diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
commit | 43a97878ce14b72f0981164f87f2e35e14151312 (patch) | |
tree | 620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/rust/tokio/src | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/tokio/src')
287 files changed, 71297 insertions, 0 deletions
diff --git a/third_party/rust/tokio/src/blocking.rs b/third_party/rust/tokio/src/blocking.rs new file mode 100644 index 0000000000..f172399d5e --- /dev/null +++ b/third_party/rust/tokio/src/blocking.rs @@ -0,0 +1,63 @@ +cfg_rt! { + pub(crate) use crate::runtime::spawn_blocking; + + cfg_fs! { + #[allow(unused_imports)] + pub(crate) use crate::runtime::spawn_mandatory_blocking; + } + + pub(crate) use crate::task::JoinHandle; +} + +cfg_not_rt! { + use std::fmt; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + pub(crate) fn spawn_blocking<F, R>(_f: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + assert_send_sync::<JoinHandle<std::cell::Cell<()>>>(); + panic!("requires the `rt` Tokio feature flag") + } + + cfg_fs! { + pub(crate) fn spawn_mandatory_blocking<F, R>(_f: F) -> Option<JoinHandle<R>> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + panic!("requires the `rt` Tokio feature flag") + } + } + + pub(crate) struct JoinHandle<R> { + _p: std::marker::PhantomData<R>, + } + + unsafe impl<T: Send> Send for JoinHandle<T> {} + unsafe impl<T: Send> Sync for JoinHandle<T> {} + + impl<R> Future for JoinHandle<R> { + type Output = Result<R, std::io::Error>; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { + unreachable!() + } + } + + impl<T> fmt::Debug for JoinHandle<T> + where + T: fmt::Debug, + { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("JoinHandle").finish() + } + } + + fn assert_send_sync<T: Send + Sync>() { + } +} diff --git a/third_party/rust/tokio/src/coop.rs b/third_party/rust/tokio/src/coop.rs new file mode 100644 index 0000000000..96905319b6 --- /dev/null +++ b/third_party/rust/tokio/src/coop.rs @@ -0,0 +1,284 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + +//! Yield points for improved cooperative scheduling. +//! +//! Documentation for this can be found in the [`tokio::task`] module. +//! +//! [`tokio::task`]: crate::task. + +// ```ignore +// # use tokio_stream::{Stream, StreamExt}; +// async fn drop_all<I: Stream + Unpin>(mut input: I) { +// while let Some(_) = input.next().await { +// tokio::coop::proceed().await; +// } +// } +// ``` +// +// The `proceed` future will coordinate with the executor to make sure that +// every so often control is yielded back to the executor so it can run other +// tasks. +// +// # Placing yield points +// +// Voluntary yield points should be placed _after_ at least some work has been +// done. If they are not, a future sufficiently deep in the task hierarchy may +// end up _never_ getting to run because of the number of yield points that +// inevitably appear before it is reached. In general, you will want yield +// points to only appear in "leaf" futures -- those that do not themselves poll +// other futures. By doing this, you avoid double-counting each iteration of +// the outer future against the cooperating budget. + +use std::cell::Cell; + +thread_local! { + static CURRENT: Cell<Budget> = Cell::new(Budget::unconstrained()); +} + +/// Opaque type tracking the amount of "work" a task may still do before +/// yielding back to the scheduler. +#[derive(Debug, Copy, Clone)] +pub(crate) struct Budget(Option<u8>); + +impl Budget { + /// Budget assigned to a task on each poll. + /// + /// The value itself is chosen somewhat arbitrarily. It needs to be high + /// enough to amortize wakeup and scheduling costs, but low enough that we + /// do not starve other tasks for too long. The value also needs to be high + /// enough that particularly deep tasks are able to do at least some useful + /// work at all. + /// + /// Note that as more yield points are added in the ecosystem, this value + /// will probably also have to be raised. + const fn initial() -> Budget { + Budget(Some(128)) + } + + /// Returns an unconstrained budget. Operations will not be limited. + const fn unconstrained() -> Budget { + Budget(None) + } + + fn has_remaining(self) -> bool { + self.0.map(|budget| budget > 0).unwrap_or(true) + } +} + +/// Runs the given closure with a cooperative task budget. When the function +/// returns, the budget is reset to the value prior to calling the function. +#[inline(always)] +pub(crate) fn budget<R>(f: impl FnOnce() -> R) -> R { + with_budget(Budget::initial(), f) +} + +/// Runs the given closure with an unconstrained task budget. When the function returns, the budget +/// is reset to the value prior to calling the function. +#[inline(always)] +pub(crate) fn with_unconstrained<R>(f: impl FnOnce() -> R) -> R { + with_budget(Budget::unconstrained(), f) +} + +#[inline(always)] +fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R { + struct ResetGuard<'a> { + cell: &'a Cell<Budget>, + prev: Budget, + } + + impl<'a> Drop for ResetGuard<'a> { + fn drop(&mut self) { + self.cell.set(self.prev); + } + } + + CURRENT.with(move |cell| { + let prev = cell.get(); + + cell.set(budget); + + let _guard = ResetGuard { cell, prev }; + + f() + }) +} + +#[inline(always)] +pub(crate) fn has_budget_remaining() -> bool { + CURRENT.with(|cell| cell.get().has_remaining()) +} + +cfg_rt_multi_thread! { + /// Sets the current task's budget. + pub(crate) fn set(budget: Budget) { + CURRENT.with(|cell| cell.set(budget)) + } +} + +cfg_rt! { + /// Forcibly removes the budgeting constraints early. + /// + /// Returns the remaining budget + pub(crate) fn stop() -> Budget { + CURRENT.with(|cell| { + let prev = cell.get(); + cell.set(Budget::unconstrained()); + prev + }) + } +} + +cfg_coop! { + use std::task::{Context, Poll}; + + #[must_use] + pub(crate) struct RestoreOnPending(Cell<Budget>); + + impl RestoreOnPending { + pub(crate) fn made_progress(&self) { + self.0.set(Budget::unconstrained()); + } + } + + impl Drop for RestoreOnPending { + fn drop(&mut self) { + // Don't reset if budget was unconstrained or if we made progress. + // They are both represented as the remembered budget being unconstrained. + let budget = self.0.get(); + if !budget.is_unconstrained() { + CURRENT.with(|cell| { + cell.set(budget); + }); + } + } + } + + /// Returns `Poll::Pending` if the current task has exceeded its budget and should yield. + /// + /// When you call this method, the current budget is decremented. However, to ensure that + /// progress is made every time a task is polled, the budget is automatically restored to its + /// former value if the returned `RestoreOnPending` is dropped. It is the caller's + /// responsibility to call `RestoreOnPending::made_progress` if it made progress, to ensure + /// that the budget empties appropriately. + /// + /// Note that `RestoreOnPending` restores the budget **as it was before `poll_proceed`**. + /// Therefore, if the budget is _further_ adjusted between when `poll_proceed` returns and + /// `RestRestoreOnPending` is dropped, those adjustments are erased unless the caller indicates + /// that progress was made. + #[inline] + pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<RestoreOnPending> { + CURRENT.with(|cell| { + let mut budget = cell.get(); + + if budget.decrement() { + let restore = RestoreOnPending(Cell::new(cell.get())); + cell.set(budget); + Poll::Ready(restore) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + } + + impl Budget { + /// Decrements the budget. Returns `true` if successful. Decrementing fails + /// when there is not enough remaining budget. + fn decrement(&mut self) -> bool { + if let Some(num) = &mut self.0 { + if *num > 0 { + *num -= 1; + true + } else { + false + } + } else { + true + } + } + + fn is_unconstrained(self) -> bool { + self.0.is_none() + } + } +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::wasm_bindgen_test as test; + + fn get() -> Budget { + CURRENT.with(|cell| cell.get()) + } + + #[test] + fn bugeting() { + use futures::future::poll_fn; + use tokio_test::*; + + assert!(get().0.is_none()); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + + assert!(get().0.is_none()); + drop(coop); + assert!(get().0.is_none()); + + budget(|| { + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + drop(coop); + // we didn't make progress + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + // we _did_ make progress + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + coop.made_progress(); + drop(coop); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + + budget(|| { + assert_eq!(get().0, Budget::initial().0); + + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + coop.made_progress(); + drop(coop); + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1); + }); + + assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2); + }); + + assert!(get().0.is_none()); + + budget(|| { + let n = get().0.unwrap(); + + for _ in 0..n { + let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx))); + coop.made_progress(); + } + + let mut task = task::spawn(poll_fn(|cx| { + let coop = ready!(poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + })); + + assert_pending!(task.poll()); + }); + } +} diff --git a/third_party/rust/tokio/src/doc/mod.rs b/third_party/rust/tokio/src/doc/mod.rs new file mode 100644 index 0000000000..3a94934490 --- /dev/null +++ b/third_party/rust/tokio/src/doc/mod.rs @@ -0,0 +1,24 @@ +//! Types which are documented locally in the Tokio crate, but does not actually +//! live here. +//! +//! **Note** this module is only visible on docs.rs, you cannot use it directly +//! in your own code. + +/// The name of a type which is not defined here. +/// +/// This is typically used as an alias for another type, like so: +/// +/// ```rust,ignore +/// /// See [some::other::location](https://example.com). +/// type DEFINED_ELSEWHERE = crate::doc::NotDefinedHere; +/// ``` +/// +/// This type is uninhabitable like the [`never` type] to ensure that no one +/// will ever accidentally use it. +/// +/// [`never` type]: https://doc.rust-lang.org/std/primitive.never.html +#[derive(Debug)] +pub enum NotDefinedHere {} + +pub mod os; +pub mod winapi; diff --git a/third_party/rust/tokio/src/doc/os.rs b/third_party/rust/tokio/src/doc/os.rs new file mode 100644 index 0000000000..0ddf86959b --- /dev/null +++ b/third_party/rust/tokio/src/doc/os.rs @@ -0,0 +1,26 @@ +//! See [std::os](https://doc.rust-lang.org/std/os/index.html). + +/// Platform-specific extensions to `std` for Windows. +/// +/// See [std::os::windows](https://doc.rust-lang.org/std/os/windows/index.html). +pub mod windows { + /// Windows-specific extensions to general I/O primitives. + /// + /// See [std::os::windows::io](https://doc.rust-lang.org/std/os/windows/io/index.html). + pub mod io { + /// See [std::os::windows::io::RawHandle](https://doc.rust-lang.org/std/os/windows/io/type.RawHandle.html) + pub type RawHandle = crate::doc::NotDefinedHere; + + /// See [std::os::windows::io::AsRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html) + pub trait AsRawHandle { + /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html#tymethod.as_raw_handle) + fn as_raw_handle(&self) -> RawHandle; + } + + /// See [std::os::windows::io::FromRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html) + pub trait FromRawHandle { + /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html#tymethod.from_raw_handle) + unsafe fn from_raw_handle(handle: RawHandle) -> Self; + } + } +} diff --git a/third_party/rust/tokio/src/doc/winapi.rs b/third_party/rust/tokio/src/doc/winapi.rs new file mode 100644 index 0000000000..be68749e00 --- /dev/null +++ b/third_party/rust/tokio/src/doc/winapi.rs @@ -0,0 +1,66 @@ +//! See [winapi]. +//! +//! [winapi]: https://docs.rs/winapi + +/// See [winapi::shared](https://docs.rs/winapi/*/winapi/shared/index.html). +pub mod shared { + /// See [winapi::shared::winerror](https://docs.rs/winapi/*/winapi/shared/winerror/index.html). + #[allow(non_camel_case_types)] + pub mod winerror { + /// See [winapi::shared::winerror::ERROR_ACCESS_DENIED][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_ACCESS_DENIED.html + pub type ERROR_ACCESS_DENIED = crate::doc::NotDefinedHere; + + /// See [winapi::shared::winerror::ERROR_PIPE_BUSY][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_PIPE_BUSY.html + pub type ERROR_PIPE_BUSY = crate::doc::NotDefinedHere; + + /// See [winapi::shared::winerror::ERROR_MORE_DATA][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_MORE_DATA.html + pub type ERROR_MORE_DATA = crate::doc::NotDefinedHere; + } +} + +/// See [winapi::um](https://docs.rs/winapi/*/winapi/um/index.html). +pub mod um { + /// See [winapi::um::winbase](https://docs.rs/winapi/*/winapi/um/winbase/index.html). + #[allow(non_camel_case_types)] + pub mod winbase { + /// See [winapi::um::winbase::PIPE_TYPE_MESSAGE][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_MESSAGE.html + pub type PIPE_TYPE_MESSAGE = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_TYPE_BYTE][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_BYTE.html + pub type PIPE_TYPE_BYTE = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_CLIENT_END][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_CLIENT_END.html + pub type PIPE_CLIENT_END = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_SERVER_END][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_SERVER_END.html + pub type PIPE_SERVER_END = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::SECURITY_IDENTIFICATION][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.SECURITY_IDENTIFICATION.html + pub type SECURITY_IDENTIFICATION = crate::doc::NotDefinedHere; + } + + /// See [winapi::um::minwinbase](https://docs.rs/winapi/*/winapi/um/minwinbase/index.html). + #[allow(non_camel_case_types)] + pub mod minwinbase { + /// See [winapi::um::minwinbase::SECURITY_ATTRIBUTES][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/minwinbase/constant.SECURITY_ATTRIBUTES.html + pub type SECURITY_ATTRIBUTES = crate::doc::NotDefinedHere; + } +} diff --git a/third_party/rust/tokio/src/fs/canonicalize.rs b/third_party/rust/tokio/src/fs/canonicalize.rs new file mode 100644 index 0000000000..403662685c --- /dev/null +++ b/third_party/rust/tokio/src/fs/canonicalize.rs @@ -0,0 +1,51 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::{Path, PathBuf}; + +/// Returns the canonical, absolute form of a path with all intermediate +/// components normalized and symbolic links resolved. +/// +/// This is an async version of [`std::fs::canonicalize`][std] +/// +/// [std]: std::fs::canonicalize +/// +/// # Platform-specific behavior +/// +/// This function currently corresponds to the `realpath` function on Unix +/// and the `CreateFile` and `GetFinalPathNameByHandle` functions on Windows. +/// Note that, this [may change in the future][changes]. +/// +/// On Windows, this converts the path to use [extended length path][path] +/// syntax, which allows your program to use longer path names, but means you +/// can only join backslash-delimited paths to it, and it may be incompatible +/// with other applications (if passed to the application on the command-line, +/// or written to a file another application may read). +/// +/// [changes]: https://doc.rust-lang.org/std/io/index.html#platform-specific-behavior +/// [path]: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247(v=vs.85).aspx#maxpath +/// +/// # Errors +/// +/// This function will return an error in the following situations, but is not +/// limited to just these cases: +/// +/// * `path` does not exist. +/// * A non-final component in path is not a directory. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let path = fs::canonicalize("../a/../foo.txt").await?; +/// Ok(()) +/// } +/// ``` +pub async fn canonicalize(path: impl AsRef<Path>) -> io::Result<PathBuf> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::canonicalize(path)).await +} diff --git a/third_party/rust/tokio/src/fs/copy.rs b/third_party/rust/tokio/src/fs/copy.rs new file mode 100644 index 0000000000..b47f287285 --- /dev/null +++ b/third_party/rust/tokio/src/fs/copy.rs @@ -0,0 +1,27 @@ +use crate::fs::asyncify; +use std::path::Path; + +/// Copies the contents of one file to another. This function will also copy the permission bits +/// of the original file to the destination file. +/// This function will overwrite the contents of to. +/// +/// This is the async equivalent of [`std::fs::copy`][std]. +/// +/// [std]: fn@std::fs::copy +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// +/// # async fn dox() -> std::io::Result<()> { +/// fs::copy("foo.txt", "bar.txt").await?; +/// # Ok(()) +/// # } +/// ``` + +pub async fn copy(from: impl AsRef<Path>, to: impl AsRef<Path>) -> Result<u64, std::io::Error> { + let from = from.as_ref().to_owned(); + let to = to.as_ref().to_owned(); + asyncify(|| std::fs::copy(from, to)).await +} diff --git a/third_party/rust/tokio/src/fs/create_dir.rs b/third_party/rust/tokio/src/fs/create_dir.rs new file mode 100644 index 0000000000..411969500f --- /dev/null +++ b/third_party/rust/tokio/src/fs/create_dir.rs @@ -0,0 +1,52 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Creates a new, empty directory at the provided path. +/// +/// This is an async version of [`std::fs::create_dir`][std] +/// +/// [std]: std::fs::create_dir +/// +/// # Platform-specific behavior +/// +/// This function currently corresponds to the `mkdir` function on Unix +/// and the `CreateDirectory` function on Windows. +/// Note that, this [may change in the future][changes]. +/// +/// [changes]: https://doc.rust-lang.org/std/io/index.html#platform-specific-behavior +/// +/// **NOTE**: If a parent of the given path doesn't exist, this function will +/// return an error. To create a directory and all its missing parents at the +/// same time, use the [`create_dir_all`] function. +/// +/// # Errors +/// +/// This function will return an error in the following situations, but is not +/// limited to just these cases: +/// +/// * User lacks permissions to create directory at `path`. +/// * A parent of the given path doesn't exist. (To create a directory and all +/// its missing parents at the same time, use the [`create_dir_all`] +/// function.) +/// * `path` already exists. +/// +/// [`create_dir_all`]: super::create_dir_all() +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// fs::create_dir("/some/dir").await?; +/// Ok(()) +/// } +/// ``` +pub async fn create_dir(path: impl AsRef<Path>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::create_dir(path)).await +} diff --git a/third_party/rust/tokio/src/fs/create_dir_all.rs b/third_party/rust/tokio/src/fs/create_dir_all.rs new file mode 100644 index 0000000000..21f0c82d11 --- /dev/null +++ b/third_party/rust/tokio/src/fs/create_dir_all.rs @@ -0,0 +1,53 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Recursively creates a directory and all of its parent components if they +/// are missing. +/// +/// This is an async version of [`std::fs::create_dir_all`][std] +/// +/// [std]: std::fs::create_dir_all +/// +/// # Platform-specific behavior +/// +/// This function currently corresponds to the `mkdir` function on Unix +/// and the `CreateDirectory` function on Windows. +/// Note that, this [may change in the future][changes]. +/// +/// [changes]: https://doc.rust-lang.org/std/io/index.html#platform-specific-behavior +/// +/// # Errors +/// +/// This function will return an error in the following situations, but is not +/// limited to just these cases: +/// +/// * If any directory in the path specified by `path` does not already exist +/// and it could not be created otherwise. The specific error conditions for +/// when a directory is being created (after it is determined to not exist) are +/// outlined by [`fs::create_dir`]. +/// +/// Notable exception is made for situations where any of the directories +/// specified in the `path` could not be created as it was being created concurrently. +/// Such cases are considered to be successful. That is, calling `create_dir_all` +/// concurrently from multiple threads or processes is guaranteed not to fail +/// due to a race condition with itself. +/// +/// [`fs::create_dir`]: std::fs::create_dir +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// +/// #[tokio::main] +/// async fn main() -> std::io::Result<()> { +/// fs::create_dir_all("/some/dir").await?; +/// Ok(()) +/// } +/// ``` +pub async fn create_dir_all(path: impl AsRef<Path>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::create_dir_all(path)).await +} diff --git a/third_party/rust/tokio/src/fs/dir_builder.rs b/third_party/rust/tokio/src/fs/dir_builder.rs new file mode 100644 index 0000000000..97168bff70 --- /dev/null +++ b/third_party/rust/tokio/src/fs/dir_builder.rs @@ -0,0 +1,137 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// A builder for creating directories in various manners. +/// +/// This is a specialized version of [`std::fs::DirBuilder`] for usage on +/// the Tokio runtime. +/// +/// [std::fs::DirBuilder]: std::fs::DirBuilder +#[derive(Debug, Default)] +pub struct DirBuilder { + /// Indicates whether to create parent directories if they are missing. + recursive: bool, + + /// Sets the Unix mode for newly created directories. + #[cfg(unix)] + pub(super) mode: Option<u32>, +} + +impl DirBuilder { + /// Creates a new set of options with default mode/security settings for all + /// platforms and also non-recursive. + /// + /// This is an async version of [`std::fs::DirBuilder::new`][std] + /// + /// [std]: std::fs::DirBuilder::new + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let builder = DirBuilder::new(); + /// ``` + pub fn new() -> Self { + Default::default() + } + + /// Indicates whether to create directories recursively (including all parent directories). + /// Parents that do not exist are created with the same security and permissions settings. + /// + /// This option defaults to `false`. + /// + /// This is an async version of [`std::fs::DirBuilder::recursive`][std] + /// + /// [std]: std::fs::DirBuilder::recursive + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let mut builder = DirBuilder::new(); + /// builder.recursive(true); + /// ``` + pub fn recursive(&mut self, recursive: bool) -> &mut Self { + self.recursive = recursive; + self + } + + /// Creates the specified directory with the configured options. + /// + /// It is considered an error if the directory already exists unless + /// recursive mode is enabled. + /// + /// This is an async version of [`std::fs::DirBuilder::create`][std] + /// + /// [std]: std::fs::DirBuilder::create + /// + /// # Errors + /// + /// An error will be returned under the following circumstances: + /// + /// * Path already points to an existing file. + /// * Path already points to an existing directory and the mode is + /// non-recursive. + /// * The calling process doesn't have permissions to create the directory + /// or its missing parents. + /// * Other I/O error occurred. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// DirBuilder::new() + /// .recursive(true) + /// .create("/tmp/foo/bar/baz") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub async fn create(&self, path: impl AsRef<Path>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + let mut builder = std::fs::DirBuilder::new(); + builder.recursive(self.recursive); + + #[cfg(unix)] + { + if let Some(mode) = self.mode { + std::os::unix::fs::DirBuilderExt::mode(&mut builder, mode); + } + } + + asyncify(move || builder.create(path)).await + } +} + +feature! { + #![unix] + + impl DirBuilder { + /// Sets the mode to create new directories with. + /// + /// This option defaults to 0o777. + /// + /// # Examples + /// + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let mut builder = DirBuilder::new(); + /// builder.mode(0o775); + /// ``` + pub fn mode(&mut self, mode: u32) -> &mut Self { + self.mode = Some(mode); + self + } + } +} diff --git a/third_party/rust/tokio/src/fs/file.rs b/third_party/rust/tokio/src/fs/file.rs new file mode 100644 index 0000000000..2c38e8059f --- /dev/null +++ b/third_party/rust/tokio/src/fs/file.rs @@ -0,0 +1,779 @@ +//! Types for working with [`File`]. +//! +//! [`File`]: File + +use self::State::*; +use crate::fs::asyncify; +use crate::io::blocking::Buf; +use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; +use crate::sync::Mutex; + +use std::fmt; +use std::fs::{Metadata, Permissions}; +use std::future::Future; +use std::io::{self, Seek, SeekFrom}; +use std::path::Path; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::Poll::*; + +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(test)] +use super::mocks::MockFile as StdFile; +#[cfg(test)] +use super::mocks::{spawn_blocking, spawn_mandatory_blocking}; +#[cfg(not(test))] +use crate::blocking::JoinHandle; +#[cfg(not(test))] +use crate::blocking::{spawn_blocking, spawn_mandatory_blocking}; +#[cfg(not(test))] +use std::fs::File as StdFile; + +/// A reference to an open file on the filesystem. +/// +/// This is a specialized version of [`std::fs::File`][std] for usage from the +/// Tokio runtime. +/// +/// An instance of a `File` can be read and/or written depending on what options +/// it was opened with. Files also implement [`AsyncSeek`] to alter the logical +/// cursor that the file contains internally. +/// +/// A file will not be closed immediately when it goes out of scope if there +/// are any IO operations that have not yet completed. To ensure that a file is +/// closed immediately when it is dropped, you should call [`flush`] before +/// dropping it. Note that this does not ensure that the file has been fully +/// written to disk; the operating system might keep the changes around in an +/// in-memory buffer. See the [`sync_all`] method for telling the OS to write +/// the data to disk. +/// +/// Reading and writing to a `File` is usually done using the convenience +/// methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] traits. +/// +/// [std]: struct@std::fs::File +/// [`AsyncSeek`]: trait@crate::io::AsyncSeek +/// [`flush`]: fn@crate::io::AsyncWriteExt::flush +/// [`sync_all`]: fn@crate::fs::File::sync_all +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// +/// # Examples +/// +/// Create a new file and asynchronously write bytes to it: +/// +/// ```no_run +/// use tokio::fs::File; +/// use tokio::io::AsyncWriteExt; // for write_all() +/// +/// # async fn dox() -> std::io::Result<()> { +/// let mut file = File::create("foo.txt").await?; +/// file.write_all(b"hello, world!").await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// Read the contents of a file into a buffer: +/// +/// ```no_run +/// use tokio::fs::File; +/// use tokio::io::AsyncReadExt; // for read_to_end() +/// +/// # async fn dox() -> std::io::Result<()> { +/// let mut file = File::open("foo.txt").await?; +/// +/// let mut contents = vec![]; +/// file.read_to_end(&mut contents).await?; +/// +/// println!("len = {}", contents.len()); +/// # Ok(()) +/// # } +/// ``` +pub struct File { + std: Arc<StdFile>, + inner: Mutex<Inner>, +} + +struct Inner { + state: State, + + /// Errors from writes/flushes are returned in write/flush calls. If a write + /// error is observed while performing a read, it is saved until the next + /// write / flush call. + last_write_err: Option<io::ErrorKind>, + + pos: u64, +} + +#[derive(Debug)] +enum State { + Idle(Option<Buf>), + Busy(JoinHandle<(Operation, Buf)>), +} + +#[derive(Debug)] +enum Operation { + Read(io::Result<usize>), + Write(io::Result<()>), + Seek(io::Result<u64>), +} + +impl File { + /// Attempts to open a file in read-only mode. + /// + /// See [`OpenOptions`] for more details. + /// + /// [`OpenOptions`]: super::OpenOptions + /// + /// # Errors + /// + /// This function will return an error if called from outside of the Tokio + /// runtime or if path does not already exist. Other errors may also be + /// returned according to OpenOptions::open. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::AsyncReadExt; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::open("foo.txt").await?; + /// + /// let mut contents = vec![]; + /// file.read_to_end(&mut contents).await?; + /// + /// println!("len = {}", contents.len()); + /// # Ok(()) + /// # } + /// ``` + /// + /// The [`read_to_end`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read_to_end`]: fn@crate::io::AsyncReadExt::read_to_end + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + pub async fn open(path: impl AsRef<Path>) -> io::Result<File> { + let path = path.as_ref().to_owned(); + let std = asyncify(|| StdFile::open(path)).await?; + + Ok(File::from_std(std)) + } + + /// Opens a file in write-only mode. + /// + /// This function will create a file if it does not exist, and will truncate + /// it if it does. + /// + /// See [`OpenOptions`] for more details. + /// + /// [`OpenOptions`]: super::OpenOptions + /// + /// # Errors + /// + /// Results in an error if called from outside of the Tokio runtime or if + /// the underlying [`create`] call results in an error. + /// + /// [`create`]: std::fs::File::create + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::AsyncWriteExt; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// file.write_all(b"hello, world!").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + pub async fn create(path: impl AsRef<Path>) -> io::Result<File> { + let path = path.as_ref().to_owned(); + let std_file = asyncify(move || StdFile::create(path)).await?; + Ok(File::from_std(std_file)) + } + + /// Converts a [`std::fs::File`][std] to a [`tokio::fs::File`][file]. + /// + /// [std]: std::fs::File + /// [file]: File + /// + /// # Examples + /// + /// ```no_run + /// // This line could block. It is not recommended to do this on the Tokio + /// // runtime. + /// let std_file = std::fs::File::open("foo.txt").unwrap(); + /// let file = tokio::fs::File::from_std(std_file); + /// ``` + pub fn from_std(std: StdFile) -> File { + File { + std: Arc::new(std), + inner: Mutex::new(Inner { + state: State::Idle(Some(Buf::with_capacity(0))), + last_write_err: None, + pos: 0, + }), + } + } + + /// Attempts to sync all OS-internal metadata to disk. + /// + /// This function will attempt to ensure that all in-core data reaches the + /// filesystem before returning. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::AsyncWriteExt; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// file.write_all(b"hello, world!").await?; + /// file.sync_all().await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + pub async fn sync_all(&self) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; + + let std = self.std.clone(); + asyncify(move || std.sync_all()).await + } + + /// This function is similar to `sync_all`, except that it may not + /// synchronize file metadata to the filesystem. + /// + /// This is intended for use cases that must synchronize content, but don't + /// need the metadata on disk. The goal of this method is to reduce disk + /// operations. + /// + /// Note that some platforms may simply implement this in terms of `sync_all`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::AsyncWriteExt; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// file.write_all(b"hello, world!").await?; + /// file.sync_data().await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + pub async fn sync_data(&self) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; + + let std = self.std.clone(); + asyncify(move || std.sync_data()).await + } + + /// Truncates or extends the underlying file, updating the size of this file to become size. + /// + /// If the size is less than the current file's size, then the file will be + /// shrunk. If it is greater than the current file's size, then the file + /// will be extended to size and have all of the intermediate data filled in + /// with 0s. + /// + /// # Errors + /// + /// This function will return an error if the file is not opened for + /// writing. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::AsyncWriteExt; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// file.write_all(b"hello, world!").await?; + /// file.set_len(10).await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + pub async fn set_len(&self, size: u64) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; + + let mut buf = match inner.state { + Idle(ref mut buf_cell) => buf_cell.take().unwrap(), + _ => unreachable!(), + }; + + let seek = if !buf.is_empty() { + Some(SeekFrom::Current(buf.discard_read())) + } else { + None + }; + + let std = self.std.clone(); + + inner.state = Busy(spawn_blocking(move || { + let res = if let Some(seek) = seek { + (&*std).seek(seek).and_then(|_| std.set_len(size)) + } else { + std.set_len(size) + } + .map(|_| 0); // the value is discarded later + + // Return the result as a seek + (Operation::Seek(res), buf) + })); + + let (op, buf) = match inner.state { + Idle(_) => unreachable!(), + Busy(ref mut rx) => rx.await?, + }; + + inner.state = Idle(Some(buf)); + + match op { + Operation::Seek(res) => res.map(|pos| { + inner.pos = pos; + }), + _ => unreachable!(), + } + } + + /// Queries metadata about the underlying file. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let file = File::open("foo.txt").await?; + /// let metadata = file.metadata().await?; + /// + /// println!("{:?}", metadata); + /// # Ok(()) + /// # } + /// ``` + pub async fn metadata(&self) -> io::Result<Metadata> { + let std = self.std.clone(); + asyncify(move || std.metadata()).await + } + + /// Creates a new `File` instance that shares the same underlying file handle + /// as the existing `File` instance. Reads, writes, and seeks will affect both + /// File instances simultaneously. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let file = File::open("foo.txt").await?; + /// let file_clone = file.try_clone().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn try_clone(&self) -> io::Result<File> { + let std = self.std.clone(); + let std_file = asyncify(move || std.try_clone()).await?; + Ok(File::from_std(std_file)) + } + + /// Destructures `File` into a [`std::fs::File`][std]. This function is + /// async to allow any in-flight operations to complete. + /// + /// Use `File::try_into_std` to attempt conversion immediately. + /// + /// [std]: std::fs::File + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let tokio_file = File::open("foo.txt").await?; + /// let std_file = tokio_file.into_std().await; + /// # Ok(()) + /// # } + /// ``` + pub async fn into_std(mut self) -> StdFile { + self.inner.get_mut().complete_inflight().await; + Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed") + } + + /// Tries to immediately destructure `File` into a [`std::fs::File`][std]. + /// + /// [std]: std::fs::File + /// + /// # Errors + /// + /// This function will return an error containing the file if some + /// operation is in-flight. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let tokio_file = File::open("foo.txt").await?; + /// let std_file = tokio_file.try_into_std().unwrap(); + /// # Ok(()) + /// # } + /// ``` + pub fn try_into_std(mut self) -> Result<StdFile, Self> { + match Arc::try_unwrap(self.std) { + Ok(file) => Ok(file), + Err(std_file_arc) => { + self.std = std_file_arc; + Err(self) + } + } + } + + /// Changes the permissions on the underlying file. + /// + /// # Platform-specific behavior + /// + /// This function currently corresponds to the `fchmod` function on Unix and + /// the `SetFileInformationByHandle` function on Windows. Note that, this + /// [may change in the future][changes]. + /// + /// [changes]: https://doc.rust-lang.org/std/io/index.html#platform-specific-behavior + /// + /// # Errors + /// + /// This function will return an error if the user lacks permission change + /// attributes on the underlying file. It may also return an error in other + /// os-specific unspecified cases. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let file = File::open("foo.txt").await?; + /// let mut perms = file.metadata().await?.permissions(); + /// perms.set_readonly(true); + /// file.set_permissions(perms).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn set_permissions(&self, perm: Permissions) -> io::Result<()> { + let std = self.std.clone(); + asyncify(move || std.set_permissions(perm)).await + } +} + +impl AsyncRead for File { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + loop { + match inner.state { + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap(); + + if !buf.is_empty() { + buf.copy_to(dst); + *buf_cell = Some(buf); + return Ready(Ok(())); + } + + buf.ensure_capacity_for(dst); + let std = me.std.clone(); + + inner.state = Busy(spawn_blocking(move || { + let res = buf.read_from(&mut &*std); + (Operation::Read(res), buf) + })); + } + Busy(ref mut rx) => { + let (op, mut buf) = ready!(Pin::new(rx).poll(cx))?; + + match op { + Operation::Read(Ok(_)) => { + buf.copy_to(dst); + inner.state = Idle(Some(buf)); + return Ready(Ok(())); + } + Operation::Read(Err(e)) => { + assert!(buf.is_empty()); + + inner.state = Idle(Some(buf)); + return Ready(Err(e)); + } + Operation::Write(Ok(_)) => { + assert!(buf.is_empty()); + inner.state = Idle(Some(buf)); + continue; + } + Operation::Write(Err(e)) => { + assert!(inner.last_write_err.is_none()); + inner.last_write_err = Some(e.kind()); + inner.state = Idle(Some(buf)); + } + Operation::Seek(result) => { + assert!(buf.is_empty()); + inner.state = Idle(Some(buf)); + if let Ok(pos) = result { + inner.pos = pos; + } + continue; + } + } + } + } + } + } +} + +impl AsyncSeek for File { + fn start_seek(self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + loop { + match inner.state { + Busy(_) => panic!("must wait for poll_complete before calling start_seek"), + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap(); + + // Factor in any unread data from the buf + if !buf.is_empty() { + let n = buf.discard_read(); + + if let SeekFrom::Current(ref mut offset) = pos { + *offset += n; + } + } + + let std = me.std.clone(); + + inner.state = Busy(spawn_blocking(move || { + let res = (&*std).seek(pos); + (Operation::Seek(res), buf) + })); + return Ok(()); + } + } + } + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let inner = self.inner.get_mut(); + + loop { + match inner.state { + Idle(_) => return Poll::Ready(Ok(inner.pos)), + Busy(ref mut rx) => { + let (op, buf) = ready!(Pin::new(rx).poll(cx))?; + inner.state = Idle(Some(buf)); + + match op { + Operation::Read(_) => {} + Operation::Write(Err(e)) => { + assert!(inner.last_write_err.is_none()); + inner.last_write_err = Some(e.kind()); + } + Operation::Write(_) => {} + Operation::Seek(res) => { + if let Ok(pos) = res { + inner.pos = pos; + } + return Ready(res); + } + } + } + } + } + } +} + +impl AsyncWrite for File { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + src: &[u8], + ) -> Poll<io::Result<usize>> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + if let Some(e) = inner.last_write_err.take() { + return Ready(Err(e.into())); + } + + loop { + match inner.state { + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap(); + + let seek = if !buf.is_empty() { + Some(SeekFrom::Current(buf.discard_read())) + } else { + None + }; + + let n = buf.copy_from(src); + let std = me.std.clone(); + + let blocking_task_join_handle = spawn_mandatory_blocking(move || { + let res = if let Some(seek) = seek { + (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) + } else { + buf.write_to(&mut &*std) + }; + + (Operation::Write(res), buf) + }) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "background task failed") + })?; + + inner.state = Busy(blocking_task_join_handle); + + return Ready(Ok(n)); + } + Busy(ref mut rx) => { + let (op, buf) = ready!(Pin::new(rx).poll(cx))?; + inner.state = Idle(Some(buf)); + + match op { + Operation::Read(_) => { + // We don't care about the result here. The fact + // that the cursor has advanced will be reflected in + // the next iteration of the loop + continue; + } + Operation::Write(res) => { + // If the previous write was successful, continue. + // Otherwise, error. + res?; + continue; + } + Operation::Seek(_) => { + // Ignore the seek + continue; + } + } + } + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let inner = self.inner.get_mut(); + inner.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.poll_flush(cx) + } +} + +impl From<StdFile> for File { + fn from(std: StdFile) -> Self { + Self::from_std(std) + } +} + +impl fmt::Debug for File { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("tokio::fs::File") + .field("std", &self.std) + .finish() + } +} + +#[cfg(unix)] +impl std::os::unix::io::AsRawFd for File { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + self.std.as_raw_fd() + } +} + +#[cfg(unix)] +impl std::os::unix::io::FromRawFd for File { + unsafe fn from_raw_fd(fd: std::os::unix::io::RawFd) -> Self { + StdFile::from_raw_fd(fd).into() + } +} + +#[cfg(windows)] +impl std::os::windows::io::AsRawHandle for File { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + self.std.as_raw_handle() + } +} + +#[cfg(windows)] +impl std::os::windows::io::FromRawHandle for File { + unsafe fn from_raw_handle(handle: std::os::windows::io::RawHandle) -> Self { + StdFile::from_raw_handle(handle).into() + } +} + +impl Inner { + async fn complete_inflight(&mut self) { + use crate::future::poll_fn; + + if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await { + self.last_write_err = Some(e.kind()); + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + if let Some(e) = self.last_write_err.take() { + return Ready(Err(e.into())); + } + + let (op, buf) = match self.state { + Idle(_) => return Ready(Ok(())), + Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?, + }; + + // The buffer is not used here + self.state = Idle(Some(buf)); + + match op { + Operation::Read(_) => Ready(Ok(())), + Operation::Write(res) => Ready(res), + Operation::Seek(_) => Ready(Ok(())), + } + } +} + +#[cfg(test)] +mod tests; diff --git a/third_party/rust/tokio/src/fs/file/tests.rs b/third_party/rust/tokio/src/fs/file/tests.rs new file mode 100644 index 0000000000..18a4c07859 --- /dev/null +++ b/third_party/rust/tokio/src/fs/file/tests.rs @@ -0,0 +1,957 @@ +use super::*; +use crate::{ + fs::mocks::*, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; +use mockall::{predicate::eq, Sequence}; +use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task}; + +const HELLO: &[u8] = b"hello world..."; +const FOO: &[u8] = b"foo bar baz..."; + +#[test] +fn open_read() { + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + let mut file = File::from_std(file); + + let mut buf = [0; 1024]; + let mut t = task::spawn(file.read(&mut buf)); + + assert_eq!(0, pool::len()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + + pool::run_one(); + + assert!(t.is_woken()); + + let n = assert_ready_ok!(t.poll()); + assert_eq!(n, HELLO.len()); + assert_eq!(&buf[..n], HELLO); +} + +#[test] +fn read_twice_before_dispatch() { + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + let mut file = File::from_std(file); + + let mut buf = [0; 1024]; + let mut t = task::spawn(file.read(&mut buf)); + + assert_pending!(t.poll()); + assert_pending!(t.poll()); + + assert_eq!(pool::len(), 1); + pool::run_one(); + + assert!(t.is_woken()); + + let n = assert_ready_ok!(t.poll()); + assert_eq!(&buf[..n], HELLO); +} + +#[test] +fn read_with_smaller_buf() { + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + + let mut file = File::from_std(file); + + { + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + } + + pool::run_one(); + + { + let mut buf = [0; 4]; + let mut t = task::spawn(file.read(&mut buf)); + let n = assert_ready_ok!(t.poll()); + assert_eq!(n, 4); + assert_eq!(&buf[..], &HELLO[..n]); + } + + // Calling again immediately succeeds with the rest of the buffer + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + let n = assert_ready_ok!(t.poll()); + assert_eq!(n, 10); + assert_eq!(&buf[..n], &HELLO[4..]); + + assert_eq!(0, pool::len()); +} + +#[test] +fn read_with_bigger_buf() { + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..4].copy_from_slice(&HELLO[..4]); + Ok(4) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len() - 4].copy_from_slice(&HELLO[4..]); + Ok(HELLO.len() - 4) + }); + + let mut file = File::from_std(file); + + { + let mut buf = [0; 4]; + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + } + + pool::run_one(); + + { + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + let n = assert_ready_ok!(t.poll()); + assert_eq!(n, 4); + assert_eq!(&buf[..n], &HELLO[..n]); + } + + // Calling again immediately succeeds with the rest of the buffer + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + + let n = assert_ready_ok!(t.poll()); + assert_eq!(n, 10); + assert_eq!(&buf[..n], &HELLO[4..]); + + assert_eq!(0, pool::len()); +} + +#[test] +fn read_err_then_read_success() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + + let mut file = File::from_std(file); + + { + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + + pool::run_one(); + + assert_ready_err!(t.poll()); + } + + { + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + + pool::run_one(); + + let n = assert_ready_ok!(t.poll()); + + assert_eq!(n, HELLO.len()); + assert_eq!(&buf[..n], HELLO); + } +} + +#[test] +fn open_write() { + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + + assert_eq!(0, pool::len()); + assert_ready_ok!(t.poll()); + + assert_eq!(1, pool::len()); + + pool::run_one(); + + assert!(!t.is_woken()); + + let mut t = task::spawn(file.flush()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn flush_while_idle() { + let file = MockFile::default(); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.flush()); + assert_ready_ok!(t.poll()); +} + +#[test] +#[cfg_attr(miri, ignore)] // takes a really long time with miri +fn read_with_buffer_larger_than_max() { + // Chunks + let chunk_a = 16 * 1024; + let chunk_b = chunk_a * 2; + let chunk_c = chunk_a * 3; + let chunk_d = chunk_a * 4; + + assert_eq!(chunk_d / 1024, 64); + + let mut data = vec![]; + for i in 0..(chunk_d - 1) { + data.push((i % 151) as u8); + } + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[0..chunk_a].copy_from_slice(&d0[0..chunk_a]); + Ok(chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d1[chunk_a..chunk_b]); + Ok(chunk_b - chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d2[chunk_b..chunk_c]); + Ok(chunk_c - chunk_b) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a - 1].copy_from_slice(&d3[chunk_c..]); + Ok(chunk_a - 1) + }); + let mut file = File::from_std(file); + + let mut actual = vec![0; chunk_d]; + let mut pos = 0; + + while pos < data.len() { + let mut t = task::spawn(file.read(&mut actual[pos..])); + + assert_pending!(t.poll()); + pool::run_one(); + assert!(t.is_woken()); + + let n = assert_ready_ok!(t.poll()); + assert!(n <= chunk_a); + + pos += n; + } + + assert_eq!(&data[..], &actual[..data.len()]); +} + +#[test] +#[cfg_attr(miri, ignore)] // takes a really long time with miri +fn write_with_buffer_larger_than_max() { + // Chunks + let chunk_a = 16 * 1024; + let chunk_b = chunk_a * 2; + let chunk_c = chunk_a * 3; + let chunk_d = chunk_a * 4; + + assert_eq!(chunk_d / 1024, 64); + + let mut data = vec![]; + for i in 0..(chunk_d - 1) { + data.push((i % 151) as u8); + } + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d0[0..chunk_a]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d1[chunk_a..chunk_b]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d2[chunk_b..chunk_c]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d3[chunk_c..chunk_d - 1]) + .returning(|buf| Ok(buf.len())); + + let mut file = File::from_std(file); + + let mut rem = &data[..]; + + let mut first = true; + + while !rem.is_empty() { + let mut task = task::spawn(file.write(rem)); + + if !first { + assert_pending!(task.poll()); + pool::run_one(); + assert!(task.is_woken()); + } + + first = false; + + let n = assert_ready_ok!(task.poll()); + + rem = &rem[n..]; + } + + pool::run_one(); +} + +#[test] +fn write_twice_before_dispatch() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|buf| Ok(buf.len())); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.write(FOO)); + assert_pending!(t.poll()); + + assert_eq!(pool::len(), 1); + pool::run_one(); + + assert!(t.is_woken()); + + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.flush()); + assert_pending!(t.poll()); + + assert_eq!(pool::len(), 1); + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn incomplete_read_followed_by_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); + + let mut file = File::from_std(file); + + let mut buf = [0; 32]; + + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + + pool::run_one(); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_ok!(t.poll()); + + assert_eq!(pool::len(), 1); + pool::run_one(); + + let mut t = task::spawn(file.flush()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn incomplete_partial_read_followed_by_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-10))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); + + let mut file = File::from_std(file); + + let mut buf = [0; 32]; + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + + pool::run_one(); + + let mut buf = [0; 4]; + let mut t = task::spawn(file.read(&mut buf)); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_ok!(t.poll()); + + assert_eq!(pool::len(), 1); + pool::run_one(); + + let mut t = task::spawn(file.flush()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn incomplete_read_followed_by_flush() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); + + let mut file = File::from_std(file); + + let mut buf = [0; 32]; + + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + + pool::run_one(); + + let mut t = task::spawn(file.flush()); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); +} + +#[test] +fn incomplete_flush_followed_by_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + let n = assert_ready_ok!(t.poll()); + assert_eq!(n, HELLO.len()); + + let mut t = task::spawn(file.flush()); + assert_pending!(t.poll()); + + // TODO: Move under write + pool::run_one(); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); + + let mut t = task::spawn(file.flush()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn read_err() { + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); + + let mut file = File::from_std(file); + + let mut buf = [0; 1024]; + let mut t = task::spawn(file.read(&mut buf)); + + assert_pending!(t.poll()); + + pool::run_one(); + assert!(t.is_woken()); + + assert_ready_err!(t.poll()); +} + +#[test] +fn write_write_err() { + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_err!(t.poll()); +} + +#[test] +fn write_read_write_err() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); + + let mut buf = [0; 1024]; + let mut t = task::spawn(file.read(&mut buf)); + + assert_pending!(t.poll()); + + pool::run_one(); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_err!(t.poll()); +} + +#[test] +fn write_read_flush_err() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); + + let mut buf = [0; 1024]; + let mut t = task::spawn(file.read(&mut buf)); + + assert_pending!(t.poll()); + + pool::run_one(); + + let mut t = task::spawn(file.flush()); + assert_ready_err!(t.poll()); +} + +#[test] +fn write_seek_write_err() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); + + { + let mut t = task::spawn(file.seek(SeekFrom::Start(0))); + assert_pending!(t.poll()); + } + + pool::run_one(); + + let mut t = task::spawn(file.write(FOO)); + assert_ready_err!(t.poll()); +} + +#[test] +fn write_seek_flush_err() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + + let mut file = File::from_std(file); + + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + pool::run_one(); + + { + let mut t = task::spawn(file.seek(SeekFrom::Start(0))); + assert_pending!(t.poll()); + } + + pool::run_one(); + + let mut t = task::spawn(file.flush()); + assert_ready_err!(t.poll()); +} + +#[test] +fn sync_all_ordered_after_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all().once().returning(|| Ok(())); + + let mut file = File::from_std(file); + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.sync_all()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn sync_all_err_ordered_after_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); + + let mut file = File::from_std(file); + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.sync_all()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_err!(t.poll()); +} + +#[test] +fn sync_data_ordered_after_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data().once().returning(|| Ok(())); + + let mut file = File::from_std(file); + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.sync_data()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn sync_data_err_ordered_after_write() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); + + let mut file = File::from_std(file); + let mut t = task::spawn(file.write(HELLO)); + assert_ready_ok!(t.poll()); + + let mut t = task::spawn(file.sync_data()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_pending!(t.poll()); + + assert_eq!(1, pool::len()); + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_err!(t.poll()); +} + +#[test] +fn open_set_len_ok() { + let mut file = MockFile::default(); + file.expect_set_len().with(eq(123)).returning(|_| Ok(())); + + let file = File::from_std(file); + let mut t = task::spawn(file.set_len(123)); + + assert_pending!(t.poll()); + + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_ok!(t.poll()); +} + +#[test] +fn open_set_len_err() { + let mut file = MockFile::default(); + file.expect_set_len() + .with(eq(123)) + .returning(|_| Err(io::ErrorKind::Other.into())); + + let file = File::from_std(file); + let mut t = task::spawn(file.set_len(123)); + + assert_pending!(t.poll()); + + pool::run_one(); + + assert!(t.is_woken()); + assert_ready_err!(t.poll()); +} + +#[test] +fn partial_read_set_len_ok() { + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_set_len() + .once() + .in_sequence(&mut seq) + .with(eq(123)) + .returning(|_| Ok(())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..FOO.len()].copy_from_slice(FOO); + Ok(FOO.len()) + }); + + let mut buf = [0; 32]; + let mut file = File::from_std(file); + + { + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + } + + pool::run_one(); + + { + let mut t = task::spawn(file.set_len(123)); + + assert_pending!(t.poll()); + pool::run_one(); + assert_ready_ok!(t.poll()); + } + + let mut t = task::spawn(file.read(&mut buf)); + assert_pending!(t.poll()); + pool::run_one(); + let n = assert_ready_ok!(t.poll()); + + assert_eq!(n, FOO.len()); + assert_eq!(&buf[..n], FOO); +} diff --git a/third_party/rust/tokio/src/fs/hard_link.rs b/third_party/rust/tokio/src/fs/hard_link.rs new file mode 100644 index 0000000000..50cc17d286 --- /dev/null +++ b/third_party/rust/tokio/src/fs/hard_link.rs @@ -0,0 +1,46 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Creates a new hard link on the filesystem. +/// +/// This is an async version of [`std::fs::hard_link`][std] +/// +/// [std]: std::fs::hard_link +/// +/// The `dst` path will be a link pointing to the `src` path. Note that systems +/// often require these two paths to both be located on the same filesystem. +/// +/// # Platform-specific behavior +/// +/// This function currently corresponds to the `link` function on Unix +/// and the `CreateHardLink` function on Windows. +/// Note that, this [may change in the future][changes]. +/// +/// [changes]: https://doc.rust-lang.org/std/io/index.html#platform-specific-behavior +/// +/// # Errors +/// +/// This function will return an error in the following situations, but is not +/// limited to just these cases: +/// +/// * The `src` path is not a file or doesn't exist. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// +/// #[tokio::main] +/// async fn main() -> std::io::Result<()> { +/// fs::hard_link("a.txt", "b.txt").await?; // Hard link a.txt to b.txt +/// Ok(()) +/// } +/// ``` +pub async fn hard_link(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> io::Result<()> { + let src = src.as_ref().to_owned(); + let dst = dst.as_ref().to_owned(); + + asyncify(move || std::fs::hard_link(src, dst)).await +} diff --git a/third_party/rust/tokio/src/fs/metadata.rs b/third_party/rust/tokio/src/fs/metadata.rs new file mode 100644 index 0000000000..ff9cded79a --- /dev/null +++ b/third_party/rust/tokio/src/fs/metadata.rs @@ -0,0 +1,47 @@ +use crate::fs::asyncify; + +use std::fs::Metadata; +use std::io; +use std::path::Path; + +/// Given a path, queries the file system to get information about a file, +/// directory, etc. +/// +/// This is an async version of [`std::fs::metadata`][std] +/// +/// This function will traverse symbolic links to query information about the +/// destination file. +/// +/// # Platform-specific behavior +/// +/// This function currently corresponds to the `stat` function on Unix and the +/// `GetFileAttributesEx` function on Windows. Note that, this [may change in +/// the future][changes]. +/// +/// [std]: std::fs::metadata +/// [changes]: https://doc.rust-lang.org/std/io/index.html#platform-specific-behavior +/// +/// # Errors +/// +/// This function will return an error in the following situations, but is not +/// limited to just these cases: +/// +/// * The user lacks permissions to perform `metadata` call on `path`. +/// * `path` does not exist. +/// +/// # Examples +/// +/// ```rust,no_run +/// use tokio::fs; +/// +/// #[tokio::main] +/// async fn main() -> std::io::Result<()> { +/// let attr = fs::metadata("/some/file/path.txt").await?; +/// // inspect attr ... +/// Ok(()) +/// } +/// ``` +pub async fn metadata(path: impl AsRef<Path>) -> io::Result<Metadata> { + let path = path.as_ref().to_owned(); + asyncify(|| std::fs::metadata(path)).await +} diff --git a/third_party/rust/tokio/src/fs/mocks.rs b/third_party/rust/tokio/src/fs/mocks.rs new file mode 100644 index 0000000000..b186172677 --- /dev/null +++ b/third_party/rust/tokio/src/fs/mocks.rs @@ -0,0 +1,151 @@ +//! Mock version of std::fs::File; +use mockall::mock; + +use crate::sync::oneshot; +use std::{ + cell::RefCell, + collections::VecDeque, + fs::{Metadata, Permissions}, + future::Future, + io::{self, Read, Seek, SeekFrom, Write}, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; + +mock! { + #[derive(Debug)] + pub File { + pub fn create(pb: PathBuf) -> io::Result<Self>; + // These inner_ methods exist because std::fs::File has two + // implementations for each of these methods: one on "&mut self" and + // one on "&&self". Defining both of those in terms of an inner_ method + // allows us to specify the expectation the same way, regardless of + // which method is used. + pub fn inner_flush(&self) -> io::Result<()>; + pub fn inner_read(&self, dst: &mut [u8]) -> io::Result<usize>; + pub fn inner_seek(&self, pos: SeekFrom) -> io::Result<u64>; + pub fn inner_write(&self, src: &[u8]) -> io::Result<usize>; + pub fn metadata(&self) -> io::Result<Metadata>; + pub fn open(pb: PathBuf) -> io::Result<Self>; + pub fn set_len(&self, size: u64) -> io::Result<()>; + pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>; + pub fn sync_all(&self) -> io::Result<()>; + pub fn sync_data(&self) -> io::Result<()>; + pub fn try_clone(&self) -> io::Result<Self>; + } + #[cfg(windows)] + impl std::os::windows::io::AsRawHandle for File { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle; + } + #[cfg(windows)] + impl std::os::windows::io::FromRawHandle for File { + unsafe fn from_raw_handle(h: std::os::windows::io::RawHandle) -> Self; + } + #[cfg(unix)] + impl std::os::unix::io::AsRawFd for File { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd; + } + + #[cfg(unix)] + impl std::os::unix::io::FromRawFd for File { + unsafe fn from_raw_fd(h: std::os::unix::io::RawFd) -> Self; + } +} + +impl Read for MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner_read(dst) + } +} + +impl Read for &'_ MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner_read(dst) + } +} + +impl Seek for &'_ MockFile { + fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> { + self.inner_seek(pos) + } +} + +impl Write for &'_ MockFile { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + self.inner_write(src) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner_flush() + } +} + +thread_local! { + static QUEUE: RefCell<VecDeque<Box<dyn FnOnce() + Send>>> = RefCell::new(VecDeque::new()) +} + +#[derive(Debug)] +pub(super) struct JoinHandle<T> { + rx: oneshot::Receiver<T>, +} + +pub(super) fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let task = Box::new(move || { + let _ = tx.send(f()); + }); + + QUEUE.with(|cell| cell.borrow_mut().push_back(task)); + + JoinHandle { rx } +} + +pub(super) fn spawn_mandatory_blocking<F, R>(f: F) -> Option<JoinHandle<R>> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let task = Box::new(move || { + let _ = tx.send(f()); + }); + + QUEUE.with(|cell| cell.borrow_mut().push_back(task)); + + Some(JoinHandle { rx }) +} + +impl<T> Future for JoinHandle<T> { + type Output = Result<T, io::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + use std::task::Poll::*; + + match Pin::new(&mut self.rx).poll(cx) { + Ready(Ok(v)) => Ready(Ok(v)), + Ready(Err(e)) => panic!("error = {:?}", e), + Pending => Pending, + } + } +} + +pub(super) mod pool { + use super::*; + + pub(in super::super) fn len() -> usize { + QUEUE.with(|cell| cell.borrow().len()) + } + + pub(in super::super) fn run_one() { + let task = QUEUE + .with(|cell| cell.borrow_mut().pop_front()) + .expect("expected task to run, but none ready"); + + task(); + } +} diff --git a/third_party/rust/tokio/src/fs/mod.rs b/third_party/rust/tokio/src/fs/mod.rs new file mode 100644 index 0000000000..ca0264b367 --- /dev/null +++ b/third_party/rust/tokio/src/fs/mod.rs @@ -0,0 +1,126 @@ +#![cfg(not(loom))] + +//! Asynchronous file and standard stream adaptation. +//! +//! This module contains utility methods and adapter types for input/output to +//! files or standard streams (`Stdin`, `Stdout`, `Stderr`), and +//! filesystem manipulation, for use within (and only within) a Tokio runtime. +//! +//! Tasks run by *worker* threads should not block, as this could delay +//! servicing reactor events. Portable filesystem operations are blocking, +//! however. This module offers adapters which use a `blocking` annotation +//! to inform the runtime that a blocking operation is required. When +//! necessary, this allows the runtime to convert the current thread from a +//! *worker* to a *backup* thread, where blocking is acceptable. +//! +//! ## Usage +//! +//! Where possible, users should prefer the provided asynchronous-specific +//! traits such as [`AsyncRead`], or methods returning a `Future` or `Poll` +//! type. Adaptions also extend to traits like `std::io::Read` where methods +//! return `std::io::Result`. Be warned that these adapted methods may return +//! `std::io::ErrorKind::WouldBlock` if a *worker* thread can not be converted +//! to a *backup* thread immediately. +//! +//! [`AsyncRead`]: trait@crate::io::AsyncRead + +mod canonicalize; +pub use self::canonicalize::canonicalize; + +mod create_dir; +pub use self::create_dir::create_dir; + +mod create_dir_all; +pub use self::create_dir_all::create_dir_all; + +mod dir_builder; +pub use self::dir_builder::DirBuilder; + +mod file; +pub use self::file::File; + +mod hard_link; +pub use self::hard_link::hard_link; + +mod metadata; +pub use self::metadata::metadata; + +mod open_options; +pub use self::open_options::OpenOptions; + +mod read; +pub use self::read::read; + +mod read_dir; +pub use self::read_dir::{read_dir, DirEntry, ReadDir}; + +mod read_link; +pub use self::read_link::read_link; + +mod read_to_string; +pub use self::read_to_string::read_to_string; + +mod remove_dir; +pub use self::remove_dir::remove_dir; + +mod remove_dir_all; +pub use self::remove_dir_all::remove_dir_all; + +mod remove_file; +pub use self::remove_file::remove_file; + +mod rename; +pub use self::rename::rename; + +mod set_permissions; +pub use self::set_permissions::set_permissions; + +mod symlink_metadata; +pub use self::symlink_metadata::symlink_metadata; + +mod write; +pub use self::write::write; + +mod copy; +pub use self::copy::copy; + +#[cfg(test)] +mod mocks; + +feature! { + #![unix] + + mod symlink; + pub use self::symlink::symlink; +} + +feature! { + #![windows] + + mod symlink_dir; + pub use self::symlink_dir::symlink_dir; + + mod symlink_file; + pub use self::symlink_file::symlink_file; +} + +use std::io; + +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(test)] +use mocks::spawn_blocking; + +pub(crate) async fn asyncify<F, T>(f: F) -> io::Result<T> +where + F: FnOnce() -> io::Result<T> + Send + 'static, + T: Send + 'static, +{ + match spawn_blocking(f).await { + Ok(res) => res, + Err(_) => Err(io::Error::new( + io::ErrorKind::Other, + "background task failed", + )), + } +} diff --git a/third_party/rust/tokio/src/fs/open_options.rs b/third_party/rust/tokio/src/fs/open_options.rs new file mode 100644 index 0000000000..f3b4654741 --- /dev/null +++ b/third_party/rust/tokio/src/fs/open_options.rs @@ -0,0 +1,665 @@ +use crate::fs::{asyncify, File}; + +use std::io; +use std::path::Path; + +#[cfg(test)] +mod mock_open_options; +#[cfg(test)] +use mock_open_options::MockOpenOptions as StdOpenOptions; +#[cfg(not(test))] +use std::fs::OpenOptions as StdOpenOptions; + +/// Options and flags which can be used to configure how a file is opened. +/// +/// This builder exposes the ability to configure how a [`File`] is opened and +/// what operations are permitted on the open file. The [`File::open`] and +/// [`File::create`] methods are aliases for commonly used options using this +/// builder. +/// +/// Generally speaking, when using `OpenOptions`, you'll first call [`new`], +/// then chain calls to methods to set each option, then call [`open`], passing +/// the path of the file you're trying to open. This will give you a +/// [`io::Result`][result] with a [`File`] inside that you can further operate +/// on. +/// +/// This is a specialized version of [`std::fs::OpenOptions`] for usage from +/// the Tokio runtime. +/// +/// `From<std::fs::OpenOptions>` is implemented for more advanced configuration +/// than the methods provided here. +/// +/// [`new`]: OpenOptions::new +/// [`open`]: OpenOptions::open +/// [result]: std::io::Result +/// [`File`]: File +/// [`File::open`]: File::open +/// [`File::create`]: File::create +/// [`std::fs::OpenOptions`]: std::fs::OpenOptions +/// +/// # Examples +/// +/// Opening a file to read: +/// +/// ```no_run +/// use tokio::fs::OpenOptions; +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let file = OpenOptions::new() +/// .read(true) +/// .open("foo.txt") +/// .await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// Opening a file for both reading and writing, as well as creating it if it +/// doesn't exist: +/// +/// ```no_run +/// use tokio::fs::OpenOptions; +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let file = OpenOptions::new() +/// .read(true) +/// .write(true) +/// .create(true) +/// .open("foo.txt") +/// .await?; +/// +/// Ok(()) +/// } +/// ``` +#[derive(Clone, Debug)] +pub struct OpenOptions(StdOpenOptions); + +impl OpenOptions { + /// Creates a blank new set of options ready for configuration. + /// + /// All options are initially set to `false`. + /// + /// This is an async version of [`std::fs::OpenOptions::new`][std] + /// + /// [std]: std::fs::OpenOptions::new + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// + /// let mut options = OpenOptions::new(); + /// let future = options.read(true).open("foo.txt"); + /// ``` + pub fn new() -> OpenOptions { + OpenOptions(StdOpenOptions::new()) + } + + /// Sets the option for read access. + /// + /// This option, when true, will indicate that the file should be + /// `read`-able if opened. + /// + /// This is an async version of [`std::fs::OpenOptions::read`][std] + /// + /// [std]: std::fs::OpenOptions::read + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new() + /// .read(true) + /// .open("foo.txt") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn read(&mut self, read: bool) -> &mut OpenOptions { + self.0.read(read); + self + } + + /// Sets the option for write access. + /// + /// This option, when true, will indicate that the file should be + /// `write`-able if opened. + /// + /// This is an async version of [`std::fs::OpenOptions::write`][std] + /// + /// [std]: std::fs::OpenOptions::write + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .open("foo.txt") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn write(&mut self, write: bool) -> &mut OpenOptions { + self.0.write(write); + self + } + + /// Sets the option for the append mode. + /// + /// This option, when true, means that writes will append to a file instead + /// of overwriting previous contents. Note that setting + /// `.write(true).append(true)` has the same effect as setting only + /// `.append(true)`. + /// + /// For most filesystems, the operating system guarantees that all writes are + /// atomic: no writes get mangled because another process writes at the same + /// time. + /// + /// One maybe obvious note when using append-mode: make sure that all data + /// that belongs together is written to the file in one operation. This + /// can be done by concatenating strings before passing them to [`write()`], + /// or using a buffered writer (with a buffer of adequate size), + /// and calling [`flush()`] when the message is complete. + /// + /// If a file is opened with both read and append access, beware that after + /// opening, and after every write, the position for reading may be set at the + /// end of the file. So, before writing, save the current position (using + /// [`seek`]`(`[`SeekFrom`]`::`[`Current`]`(0))`), and restore it before the next read. + /// + /// This is an async version of [`std::fs::OpenOptions::append`][std] + /// + /// [std]: std::fs::OpenOptions::append + /// + /// ## Note + /// + /// This function doesn't create the file if it doesn't exist. Use the [`create`] + /// method to do so. + /// + /// [`write()`]: crate::io::AsyncWriteExt::write + /// [`flush()`]: crate::io::AsyncWriteExt::flush + /// [`seek`]: crate::io::AsyncSeekExt::seek + /// [`SeekFrom`]: std::io::SeekFrom + /// [`Current`]: std::io::SeekFrom::Current + /// [`create`]: OpenOptions::create + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new() + /// .append(true) + /// .open("foo.txt") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn append(&mut self, append: bool) -> &mut OpenOptions { + self.0.append(append); + self + } + + /// Sets the option for truncating a previous file. + /// + /// If a file is successfully opened with this option set it will truncate + /// the file to 0 length if it already exists. + /// + /// The file must be opened with write access for truncate to work. + /// + /// This is an async version of [`std::fs::OpenOptions::truncate`][std] + /// + /// [std]: std::fs::OpenOptions::truncate + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .truncate(true) + /// .open("foo.txt") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions { + self.0.truncate(truncate); + self + } + + /// Sets the option for creating a new file. + /// + /// This option indicates whether a new file will be created if the file + /// does not yet already exist. + /// + /// In order for the file to be created, [`write`] or [`append`] access must + /// be used. + /// + /// This is an async version of [`std::fs::OpenOptions::create`][std] + /// + /// [std]: std::fs::OpenOptions::create + /// [`write`]: OpenOptions::write + /// [`append`]: OpenOptions::append + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// .open("foo.txt") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn create(&mut self, create: bool) -> &mut OpenOptions { + self.0.create(create); + self + } + + /// Sets the option to always create a new file. + /// + /// This option indicates whether a new file will be created. No file is + /// allowed to exist at the target location, also no (dangling) symlink. + /// + /// This option is useful because it is atomic. Otherwise between checking + /// whether a file exists and creating a new one, the file may have been + /// created by another process (a TOCTOU race condition / attack). + /// + /// If `.create_new(true)` is set, [`.create()`] and [`.truncate()`] are + /// ignored. + /// + /// The file must be opened with write or append access in order to create a + /// new file. + /// + /// This is an async version of [`std::fs::OpenOptions::create_new`][std] + /// + /// [std]: std::fs::OpenOptions::create_new + /// [`.create()`]: OpenOptions::create + /// [`.truncate()`]: OpenOptions::truncate + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create_new(true) + /// .open("foo.txt") + /// .await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions { + self.0.create_new(create_new); + self + } + + /// Opens a file at `path` with the options specified by `self`. + /// + /// This is an async version of [`std::fs::OpenOptions::open`][std] + /// + /// [std]: std::fs::OpenOptions::open + /// + /// # Errors + /// + /// This function will return an error under a number of different + /// circumstances. Some of these error conditions are listed here, together + /// with their [`ErrorKind`]. The mapping to [`ErrorKind`]s is not part of + /// the compatibility contract of the function, especially the `Other` kind + /// might change to more specific kinds in the future. + /// + /// * [`NotFound`]: The specified file does not exist and neither `create` + /// or `create_new` is set. + /// * [`NotFound`]: One of the directory components of the file path does + /// not exist. + /// * [`PermissionDenied`]: The user lacks permission to get the specified + /// access rights for the file. + /// * [`PermissionDenied`]: The user lacks permission to open one of the + /// directory components of the specified path. + /// * [`AlreadyExists`]: `create_new` was specified and the file already + /// exists. + /// * [`InvalidInput`]: Invalid combinations of open options (truncate + /// without write access, no access mode set, etc.). + /// * [`Other`]: One of the directory components of the specified file path + /// was not, in fact, a directory. + /// * [`Other`]: Filesystem-level errors: full disk, write permission + /// requested on a read-only file system, exceeded disk quota, too many + /// open files, too long filename, too many symbolic links in the + /// specified path (Unix-like systems only), etc. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let file = OpenOptions::new().open("foo.txt").await?; + /// Ok(()) + /// } + /// ``` + /// + /// [`ErrorKind`]: std::io::ErrorKind + /// [`AlreadyExists`]: std::io::ErrorKind::AlreadyExists + /// [`InvalidInput`]: std::io::ErrorKind::InvalidInput + /// [`NotFound`]: std::io::ErrorKind::NotFound + /// [`Other`]: std::io::ErrorKind::Other + /// [`PermissionDenied`]: std::io::ErrorKind::PermissionDenied + pub async fn open(&self, path: impl AsRef<Path>) -> io::Result<File> { + let path = path.as_ref().to_owned(); + let opts = self.0.clone(); + + let std = asyncify(move || opts.open(path)).await?; + Ok(File::from_std(std)) + } + + /// Returns a mutable reference to the underlying `std::fs::OpenOptions` + pub(super) fn as_inner_mut(&mut self) -> &mut StdOpenOptions { + &mut self.0 + } +} + +feature! { + #![unix] + + use std::os::unix::fs::OpenOptionsExt; + + impl OpenOptions { + /// Sets the mode bits that a new file will be created with. + /// + /// If a new file is created as part of an `OpenOptions::open` call then this + /// specified `mode` will be used as the permission bits for the new file. + /// If no `mode` is set, the default of `0o666` will be used. + /// The operating system masks out bits with the system's `umask`, to produce + /// the final permissions. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.mode(0o644); // Give read/write for owner and read for others. + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn mode(&mut self, mode: u32) -> &mut OpenOptions { + self.as_inner_mut().mode(mode); + self + } + + /// Passes custom flags to the `flags` argument of `open`. + /// + /// The bits that define the access mode are masked out with `O_ACCMODE`, to + /// ensure they do not interfere with the access mode set by Rusts options. + /// + /// Custom flags can only set flags, not remove flags set by Rusts options. + /// This options overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use libc; + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.write(true); + /// if cfg!(unix) { + /// options.custom_flags(libc::O_NOFOLLOW); + /// } + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } + } +} + +feature! { + #![windows] + + use std::os::windows::fs::OpenOptionsExt; + + impl OpenOptions { + /// Overrides the `dwDesiredAccess` argument to the call to [`CreateFile`] + /// with the specified value. + /// + /// This will override the `read`, `write`, and `append` flags on the + /// `OpenOptions` structure. This method provides fine-grained control over + /// the permissions to read, write and append data, attributes (like hidden + /// and system), and extended attributes. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// // Open without read and write permission, for example if you only need + /// // to call `stat` on the file + /// let file = OpenOptions::new().access_mode(0).open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + pub fn access_mode(&mut self, access: u32) -> &mut OpenOptions { + self.as_inner_mut().access_mode(access); + self + } + + /// Overrides the `dwShareMode` argument to the call to [`CreateFile`] with + /// the specified value. + /// + /// By default `share_mode` is set to + /// `FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE`. This allows + /// other processes to read, write, and delete/rename the same file + /// while it is open. Removing any of the flags will prevent other + /// processes from performing the corresponding operation until the file + /// handle is closed. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// // Do not allow others to read or modify this file while we have it open + /// // for writing. + /// let file = OpenOptions::new() + /// .write(true) + /// .share_mode(0) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + pub fn share_mode(&mut self, share: u32) -> &mut OpenOptions { + self.as_inner_mut().share_mode(share); + self + } + + /// Sets extra flags for the `dwFileFlags` argument to the call to + /// [`CreateFile2`] to the specified value (or combines it with + /// `attributes` and `security_qos_flags` to set the `dwFlagsAndAttributes` + /// for [`CreateFile`]). + /// + /// Custom flags can only set flags, not remove flags set by Rust's options. + /// This option overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winbase::FILE_FLAG_DELETE_ON_CLOSE; + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .create(true) + /// .write(true) + /// .custom_flags(FILE_FLAG_DELETE_ON_CLOSE) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + pub fn custom_flags(&mut self, flags: u32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } + + /// Sets the `dwFileAttributes` argument to the call to [`CreateFile2`] to + /// the specified value (or combines it with `custom_flags` and + /// `security_qos_flags` to set the `dwFlagsAndAttributes` for + /// [`CreateFile`]). + /// + /// If a _new_ file is created because it does not yet exist and + /// `.create(true)` or `.create_new(true)` are specified, the new file is + /// given the attributes declared with `.attributes()`. + /// + /// If an _existing_ file is opened with `.create(true).truncate(true)`, its + /// existing attributes are preserved and combined with the ones declared + /// with `.attributes()`. + /// + /// In all other cases the attributes get ignored. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winnt::FILE_ATTRIBUTE_HIDDEN; + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// .attributes(FILE_ATTRIBUTE_HIDDEN) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + pub fn attributes(&mut self, attributes: u32) -> &mut OpenOptions { + self.as_inner_mut().attributes(attributes); + self + } + + /// Sets the `dwSecurityQosFlags` argument to the call to [`CreateFile2`] to + /// the specified value (or combines it with `custom_flags` and `attributes` + /// to set the `dwFlagsAndAttributes` for [`CreateFile`]). + /// + /// By default `security_qos_flags` is not set. It should be specified when + /// opening a named pipe, to control to which degree a server process can + /// act on behalf of a client process (security impersonation level). + /// + /// When `security_qos_flags` is not set, a malicious program can gain the + /// elevated privileges of a privileged Rust process when it allows opening + /// user-specified paths, by tricking it into opening a named pipe. So + /// arguably `security_qos_flags` should also be set when opening arbitrary + /// paths. However the bits can then conflict with other flags, specifically + /// `FILE_FLAG_OPEN_NO_RECALL`. + /// + /// For information about possible values, see [Impersonation Levels] on the + /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set + /// automatically when using this method. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winbase::SECURITY_IDENTIFICATION; + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// + /// // Sets the flag value to `SecurityIdentification`. + /// .security_qos_flags(SECURITY_IDENTIFICATION) + /// + /// .open(r"\\.\pipe\MyPipe").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + /// [Impersonation Levels]: + /// https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level + pub fn security_qos_flags(&mut self, flags: u32) -> &mut OpenOptions { + self.as_inner_mut().security_qos_flags(flags); + self + } + } +} + +impl From<StdOpenOptions> for OpenOptions { + fn from(options: StdOpenOptions) -> OpenOptions { + OpenOptions(options) + } +} + +impl Default for OpenOptions { + fn default() -> Self { + Self::new() + } +} diff --git a/third_party/rust/tokio/src/fs/open_options/mock_open_options.rs b/third_party/rust/tokio/src/fs/open_options/mock_open_options.rs new file mode 100644 index 0000000000..cbbda0ec25 --- /dev/null +++ b/third_party/rust/tokio/src/fs/open_options/mock_open_options.rs @@ -0,0 +1,38 @@ +//! Mock version of std::fs::OpenOptions; +use mockall::mock; + +use crate::fs::mocks::MockFile; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(windows)] +use std::os::windows::fs::OpenOptionsExt; +use std::{io, path::Path}; + +mock! { + #[derive(Debug)] + pub OpenOptions { + pub fn append(&mut self, append: bool) -> &mut Self; + pub fn create(&mut self, create: bool) -> &mut Self; + pub fn create_new(&mut self, create_new: bool) -> &mut Self; + pub fn open<P: AsRef<Path> + 'static>(&self, path: P) -> io::Result<MockFile>; + pub fn read(&mut self, read: bool) -> &mut Self; + pub fn truncate(&mut self, truncate: bool) -> &mut Self; + pub fn write(&mut self, write: bool) -> &mut Self; + } + impl Clone for OpenOptions { + fn clone(&self) -> Self; + } + #[cfg(unix)] + impl OpenOptionsExt for OpenOptions { + fn custom_flags(&mut self, flags: i32) -> &mut Self; + fn mode(&mut self, mode: u32) -> &mut Self; + } + #[cfg(windows)] + impl OpenOptionsExt for OpenOptions { + fn access_mode(&mut self, access: u32) -> &mut Self; + fn share_mode(&mut self, val: u32) -> &mut Self; + fn custom_flags(&mut self, flags: u32) -> &mut Self; + fn attributes(&mut self, val: u32) -> &mut Self; + fn security_qos_flags(&mut self, flags: u32) -> &mut Self; + } +} diff --git a/third_party/rust/tokio/src/fs/read.rs b/third_party/rust/tokio/src/fs/read.rs new file mode 100644 index 0000000000..ada5ba391b --- /dev/null +++ b/third_party/rust/tokio/src/fs/read.rs @@ -0,0 +1,51 @@ +use crate::fs::asyncify; + +use std::{io, path::Path}; + +/// Reads the entire contents of a file into a bytes vector. +/// +/// This is an async version of [`std::fs::read`][std] +/// +/// [std]: std::fs::read +/// +/// This is a convenience function for using [`File::open`] and [`read_to_end`] +/// with fewer imports and without an intermediate variable. It pre-allocates a +/// buffer based on the file size when available, so it is generally faster than +/// reading into a vector created with `Vec::new()`. +/// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`File::open`]: super::File::open +/// [`read_to_end`]: crate::io::AsyncReadExt::read_to_end +/// [`spawn_blocking`]: crate::task::spawn_blocking +/// +/// # Errors +/// +/// This function will return an error if `path` does not already exist. +/// Other errors may also be returned according to [`OpenOptions::open`]. +/// +/// [`OpenOptions::open`]: super::OpenOptions::open +/// +/// It will also return an error if it encounters while reading an error +/// of a kind other than [`ErrorKind::Interrupted`]. +/// +/// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// use std::net::SocketAddr; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box<dyn std::error::Error + 'static>> { +/// let contents = fs::read("address.txt").await?; +/// let foo: SocketAddr = String::from_utf8_lossy(&contents).parse()?; +/// Ok(()) +/// } +/// ``` +pub async fn read(path: impl AsRef<Path>) -> io::Result<Vec<u8>> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::read(path)).await +} diff --git a/third_party/rust/tokio/src/fs/read_dir.rs b/third_party/rust/tokio/src/fs/read_dir.rs new file mode 100644 index 0000000000..281ea4cd75 --- /dev/null +++ b/third_party/rust/tokio/src/fs/read_dir.rs @@ -0,0 +1,295 @@ +use crate::fs::asyncify; + +use std::ffi::OsString; +use std::fs::{FileType, Metadata}; +use std::future::Future; +use std::io; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; + +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; + +/// Returns a stream over the entries within a directory. +/// +/// This is an async version of [`std::fs::read_dir`](std::fs::read_dir) +/// +/// This operation is implemented by running the equivalent blocking +/// operation on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking +pub async fn read_dir(path: impl AsRef<Path>) -> io::Result<ReadDir> { + let path = path.as_ref().to_owned(); + let std = asyncify(|| std::fs::read_dir(path)).await?; + + Ok(ReadDir(State::Idle(Some(std)))) +} + +/// Reads the the entries in a directory. +/// +/// This struct is returned from the [`read_dir`] function of this module and +/// will yield instances of [`DirEntry`]. Through a [`DirEntry`] information +/// like the entry's path and possibly other metadata can be learned. +/// +/// A `ReadDir` can be turned into a `Stream` with [`ReadDirStream`]. +/// +/// [`ReadDirStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReadDirStream.html +/// +/// # Errors +/// +/// This stream will return an [`Err`] if there's some sort of intermittent +/// IO error during iteration. +/// +/// [`read_dir`]: read_dir +/// [`DirEntry`]: DirEntry +/// [`Err`]: std::result::Result::Err +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct ReadDir(State); + +#[derive(Debug)] +enum State { + Idle(Option<std::fs::ReadDir>), + Pending(JoinHandle<(Option<io::Result<std::fs::DirEntry>>, std::fs::ReadDir)>), +} + +impl ReadDir { + /// Returns the next entry in the directory stream. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. + pub async fn next_entry(&mut self) -> io::Result<Option<DirEntry>> { + use crate::future::poll_fn; + poll_fn(|cx| self.poll_next_entry(cx)).await + } + + /// Polls for the next directory entry in the stream. + /// + /// This method returns: + /// + /// * `Poll::Pending` if the next directory entry is not yet available. + /// * `Poll::Ready(Ok(Some(entry)))` if the next directory entry is available. + /// * `Poll::Ready(Ok(None))` if there are no more directory entries in this + /// stream. + /// * `Poll::Ready(Err(err))` if an IO error occurred while reading the next + /// directory entry. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when the next directory entry + /// becomes available on the underlying IO resource. + /// + /// Note that on multiple calls to `poll_next_entry`, only the `Waker` from + /// the `Context` passed to the most recent call is scheduled to receive a + /// wakeup. + pub fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> { + loop { + match self.0 { + State::Idle(ref mut std) => { + let mut std = std.take().unwrap(); + + self.0 = State::Pending(spawn_blocking(move || { + let ret = std.next(); + (ret, std) + })); + } + State::Pending(ref mut rx) => { + let (ret, std) = ready!(Pin::new(rx).poll(cx))?; + self.0 = State::Idle(Some(std)); + + let ret = match ret { + Some(Ok(std)) => Ok(Some(DirEntry(Arc::new(std)))), + Some(Err(e)) => Err(e), + None => Ok(None), + }; + + return Poll::Ready(ret); + } + } + } + } +} + +feature! { + #![unix] + + use std::os::unix::fs::DirEntryExt; + + impl DirEntry { + /// Returns the underlying `d_ino` field in the contained `dirent` + /// structure. + /// + /// # Examples + /// + /// ``` + /// use tokio::fs; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// while let Some(entry) = entries.next_entry().await? { + /// // Here, `entry` is a `DirEntry`. + /// println!("{:?}: {}", entry.file_name(), entry.ino()); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn ino(&self) -> u64 { + self.as_inner().ino() + } + } +} + +/// Entries returned by the [`ReadDir`] stream. +/// +/// [`ReadDir`]: struct@ReadDir +/// +/// This is a specialized version of [`std::fs::DirEntry`] for usage from the +/// Tokio runtime. +/// +/// An instance of `DirEntry` represents an entry inside of a directory on the +/// filesystem. Each entry can be inspected via methods to learn about the full +/// path or possibly other metadata through per-platform extension traits. +#[derive(Debug)] +pub struct DirEntry(Arc<std::fs::DirEntry>); + +impl DirEntry { + /// Returns the full path to the file that this entry represents. + /// + /// The full path is created by joining the original path to `read_dir` + /// with the filename of this entry. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// + /// while let Some(entry) = entries.next_entry().await? { + /// println!("{:?}", entry.path()); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// This prints output like: + /// + /// ```text + /// "./whatever.txt" + /// "./foo.html" + /// "./hello_world.rs" + /// ``` + /// + /// The exact text, of course, depends on what files you have in `.`. + pub fn path(&self) -> PathBuf { + self.0.path() + } + + /// Returns the bare file name of this directory entry without any other + /// leading path component. + /// + /// # Examples + /// + /// ``` + /// use tokio::fs; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// + /// while let Some(entry) = entries.next_entry().await? { + /// println!("{:?}", entry.file_name()); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn file_name(&self) -> OsString { + self.0.file_name() + } + + /// Returns the metadata for the file that this entry points at. + /// + /// This function will not traverse symlinks if this entry points at a + /// symlink. + /// + /// # Platform-specific behavior + /// + /// On Windows this function is cheap to call (no extra system calls + /// needed), but on Unix platforms this function is the equivalent of + /// calling `symlink_metadata` on the path. + /// + /// # Examples + /// + /// ``` + /// use tokio::fs; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// + /// while let Some(entry) = entries.next_entry().await? { + /// if let Ok(metadata) = entry.metadata().await { + /// // Now let's show our entry's permissions! + /// println!("{:?}: {:?}", entry.path(), metadata.permissions()); + /// } else { + /// println!("Couldn't get file type for {:?}", entry.path()); + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn metadata(&self) -> io::Result<Metadata> { + let std = self.0.clone(); + asyncify(move || std.metadata()).await + } + + /// Returns the file type for the file that this entry points at. + /// + /// This function will not traverse symlinks if this entry points at a + /// symlink. + /// + /// # Platform-specific behavior + /// + /// On Windows and most Unix platforms this function is free (no extra + /// system calls needed), but some Unix platforms may require the equivalent + /// call to `symlink_metadata` to learn about the target file type. + /// + /// # Examples + /// + /// ``` + /// use tokio::fs; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// + /// while let Some(entry) = entries.next_entry().await? { + /// if let Ok(file_type) = entry.file_type().await { + /// // Now let's show our entry's file type! + /// println!("{:?}: {:?}", entry.path(), file_type); + /// } else { + /// println!("Couldn't get file type for {:?}", entry.path()); + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn file_type(&self) -> io::Result<FileType> { + let std = self.0.clone(); + asyncify(move || std.file_type()).await + } + + /// Returns a reference to the underlying `std::fs::DirEntry`. + #[cfg(unix)] + pub(super) fn as_inner(&self) -> &std::fs::DirEntry { + &self.0 + } +} diff --git a/third_party/rust/tokio/src/fs/read_link.rs b/third_party/rust/tokio/src/fs/read_link.rs new file mode 100644 index 0000000000..6c48c5e156 --- /dev/null +++ b/third_party/rust/tokio/src/fs/read_link.rs @@ -0,0 +1,14 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::{Path, PathBuf}; + +/// Reads a symbolic link, returning the file that the link points to. +/// +/// This is an async version of [`std::fs::read_link`][std] +/// +/// [std]: std::fs::read_link +pub async fn read_link(path: impl AsRef<Path>) -> io::Result<PathBuf> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::read_link(path)).await +} diff --git a/third_party/rust/tokio/src/fs/read_to_string.rs b/third_party/rust/tokio/src/fs/read_to_string.rs new file mode 100644 index 0000000000..26228d98c2 --- /dev/null +++ b/third_party/rust/tokio/src/fs/read_to_string.rs @@ -0,0 +1,30 @@ +use crate::fs::asyncify; + +use std::{io, path::Path}; + +/// Creates a future which will open a file for reading and read the entire +/// contents into a string and return said string. +/// +/// This is the async equivalent of [`std::fs::read_to_string`][std]. +/// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking +/// [std]: fn@std::fs::read_to_string +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// +/// # async fn dox() -> std::io::Result<()> { +/// let contents = fs::read_to_string("foo.txt").await?; +/// println!("foo.txt contains {} bytes", contents.len()); +/// # Ok(()) +/// # } +/// ``` +pub async fn read_to_string(path: impl AsRef<Path>) -> io::Result<String> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::read_to_string(path)).await +} diff --git a/third_party/rust/tokio/src/fs/remove_dir.rs b/third_party/rust/tokio/src/fs/remove_dir.rs new file mode 100644 index 0000000000..6e7cbd08f6 --- /dev/null +++ b/third_party/rust/tokio/src/fs/remove_dir.rs @@ -0,0 +1,12 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Removes an existing, empty directory. +/// +/// This is an async version of [`std::fs::remove_dir`](std::fs::remove_dir) +pub async fn remove_dir(path: impl AsRef<Path>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::remove_dir(path)).await +} diff --git a/third_party/rust/tokio/src/fs/remove_dir_all.rs b/third_party/rust/tokio/src/fs/remove_dir_all.rs new file mode 100644 index 0000000000..0a237550f9 --- /dev/null +++ b/third_party/rust/tokio/src/fs/remove_dir_all.rs @@ -0,0 +1,14 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Removes a directory at this path, after removing all its contents. Use carefully! +/// +/// This is an async version of [`std::fs::remove_dir_all`][std] +/// +/// [std]: fn@std::fs::remove_dir_all +pub async fn remove_dir_all(path: impl AsRef<Path>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::remove_dir_all(path)).await +} diff --git a/third_party/rust/tokio/src/fs/remove_file.rs b/third_party/rust/tokio/src/fs/remove_file.rs new file mode 100644 index 0000000000..d22a5bfc88 --- /dev/null +++ b/third_party/rust/tokio/src/fs/remove_file.rs @@ -0,0 +1,18 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Removes a file from the filesystem. +/// +/// Note that there is no guarantee that the file is immediately deleted (e.g. +/// depending on platform, other open file descriptors may prevent immediate +/// removal). +/// +/// This is an async version of [`std::fs::remove_file`][std] +/// +/// [std]: std::fs::remove_file +pub async fn remove_file(path: impl AsRef<Path>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + asyncify(move || std::fs::remove_file(path)).await +} diff --git a/third_party/rust/tokio/src/fs/rename.rs b/third_party/rust/tokio/src/fs/rename.rs new file mode 100644 index 0000000000..4f980821d2 --- /dev/null +++ b/third_party/rust/tokio/src/fs/rename.rs @@ -0,0 +1,17 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Renames a file or directory to a new name, replacing the original file if +/// `to` already exists. +/// +/// This will not work if the new name is on a different mount point. +/// +/// This is an async version of [`std::fs::rename`](std::fs::rename) +pub async fn rename(from: impl AsRef<Path>, to: impl AsRef<Path>) -> io::Result<()> { + let from = from.as_ref().to_owned(); + let to = to.as_ref().to_owned(); + + asyncify(move || std::fs::rename(from, to)).await +} diff --git a/third_party/rust/tokio/src/fs/set_permissions.rs b/third_party/rust/tokio/src/fs/set_permissions.rs new file mode 100644 index 0000000000..09be02ea01 --- /dev/null +++ b/third_party/rust/tokio/src/fs/set_permissions.rs @@ -0,0 +1,15 @@ +use crate::fs::asyncify; + +use std::fs::Permissions; +use std::io; +use std::path::Path; + +/// Changes the permissions found on a file or a directory. +/// +/// This is an async version of [`std::fs::set_permissions`][std] +/// +/// [std]: fn@std::fs::set_permissions +pub async fn set_permissions(path: impl AsRef<Path>, perm: Permissions) -> io::Result<()> { + let path = path.as_ref().to_owned(); + asyncify(|| std::fs::set_permissions(path, perm)).await +} diff --git a/third_party/rust/tokio/src/fs/symlink.rs b/third_party/rust/tokio/src/fs/symlink.rs new file mode 100644 index 0000000000..22ece7250f --- /dev/null +++ b/third_party/rust/tokio/src/fs/symlink.rs @@ -0,0 +1,18 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Creates a new symbolic link on the filesystem. +/// +/// The `dst` path will be a symbolic link pointing to the `src` path. +/// +/// This is an async version of [`std::os::unix::fs::symlink`][std] +/// +/// [std]: std::os::unix::fs::symlink +pub async fn symlink(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> io::Result<()> { + let src = src.as_ref().to_owned(); + let dst = dst.as_ref().to_owned(); + + asyncify(move || std::os::unix::fs::symlink(src, dst)).await +} diff --git a/third_party/rust/tokio/src/fs/symlink_dir.rs b/third_party/rust/tokio/src/fs/symlink_dir.rs new file mode 100644 index 0000000000..736e762b48 --- /dev/null +++ b/third_party/rust/tokio/src/fs/symlink_dir.rs @@ -0,0 +1,19 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Creates a new directory symlink on the filesystem. +/// +/// The `dst` path will be a directory symbolic link pointing to the `src` +/// path. +/// +/// This is an async version of [`std::os::windows::fs::symlink_dir`][std] +/// +/// [std]: std::os::windows::fs::symlink_dir +pub async fn symlink_dir(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> io::Result<()> { + let src = src.as_ref().to_owned(); + let dst = dst.as_ref().to_owned(); + + asyncify(move || std::os::windows::fs::symlink_dir(src, dst)).await +} diff --git a/third_party/rust/tokio/src/fs/symlink_file.rs b/third_party/rust/tokio/src/fs/symlink_file.rs new file mode 100644 index 0000000000..07d8e60419 --- /dev/null +++ b/third_party/rust/tokio/src/fs/symlink_file.rs @@ -0,0 +1,19 @@ +use crate::fs::asyncify; + +use std::io; +use std::path::Path; + +/// Creates a new file symbolic link on the filesystem. +/// +/// The `dst` path will be a file symbolic link pointing to the `src` +/// path. +/// +/// This is an async version of [`std::os::windows::fs::symlink_file`][std] +/// +/// [std]: std::os::windows::fs::symlink_file +pub async fn symlink_file(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> io::Result<()> { + let src = src.as_ref().to_owned(); + let dst = dst.as_ref().to_owned(); + + asyncify(move || std::os::windows::fs::symlink_file(src, dst)).await +} diff --git a/third_party/rust/tokio/src/fs/symlink_metadata.rs b/third_party/rust/tokio/src/fs/symlink_metadata.rs new file mode 100644 index 0000000000..1d0df12576 --- /dev/null +++ b/third_party/rust/tokio/src/fs/symlink_metadata.rs @@ -0,0 +1,15 @@ +use crate::fs::asyncify; + +use std::fs::Metadata; +use std::io; +use std::path::Path; + +/// Queries the file system metadata for a path. +/// +/// This is an async version of [`std::fs::symlink_metadata`][std] +/// +/// [std]: fn@std::fs::symlink_metadata +pub async fn symlink_metadata(path: impl AsRef<Path>) -> io::Result<Metadata> { + let path = path.as_ref().to_owned(); + asyncify(|| std::fs::symlink_metadata(path)).await +} diff --git a/third_party/rust/tokio/src/fs/write.rs b/third_party/rust/tokio/src/fs/write.rs new file mode 100644 index 0000000000..28606fb363 --- /dev/null +++ b/third_party/rust/tokio/src/fs/write.rs @@ -0,0 +1,31 @@ +use crate::fs::asyncify; + +use std::{io, path::Path}; + +/// Creates a future that will open a file for writing and write the entire +/// contents of `contents` to it. +/// +/// This is the async equivalent of [`std::fs::write`][std]. +/// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking +/// [std]: fn@std::fs::write +/// +/// # Examples +/// +/// ```no_run +/// use tokio::fs; +/// +/// # async fn dox() -> std::io::Result<()> { +/// fs::write("foo.txt", b"Hello world!").await?; +/// # Ok(()) +/// # } +/// ``` +pub async fn write(path: impl AsRef<Path>, contents: impl AsRef<[u8]>) -> io::Result<()> { + let path = path.as_ref().to_owned(); + let contents = contents.as_ref().to_owned(); + + asyncify(move || std::fs::write(path, contents)).await +} diff --git a/third_party/rust/tokio/src/future/block_on.rs b/third_party/rust/tokio/src/future/block_on.rs new file mode 100644 index 0000000000..91f9cc0055 --- /dev/null +++ b/third_party/rust/tokio/src/future/block_on.rs @@ -0,0 +1,15 @@ +use std::future::Future; + +cfg_rt! { + pub(crate) fn block_on<F: Future>(f: F) -> F::Output { + let mut e = crate::runtime::enter::enter(false); + e.block_on(f).unwrap() + } +} + +cfg_not_rt! { + pub(crate) fn block_on<F: Future>(f: F) -> F::Output { + let mut park = crate::park::thread::CachedParkThread::new(); + park.block_on(f).unwrap() + } +} diff --git a/third_party/rust/tokio/src/future/maybe_done.rs b/third_party/rust/tokio/src/future/maybe_done.rs new file mode 100644 index 0000000000..486efbe01a --- /dev/null +++ b/third_party/rust/tokio/src/future/maybe_done.rs @@ -0,0 +1,76 @@ +//! Definition of the MaybeDone combinator. + +use std::future::Future; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A future that may have completed. +#[derive(Debug)] +pub enum MaybeDone<Fut: Future> { + /// A not-yet-completed future. + Future(Fut), + /// The output of the completed future. + Done(Fut::Output), + /// The empty variant after the result of a [`MaybeDone`] has been + /// taken using the [`take_output`](MaybeDone::take_output) method. + Gone, +} + +// Safe because we never generate `Pin<&mut Fut::Output>` +impl<Fut: Future + Unpin> Unpin for MaybeDone<Fut> {} + +/// Wraps a future into a `MaybeDone`. +pub fn maybe_done<Fut: Future>(future: Fut) -> MaybeDone<Fut> { + MaybeDone::Future(future) +} + +impl<Fut: Future> MaybeDone<Fut> { + /// Returns an [`Option`] containing a mutable reference to the output of the future. + /// The output of this method will be [`Some`] if and only if the inner + /// future has been completed and [`take_output`](MaybeDone::take_output) + /// has not yet been called. + pub fn output_mut(self: Pin<&mut Self>) -> Option<&mut Fut::Output> { + unsafe { + let this = self.get_unchecked_mut(); + match this { + MaybeDone::Done(res) => Some(res), + _ => None, + } + } + } + + /// Attempts to take the output of a `MaybeDone` without driving it + /// towards completion. + #[inline] + pub fn take_output(self: Pin<&mut Self>) -> Option<Fut::Output> { + unsafe { + let this = self.get_unchecked_mut(); + match this { + MaybeDone::Done(_) => {} + MaybeDone::Future(_) | MaybeDone::Gone => return None, + }; + if let MaybeDone::Done(output) = mem::replace(this, MaybeDone::Gone) { + Some(output) + } else { + unreachable!() + } + } + } +} + +impl<Fut: Future> Future for MaybeDone<Fut> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let res = unsafe { + match self.as_mut().get_unchecked_mut() { + MaybeDone::Future(a) => ready!(Pin::new_unchecked(a).poll(cx)), + MaybeDone::Done(_) => return Poll::Ready(()), + MaybeDone::Gone => panic!("MaybeDone polled after value taken"), + } + }; + self.set(MaybeDone::Done(res)); + Poll::Ready(()) + } +} diff --git a/third_party/rust/tokio/src/future/mod.rs b/third_party/rust/tokio/src/future/mod.rs new file mode 100644 index 0000000000..084ddc571f --- /dev/null +++ b/third_party/rust/tokio/src/future/mod.rs @@ -0,0 +1,30 @@ +#![cfg_attr(not(feature = "macros"), allow(unreachable_pub))] + +//! Asynchronous values. + +#[cfg(any(feature = "macros", feature = "process"))] +pub(crate) mod maybe_done; + +mod poll_fn; +pub use poll_fn::poll_fn; + +cfg_process! { + mod try_join; + pub(crate) use try_join::try_join3; +} + +cfg_sync! { + mod block_on; + pub(crate) use block_on::block_on; +} + +cfg_trace! { + mod trace; + pub(crate) use trace::InstrumentedFuture as Future; +} + +cfg_not_trace! { + cfg_rt! { + pub(crate) use std::future::Future; + } +} diff --git a/third_party/rust/tokio/src/future/poll_fn.rs b/third_party/rust/tokio/src/future/poll_fn.rs new file mode 100644 index 0000000000..d82ce8961d --- /dev/null +++ b/third_party/rust/tokio/src/future/poll_fn.rs @@ -0,0 +1,40 @@ +#![allow(dead_code)] + +//! Definition of the `PollFn` adapter combinator. + +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Future for the [`poll_fn`] function. +pub struct PollFn<F> { + f: F, +} + +impl<F> Unpin for PollFn<F> {} + +/// Creates a new future wrapping around a function returning [`Poll`]. +pub fn poll_fn<T, F>(f: F) -> PollFn<F> +where + F: FnMut(&mut Context<'_>) -> Poll<T>, +{ + PollFn { f } +} + +impl<F> fmt::Debug for PollFn<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollFn").finish() + } +} + +impl<T, F> Future for PollFn<F> +where + F: FnMut(&mut Context<'_>) -> Poll<T>, +{ + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + (&mut self.f)(cx) + } +} diff --git a/third_party/rust/tokio/src/future/trace.rs b/third_party/rust/tokio/src/future/trace.rs new file mode 100644 index 0000000000..28789a604d --- /dev/null +++ b/third_party/rust/tokio/src/future/trace.rs @@ -0,0 +1,11 @@ +use std::future::Future; + +pub(crate) trait InstrumentedFuture: Future { + fn id(&self) -> Option<tracing::Id>; +} + +impl<F: Future> InstrumentedFuture for tracing::instrument::Instrumented<F> { + fn id(&self) -> Option<tracing::Id> { + self.span().id() + } +} diff --git a/third_party/rust/tokio/src/future/try_join.rs b/third_party/rust/tokio/src/future/try_join.rs new file mode 100644 index 0000000000..8943f61a1e --- /dev/null +++ b/third_party/rust/tokio/src/future/try_join.rs @@ -0,0 +1,82 @@ +use crate::future::maybe_done::{maybe_done, MaybeDone}; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) fn try_join3<T1, F1, T2, F2, T3, F3, E>( + future1: F1, + future2: F2, + future3: F3, +) -> TryJoin3<F1, F2, F3> +where + F1: Future<Output = Result<T1, E>>, + F2: Future<Output = Result<T2, E>>, + F3: Future<Output = Result<T3, E>>, +{ + TryJoin3 { + future1: maybe_done(future1), + future2: maybe_done(future2), + future3: maybe_done(future3), + } +} + +pin_project! { + pub(crate) struct TryJoin3<F1, F2, F3> + where + F1: Future, + F2: Future, + F3: Future, + { + #[pin] + future1: MaybeDone<F1>, + #[pin] + future2: MaybeDone<F2>, + #[pin] + future3: MaybeDone<F3>, + } +} + +impl<T1, F1, T2, F2, T3, F3, E> Future for TryJoin3<F1, F2, F3> +where + F1: Future<Output = Result<T1, E>>, + F2: Future<Output = Result<T2, E>>, + F3: Future<Output = Result<T3, E>>, +{ + type Output = Result<(T1, T2, T3), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut all_done = true; + + let mut me = self.project(); + + if me.future1.as_mut().poll(cx).is_pending() { + all_done = false; + } else if me.future1.as_mut().output_mut().unwrap().is_err() { + return Poll::Ready(Err(me.future1.take_output().unwrap().err().unwrap())); + } + + if me.future2.as_mut().poll(cx).is_pending() { + all_done = false; + } else if me.future2.as_mut().output_mut().unwrap().is_err() { + return Poll::Ready(Err(me.future2.take_output().unwrap().err().unwrap())); + } + + if me.future3.as_mut().poll(cx).is_pending() { + all_done = false; + } else if me.future3.as_mut().output_mut().unwrap().is_err() { + return Poll::Ready(Err(me.future3.take_output().unwrap().err().unwrap())); + } + + if all_done { + Poll::Ready(Ok(( + me.future1.take_output().unwrap().ok().unwrap(), + me.future2.take_output().unwrap().ok().unwrap(), + me.future3.take_output().unwrap().ok().unwrap(), + ))) + } else { + Poll::Pending + } + } +} diff --git a/third_party/rust/tokio/src/io/async_buf_read.rs b/third_party/rust/tokio/src/io/async_buf_read.rs new file mode 100644 index 0000000000..ecaafba4c2 --- /dev/null +++ b/third_party/rust/tokio/src/io/async_buf_read.rs @@ -0,0 +1,117 @@ +use crate::io::AsyncRead; + +use std::io; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Reads bytes asynchronously. +/// +/// This trait is analogous to [`std::io::BufRead`], but integrates with +/// the asynchronous task system. In particular, the [`poll_fill_buf`] method, +/// unlike [`BufRead::fill_buf`], will automatically queue the current task for wakeup +/// and return if data is not yet available, rather than blocking the calling +/// thread. +/// +/// Utilities for working with `AsyncBufRead` values are provided by +/// [`AsyncBufReadExt`]. +/// +/// [`std::io::BufRead`]: std::io::BufRead +/// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf +/// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf +/// [`AsyncBufReadExt`]: crate::io::AsyncBufReadExt +pub trait AsyncBufRead: AsyncRead { + /// Attempts to return the contents of the internal buffer, filling it with more data + /// from the inner reader if it is empty. + /// + /// On success, returns `Poll::Ready(Ok(buf))`. + /// + /// If no data is available for reading, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes + /// readable or is closed. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`consume`] method to function properly. When calling this + /// method, none of the contents will be "read" in the sense that later + /// calling [`poll_read`] may return the same contents. As such, [`consume`] must + /// be called with the number of bytes that are consumed from this buffer to + /// ensure that the bytes are never returned twice. + /// + /// An empty buffer returned indicates that the stream has reached EOF. + /// + /// [`poll_read`]: AsyncRead::poll_read + /// [`consume`]: AsyncBufRead::consume + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>; + + /// Tells this buffer that `amt` bytes have been consumed from the buffer, + /// so they should no longer be returned in calls to [`poll_read`]. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`poll_fill_buf`] method to function properly. This function does + /// not perform any I/O, it simply informs this object that some amount of + /// its buffer, returned from [`poll_fill_buf`], has been consumed and should + /// no longer be returned. As such, this function may do odd things if + /// [`poll_fill_buf`] isn't called before calling it. + /// + /// The `amt` must be `<=` the number of bytes in the buffer returned by + /// [`poll_fill_buf`]. + /// + /// [`poll_read`]: AsyncRead::poll_read + /// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf + fn consume(self: Pin<&mut Self>, amt: usize); +} + +macro_rules! deref_async_buf_read { + () => { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + Pin::new(&mut **self.get_mut()).poll_fill_buf(cx) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + Pin::new(&mut **self).consume(amt) + } + }; +} + +impl<T: ?Sized + AsyncBufRead + Unpin> AsyncBufRead for Box<T> { + deref_async_buf_read!(); +} + +impl<T: ?Sized + AsyncBufRead + Unpin> AsyncBufRead for &mut T { + deref_async_buf_read!(); +} + +impl<P> AsyncBufRead for Pin<P> +where + P: DerefMut + Unpin, + P::Target: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + self.get_mut().as_mut().poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_mut().as_mut().consume(amt) + } +} + +impl AsyncBufRead for &[u8] { + fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + Poll::Ready(Ok(*self)) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + *self = &self[amt..]; + } +} + +impl<T: AsRef<[u8]> + Unpin> AsyncBufRead for io::Cursor<T> { + fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + Poll::Ready(io::BufRead::fill_buf(self.get_mut())) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + io::BufRead::consume(self.get_mut(), amt) + } +} diff --git a/third_party/rust/tokio/src/io/async_fd.rs b/third_party/rust/tokio/src/io/async_fd.rs new file mode 100644 index 0000000000..93f9cb458a --- /dev/null +++ b/third_party/rust/tokio/src/io/async_fd.rs @@ -0,0 +1,660 @@ +use crate::io::driver::{Handle, Interest, ReadyEvent, Registration}; + +use mio::unix::SourceFd; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::{task::Context, task::Poll}; + +/// Associates an IO object backed by a Unix file descriptor with the tokio +/// reactor, allowing for readiness to be polled. The file descriptor must be of +/// a type that can be used with the OS polling facilities (ie, `poll`, `epoll`, +/// `kqueue`, etc), such as a network socket or pipe, and the file descriptor +/// must have the nonblocking mode set to true. +/// +/// Creating an AsyncFd registers the file descriptor with the current tokio +/// Reactor, allowing you to directly await the file descriptor being readable +/// or writable. Once registered, the file descriptor remains registered until +/// the AsyncFd is dropped. +/// +/// The AsyncFd takes ownership of an arbitrary object to represent the IO +/// object. It is intended that this object will handle closing the file +/// descriptor when it is dropped, avoiding resource leaks and ensuring that the +/// AsyncFd can clean up the registration before closing the file descriptor. +/// The [`AsyncFd::into_inner`] function can be used to extract the inner object +/// to retake control from the tokio IO reactor. +/// +/// The inner object is required to implement [`AsRawFd`]. This file descriptor +/// must not change while [`AsyncFd`] owns the inner object, i.e. the +/// [`AsRawFd::as_raw_fd`] method on the inner type must always return the same +/// file descriptor when called multiple times. Failure to uphold this results +/// in unspecified behavior in the IO driver, which may include breaking +/// notifications for other sockets/etc. +/// +/// Polling for readiness is done by calling the async functions [`readable`] +/// and [`writable`]. These functions complete when the associated readiness +/// condition is observed. Any number of tasks can query the same `AsyncFd` in +/// parallel, on the same or different conditions. +/// +/// On some platforms, the readiness detecting mechanism relies on +/// edge-triggered notifications. This means that the OS will only notify Tokio +/// when the file descriptor transitions from not-ready to ready. For this to +/// work you should first try to read or write and only poll for readiness +/// if that fails with an error of [`std::io::ErrorKind::WouldBlock`]. +/// +/// Tokio internally tracks when it has received a ready notification, and when +/// readiness checking functions like [`readable`] and [`writable`] are called, +/// if the readiness flag is set, these async functions will complete +/// immediately. This however does mean that it is critical to ensure that this +/// ready flag is cleared when (and only when) the file descriptor ceases to be +/// ready. The [`AsyncFdReadyGuard`] returned from readiness checking functions +/// serves this function; after calling a readiness-checking async function, +/// you must use this [`AsyncFdReadyGuard`] to signal to tokio whether the file +/// descriptor is no longer in a ready state. +/// +/// ## Use with to a poll-based API +/// +/// In some cases it may be desirable to use `AsyncFd` from APIs similar to +/// [`TcpStream::poll_read_ready`]. The [`AsyncFd::poll_read_ready`] and +/// [`AsyncFd::poll_write_ready`] functions are provided for this purpose. +/// Because these functions don't create a future to hold their state, they have +/// the limitation that only one task can wait on each direction (read or write) +/// at a time. +/// +/// # Examples +/// +/// This example shows how to turn [`std::net::TcpStream`] asynchronous using +/// `AsyncFd`. It implements `read` as an async fn, and `AsyncWrite` as a trait +/// to show how to implement both approaches. +/// +/// ```no_run +/// use futures::ready; +/// use std::io::{self, Read, Write}; +/// use std::net::TcpStream; +/// use std::pin::Pin; +/// use std::task::{Context, Poll}; +/// use tokio::io::AsyncWrite; +/// use tokio::io::unix::AsyncFd; +/// +/// pub struct AsyncTcpStream { +/// inner: AsyncFd<TcpStream>, +/// } +/// +/// impl AsyncTcpStream { +/// pub fn new(tcp: TcpStream) -> io::Result<Self> { +/// tcp.set_nonblocking(true)?; +/// Ok(Self { +/// inner: AsyncFd::new(tcp)?, +/// }) +/// } +/// +/// pub async fn read(&self, out: &mut [u8]) -> io::Result<usize> { +/// loop { +/// let mut guard = self.inner.readable().await?; +/// +/// match guard.try_io(|inner| inner.get_ref().read(out)) { +/// Ok(result) => return result, +/// Err(_would_block) => continue, +/// } +/// } +/// } +/// } +/// +/// impl AsyncWrite for AsyncTcpStream { +/// fn poll_write( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_>, +/// buf: &[u8] +/// ) -> Poll<io::Result<usize>> { +/// loop { +/// let mut guard = ready!(self.inner.poll_write_ready(cx))?; +/// +/// match guard.try_io(|inner| inner.get_ref().write(buf)) { +/// Ok(result) => return Poll::Ready(result), +/// Err(_would_block) => continue, +/// } +/// } +/// } +/// +/// fn poll_flush( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_>, +/// ) -> Poll<io::Result<()>> { +/// // tcp flush is a no-op +/// Poll::Ready(Ok(())) +/// } +/// +/// fn poll_shutdown( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_>, +/// ) -> Poll<io::Result<()>> { +/// self.inner.get_ref().shutdown(std::net::Shutdown::Write)?; +/// Poll::Ready(Ok(())) +/// } +/// } +/// ``` +/// +/// [`readable`]: method@Self::readable +/// [`writable`]: method@Self::writable +/// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard +/// [`TcpStream::poll_read_ready`]: struct@crate::net::TcpStream +pub struct AsyncFd<T: AsRawFd> { + registration: Registration, + inner: Option<T>, +} + +/// Represents an IO-ready event detected on a particular file descriptor that +/// has not yet been acknowledged. This is a `must_use` structure to help ensure +/// that you do not forget to explicitly clear (or not clear) the event. +/// +/// This type exposes an immutable reference to the underlying IO object. +#[must_use = "You must explicitly choose whether to clear the readiness state by calling a method on ReadyGuard"] +pub struct AsyncFdReadyGuard<'a, T: AsRawFd> { + async_fd: &'a AsyncFd<T>, + event: Option<ReadyEvent>, +} + +/// Represents an IO-ready event detected on a particular file descriptor that +/// has not yet been acknowledged. This is a `must_use` structure to help ensure +/// that you do not forget to explicitly clear (or not clear) the event. +/// +/// This type exposes a mutable reference to the underlying IO object. +#[must_use = "You must explicitly choose whether to clear the readiness state by calling a method on ReadyGuard"] +pub struct AsyncFdReadyMutGuard<'a, T: AsRawFd> { + async_fd: &'a mut AsyncFd<T>, + event: Option<ReadyEvent>, +} + +const ALL_INTEREST: Interest = Interest::READABLE.add(Interest::WRITABLE); + +impl<T: AsRawFd> AsyncFd<T> { + #[inline] + /// Creates an AsyncFd backed by (and taking ownership of) an object + /// implementing [`AsRawFd`]. The backing file descriptor is cached at the + /// time of creation. + /// + /// This method must be called in the context of a tokio runtime. + pub fn new(inner: T) -> io::Result<Self> + where + T: AsRawFd, + { + Self::with_interest(inner, ALL_INTEREST) + } + + #[inline] + /// Creates new instance as `new` with additional ability to customize interest, + /// allowing to specify whether file descriptor will be polled for read, write or both. + pub fn with_interest(inner: T, interest: Interest) -> io::Result<Self> + where + T: AsRawFd, + { + Self::new_with_handle_and_interest(inner, Handle::current(), interest) + } + + pub(crate) fn new_with_handle_and_interest( + inner: T, + handle: Handle, + interest: Interest, + ) -> io::Result<Self> { + let fd = inner.as_raw_fd(); + + let registration = + Registration::new_with_interest_and_handle(&mut SourceFd(&fd), interest, handle)?; + + Ok(AsyncFd { + registration, + inner: Some(inner), + }) + } + + /// Returns a shared reference to the backing object of this [`AsyncFd`]. + #[inline] + pub fn get_ref(&self) -> &T { + self.inner.as_ref().unwrap() + } + + /// Returns a mutable reference to the backing object of this [`AsyncFd`]. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + self.inner.as_mut().unwrap() + } + + fn take_inner(&mut self) -> Option<T> { + let fd = self.inner.as_ref().map(AsRawFd::as_raw_fd); + + if let Some(fd) = fd { + let _ = self.registration.deregister(&mut SourceFd(&fd)); + } + + self.inner.take() + } + + /// Deregisters this file descriptor and returns ownership of the backing + /// object. + pub fn into_inner(mut self) -> T { + self.take_inner().unwrap() + } + + /// Polls for read readiness. + /// + /// If the file descriptor is not currently ready for reading, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for reading, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_read_ready`] or + /// [`poll_read_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_write_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_read_ready_mut`]: method@Self::poll_read_ready_mut + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`readable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake + pub fn poll_read_ready<'a>( + &'a self, + cx: &mut Context<'_>, + ) -> Poll<io::Result<AsyncFdReadyGuard<'a, T>>> { + let event = ready!(self.registration.poll_read_ready(cx))?; + + Ok(AsyncFdReadyGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + /// Polls for read readiness. + /// + /// If the file descriptor is not currently ready for reading, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for reading, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_read_ready`] or + /// [`poll_read_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_write_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_read_ready_mut`]: method@Self::poll_read_ready_mut + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`readable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake + pub fn poll_read_ready_mut<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll<io::Result<AsyncFdReadyMutGuard<'a, T>>> { + let event = ready!(self.registration.poll_read_ready(cx))?; + + Ok(AsyncFdReadyMutGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + /// Polls for write readiness. + /// + /// If the file descriptor is not currently ready for writing, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for writing, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_write_ready`] or + /// [`poll_write_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_read_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`poll_write_ready_mut`]: method@Self::poll_write_ready_mut + /// [`writable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake + pub fn poll_write_ready<'a>( + &'a self, + cx: &mut Context<'_>, + ) -> Poll<io::Result<AsyncFdReadyGuard<'a, T>>> { + let event = ready!(self.registration.poll_write_ready(cx))?; + + Ok(AsyncFdReadyGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + /// Polls for write readiness. + /// + /// If the file descriptor is not currently ready for writing, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for writing, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_write_ready`] or + /// [`poll_write_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_read_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`poll_write_ready_mut`]: method@Self::poll_write_ready_mut + /// [`writable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake + pub fn poll_write_ready_mut<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll<io::Result<AsyncFdReadyMutGuard<'a, T>>> { + let event = ready!(self.registration.poll_write_ready(cx))?; + + Ok(AsyncFdReadyMutGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + async fn readiness(&self, interest: Interest) -> io::Result<AsyncFdReadyGuard<'_, T>> { + let event = self.registration.readiness(interest).await?; + + Ok(AsyncFdReadyGuard { + async_fd: self, + event: Some(event), + }) + } + + async fn readiness_mut( + &mut self, + interest: Interest, + ) -> io::Result<AsyncFdReadyMutGuard<'_, T>> { + let event = self.registration.readiness(interest).await?; + + Ok(AsyncFdReadyMutGuard { + async_fd: self, + event: Some(event), + }) + } + + /// Waits for the file descriptor to become readable, returning a + /// [`AsyncFdReadyGuard`] that must be dropped to resume read-readiness + /// polling. + /// + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn readable<'a>(&'a self) -> io::Result<AsyncFdReadyGuard<'a, T>> { + self.readiness(Interest::READABLE).await + } + + /// Waits for the file descriptor to become readable, returning a + /// [`AsyncFdReadyMutGuard`] that must be dropped to resume read-readiness + /// polling. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn readable_mut<'a>(&'a mut self) -> io::Result<AsyncFdReadyMutGuard<'a, T>> { + self.readiness_mut(Interest::READABLE).await + } + + /// Waits for the file descriptor to become writable, returning a + /// [`AsyncFdReadyGuard`] that must be dropped to resume write-readiness + /// polling. + /// + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn writable<'a>(&'a self) -> io::Result<AsyncFdReadyGuard<'a, T>> { + self.readiness(Interest::WRITABLE).await + } + + /// Waits for the file descriptor to become writable, returning a + /// [`AsyncFdReadyMutGuard`] that must be dropped to resume write-readiness + /// polling. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn writable_mut<'a>(&'a mut self) -> io::Result<AsyncFdReadyMutGuard<'a, T>> { + self.readiness_mut(Interest::WRITABLE).await + } +} + +impl<T: AsRawFd> AsRawFd for AsyncFd<T> { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_ref().unwrap().as_raw_fd() + } +} + +impl<T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFd<T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncFd") + .field("inner", &self.inner) + .finish() + } +} + +impl<T: AsRawFd> Drop for AsyncFd<T> { + fn drop(&mut self) { + let _ = self.take_inner(); + } +} + +impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { + /// Indicates to tokio that the file descriptor is no longer ready. The + /// internal readiness flag will be cleared, and tokio will wait for the + /// next edge-triggered readiness notification from the OS. + /// + /// It is critical that this function not be called unless your code + /// _actually observes_ that the file descriptor is _not_ ready. Do not call + /// it simply because, for example, a read succeeded; it should be called + /// when a read is observed to block. + /// + /// [`drop`]: method@std::mem::drop + pub fn clear_ready(&mut self) { + if let Some(event) = self.event.take() { + self.async_fd.registration.clear_readiness(event); + } + } + + /// This method should be invoked when you intentionally want to keep the + /// ready flag asserted. + /// + /// While this function is itself a no-op, it satisfies the `#[must_use]` + /// constraint on the [`AsyncFdReadyGuard`] type. + pub fn retain_ready(&mut self) { + // no-op + } + + /// Performs the provided IO operation. + /// + /// If `f` returns a [`WouldBlock`] error, the readiness state associated + /// with this file descriptor is cleared, and the method returns + /// `Err(TryIoError::WouldBlock)`. You will typically need to poll the + /// `AsyncFd` again when this happens. + /// + /// This method helps ensure that the readiness state of the underlying file + /// descriptor remains in sync with the tokio-side readiness state, by + /// clearing the tokio-side state only when a [`WouldBlock`] condition + /// occurs. It is the responsibility of the caller to ensure that `f` + /// returns [`WouldBlock`] only if the file descriptor that originated this + /// `AsyncFdReadyGuard` no longer expresses the readiness state that was queried to + /// create this `AsyncFdReadyGuard`. + /// + /// [`WouldBlock`]: std::io::ErrorKind::WouldBlock + // Alias for old name in 0.x + #[cfg_attr(docsrs, doc(alias = "with_io"))] + pub fn try_io<R>( + &mut self, + f: impl FnOnce(&'a AsyncFd<Inner>) -> io::Result<R>, + ) -> Result<io::Result<R>, TryIoError> { + let result = f(self.async_fd); + + if let Err(e) = result.as_ref() { + if e.kind() == io::ErrorKind::WouldBlock { + self.clear_ready(); + } + } + + match result { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Err(TryIoError(())), + result => Ok(result), + } + } + + /// Returns a shared reference to the inner [`AsyncFd`]. + pub fn get_ref(&self) -> &'a AsyncFd<Inner> { + self.async_fd + } + + /// Returns a shared reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner(&self) -> &'a Inner { + self.get_ref().get_ref() + } +} + +impl<'a, Inner: AsRawFd> AsyncFdReadyMutGuard<'a, Inner> { + /// Indicates to tokio that the file descriptor is no longer ready. The + /// internal readiness flag will be cleared, and tokio will wait for the + /// next edge-triggered readiness notification from the OS. + /// + /// It is critical that this function not be called unless your code + /// _actually observes_ that the file descriptor is _not_ ready. Do not call + /// it simply because, for example, a read succeeded; it should be called + /// when a read is observed to block. + /// + /// [`drop`]: method@std::mem::drop + pub fn clear_ready(&mut self) { + if let Some(event) = self.event.take() { + self.async_fd.registration.clear_readiness(event); + } + } + + /// This method should be invoked when you intentionally want to keep the + /// ready flag asserted. + /// + /// While this function is itself a no-op, it satisfies the `#[must_use]` + /// constraint on the [`AsyncFdReadyGuard`] type. + pub fn retain_ready(&mut self) { + // no-op + } + + /// Performs the provided IO operation. + /// + /// If `f` returns a [`WouldBlock`] error, the readiness state associated + /// with this file descriptor is cleared, and the method returns + /// `Err(TryIoError::WouldBlock)`. You will typically need to poll the + /// `AsyncFd` again when this happens. + /// + /// This method helps ensure that the readiness state of the underlying file + /// descriptor remains in sync with the tokio-side readiness state, by + /// clearing the tokio-side state only when a [`WouldBlock`] condition + /// occurs. It is the responsibility of the caller to ensure that `f` + /// returns [`WouldBlock`] only if the file descriptor that originated this + /// `AsyncFdReadyGuard` no longer expresses the readiness state that was queried to + /// create this `AsyncFdReadyGuard`. + /// + /// [`WouldBlock`]: std::io::ErrorKind::WouldBlock + pub fn try_io<R>( + &mut self, + f: impl FnOnce(&mut AsyncFd<Inner>) -> io::Result<R>, + ) -> Result<io::Result<R>, TryIoError> { + let result = f(self.async_fd); + + if let Err(e) = result.as_ref() { + if e.kind() == io::ErrorKind::WouldBlock { + self.clear_ready(); + } + } + + match result { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Err(TryIoError(())), + result => Ok(result), + } + } + + /// Returns a shared reference to the inner [`AsyncFd`]. + pub fn get_ref(&self) -> &AsyncFd<Inner> { + self.async_fd + } + + /// Returns a mutable reference to the inner [`AsyncFd`]. + pub fn get_mut(&mut self) -> &mut AsyncFd<Inner> { + self.async_fd + } + + /// Returns a shared reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner(&self) -> &Inner { + self.get_ref().get_ref() + } + + /// Returns a mutable reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner_mut(&mut self) -> &mut Inner { + self.get_mut().get_mut() + } +} + +impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyGuard<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ReadyGuard") + .field("async_fd", &self.async_fd) + .finish() + } +} + +impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyMutGuard<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MutReadyGuard") + .field("async_fd", &self.async_fd) + .finish() + } +} + +/// The error type returned by [`try_io`]. +/// +/// This error indicates that the IO resource returned a [`WouldBlock`] error. +/// +/// [`WouldBlock`]: std::io::ErrorKind::WouldBlock +/// [`try_io`]: method@AsyncFdReadyGuard::try_io +#[derive(Debug)] +pub struct TryIoError(()); diff --git a/third_party/rust/tokio/src/io/async_read.rs b/third_party/rust/tokio/src/io/async_read.rs new file mode 100644 index 0000000000..93e5d3e66e --- /dev/null +++ b/third_party/rust/tokio/src/io/async_read.rs @@ -0,0 +1,131 @@ +use super::ReadBuf; +use std::io; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Reads bytes from a source. +/// +/// This trait is analogous to the [`std::io::Read`] trait, but integrates with +/// the asynchronous task system. In particular, the [`poll_read`] method, +/// unlike [`Read::read`], will automatically queue the current task for wakeup +/// and return if data is not yet available, rather than blocking the calling +/// thread. +/// +/// Specifically, this means that the `poll_read` function will return one of +/// the following: +/// +/// * `Poll::Ready(Ok(()))` means that data was immediately read and placed into +/// the output buffer. The amount of data read can be determined by the +/// increase in the length of the slice returned by `ReadBuf::filled`. If the +/// difference is 0, EOF has been reached. +/// +/// * `Poll::Pending` means that no data was read into the buffer +/// provided. The I/O object is not currently readable but may become readable +/// in the future. Most importantly, **the current future's task is scheduled +/// to get unparked when the object is readable**. This means that like +/// `Future::poll` you'll receive a notification when the I/O object is +/// readable again. +/// +/// * `Poll::Ready(Err(e))` for other errors are standard I/O errors coming from the +/// underlying object. +/// +/// This trait importantly means that the `read` method only works in the +/// context of a future's task. The object may panic if used outside of a task. +/// +/// Utilities for working with `AsyncRead` values are provided by +/// [`AsyncReadExt`]. +/// +/// [`poll_read`]: AsyncRead::poll_read +/// [`std::io::Read`]: std::io::Read +/// [`Read::read`]: std::io::Read::read +/// [`AsyncReadExt`]: crate::io::AsyncReadExt +pub trait AsyncRead { + /// Attempts to read from the `AsyncRead` into `buf`. + /// + /// On success, returns `Poll::Ready(Ok(()))` and places data in the + /// unfilled portion of `buf`. If no data was read (`buf.filled().len()` is + /// unchanged), it implies that EOF has been reached. + /// + /// If no data is available for reading, the method returns `Poll::Pending` + /// and arranges for the current task (via `cx.waker()`) to receive a + /// notification when the object becomes readable or is closed. + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>>; +} + +macro_rules! deref_async_read { + () => { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut **self).poll_read(cx, buf) + } + }; +} + +impl<T: ?Sized + AsyncRead + Unpin> AsyncRead for Box<T> { + deref_async_read!(); +} + +impl<T: ?Sized + AsyncRead + Unpin> AsyncRead for &mut T { + deref_async_read!(); +} + +impl<P> AsyncRead for Pin<P> +where + P: DerefMut + Unpin, + P::Target: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.get_mut().as_mut().poll_read(cx, buf) + } +} + +impl AsyncRead for &[u8] { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let amt = std::cmp::min(self.len(), buf.remaining()); + let (a, b) = self.split_at(amt); + buf.put_slice(a); + *self = b; + Poll::Ready(Ok(())) + } +} + +impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let pos = self.position(); + let slice: &[u8] = (*self).get_ref().as_ref(); + + // The position could technically be out of bounds, so don't panic... + if pos > slice.len() as u64 { + return Poll::Ready(Ok(())); + } + + let start = pos as usize; + let amt = std::cmp::min(slice.len() - start, buf.remaining()); + // Add won't overflow because of pos check above. + let end = start + amt; + buf.put_slice(&slice[start..end]); + self.set_position(end as u64); + + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio/src/io/async_seek.rs b/third_party/rust/tokio/src/io/async_seek.rs new file mode 100644 index 0000000000..bd7a992e4d --- /dev/null +++ b/third_party/rust/tokio/src/io/async_seek.rs @@ -0,0 +1,90 @@ +use std::io::{self, SeekFrom}; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Seek bytes asynchronously. +/// +/// This trait is analogous to the [`std::io::Seek`] trait, but integrates +/// with the asynchronous task system. In particular, the `start_seek` +/// method, unlike [`Seek::seek`], will not block the calling thread. +/// +/// Utilities for working with `AsyncSeek` values are provided by +/// [`AsyncSeekExt`]. +/// +/// [`std::io::Seek`]: std::io::Seek +/// [`Seek::seek`]: std::io::Seek::seek() +/// [`AsyncSeekExt`]: crate::io::AsyncSeekExt +pub trait AsyncSeek { + /// Attempts to seek to an offset, in bytes, in a stream. + /// + /// A seek beyond the end of a stream is allowed, but behavior is defined + /// by the implementation. + /// + /// If this function returns successfully, then the job has been submitted. + /// To find out when it completes, call `poll_complete`. + /// + /// # Errors + /// + /// This function can return [`io::ErrorKind::Other`] in case there is + /// another seek in progress. To avoid this, it is advisable that any call + /// to `start_seek` is preceded by a call to `poll_complete` to ensure all + /// pending seeks have completed. + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()>; + + /// Waits for a seek operation to complete. + /// + /// If the seek operation completed successfully, + /// this method returns the new position from the start of the stream. + /// That position can be used later with [`SeekFrom::Start`]. Repeatedly + /// calling this function without calling `start_seek` might return the + /// same result. + /// + /// # Errors + /// + /// Seeking to a negative offset is considered an error. + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>; +} + +macro_rules! deref_async_seek { + () => { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + Pin::new(&mut **self).start_seek(pos) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + Pin::new(&mut **self).poll_complete(cx) + } + }; +} + +impl<T: ?Sized + AsyncSeek + Unpin> AsyncSeek for Box<T> { + deref_async_seek!(); +} + +impl<T: ?Sized + AsyncSeek + Unpin> AsyncSeek for &mut T { + deref_async_seek!(); +} + +impl<P> AsyncSeek for Pin<P> +where + P: DerefMut + Unpin, + P::Target: AsyncSeek, +{ + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.get_mut().as_mut().start_seek(pos) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + self.get_mut().as_mut().poll_complete(cx) + } +} + +impl<T: AsRef<[u8]> + Unpin> AsyncSeek for io::Cursor<T> { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + io::Seek::seek(&mut *self, pos).map(drop) + } + fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> { + Poll::Ready(Ok(self.get_mut().position())) + } +} diff --git a/third_party/rust/tokio/src/io/async_write.rs b/third_party/rust/tokio/src/io/async_write.rs new file mode 100644 index 0000000000..7ec1a302ef --- /dev/null +++ b/third_party/rust/tokio/src/io/async_write.rs @@ -0,0 +1,408 @@ +use std::io::{self, IoSlice}; +use std::ops::DerefMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Writes bytes asynchronously. +/// +/// The trait inherits from [`std::io::Write`] and indicates that an I/O object is +/// **nonblocking**. All non-blocking I/O objects must return an error when +/// bytes cannot be written instead of blocking the current thread. +/// +/// Specifically, this means that the [`poll_write`] function will return one of +/// the following: +/// +/// * `Poll::Ready(Ok(n))` means that `n` bytes of data was immediately +/// written. +/// +/// * `Poll::Pending` means that no data was written from the buffer +/// provided. The I/O object is not currently writable but may become writable +/// in the future. Most importantly, **the current future's task is scheduled +/// to get unparked when the object is writable**. This means that like +/// `Future::poll` you'll receive a notification when the I/O object is +/// writable again. +/// +/// * `Poll::Ready(Err(e))` for other errors are standard I/O errors coming from the +/// underlying object. +/// +/// This trait importantly means that the [`write`][stdwrite] method only works in +/// the context of a future's task. The object may panic if used outside of a task. +/// +/// Note that this trait also represents that the [`Write::flush`][stdflush] method +/// works very similarly to the `write` method, notably that `Ok(())` means that the +/// writer has successfully been flushed, a "would block" error means that the +/// current task is ready to receive a notification when flushing can make more +/// progress, and otherwise normal errors can happen as well. +/// +/// Utilities for working with `AsyncWrite` values are provided by +/// [`AsyncWriteExt`]. +/// +/// [`std::io::Write`]: std::io::Write +/// [`poll_write`]: AsyncWrite::poll_write() +/// [stdwrite]: std::io::Write::write() +/// [stdflush]: std::io::Write::flush() +/// [`AsyncWriteExt`]: crate::io::AsyncWriteExt +pub trait AsyncWrite { + /// Attempt to write bytes from `buf` into the object. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. If successful, + /// then it must be guaranteed that `n <= buf.len()`. A return value of `0` + /// typically means that the underlying object is no longer able to accept + /// bytes and will likely not be able to in the future as well, or that the + /// buffer provided is empty. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker()`) to receive a notification when the object becomes + /// writable or is closed. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>>; + + /// Attempts to flush the object, ensuring that any buffered data reach + /// their destination. + /// + /// On success, returns `Poll::Ready(Ok(()))`. + /// + /// If flushing cannot immediately complete, this method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker()`) to receive a notification when the object can make + /// progress towards flushing. + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>; + + /// Initiates or attempts to shut down this writer, returning success when + /// the I/O connection has completely shut down. + /// + /// This method is intended to be used for asynchronous shutdown of I/O + /// connections. For example this is suitable for implementing shutdown of a + /// TLS connection or calling `TcpStream::shutdown` on a proxied connection. + /// Protocols sometimes need to flush out final pieces of data or otherwise + /// perform a graceful shutdown handshake, reading/writing more data as + /// appropriate. This method is the hook for such protocols to implement the + /// graceful shutdown logic. + /// + /// This `shutdown` method is required by implementers of the + /// `AsyncWrite` trait. Wrappers typically just want to proxy this call + /// through to the wrapped type, and base types will typically implement + /// shutdown logic here or just return `Ok(().into())`. Note that if you're + /// wrapping an underlying `AsyncWrite` a call to `shutdown` implies that + /// transitively the entire stream has been shut down. After your wrapper's + /// shutdown logic has been executed you should shut down the underlying + /// stream. + /// + /// Invocation of a `shutdown` implies an invocation of `flush`. Once this + /// method returns `Ready` it implies that a flush successfully happened + /// before the shutdown happened. That is, callers don't need to call + /// `flush` before calling `shutdown`. They can rely that by calling + /// `shutdown` any pending buffered data will be written out. + /// + /// # Return value + /// + /// This function returns a `Poll<io::Result<()>>` classified as such: + /// + /// * `Poll::Ready(Ok(()))` - indicates that the connection was + /// successfully shut down and is now safe to deallocate/drop/close + /// resources associated with it. This method means that the current task + /// will no longer receive any notifications due to this method and the + /// I/O object itself is likely no longer usable. + /// + /// * `Poll::Pending` - indicates that shutdown is initiated but could + /// not complete just yet. This may mean that more I/O needs to happen to + /// continue this shutdown operation. The current task is scheduled to + /// receive a notification when it's otherwise ready to continue the + /// shutdown operation. When woken up this method should be called again. + /// + /// * `Poll::Ready(Err(e))` - indicates a fatal error has happened with shutdown, + /// indicating that the shutdown operation did not complete successfully. + /// This typically means that the I/O object is no longer usable. + /// + /// # Errors + /// + /// This function can return normal I/O errors through `Err`, described + /// above. Additionally this method may also render the underlying + /// `Write::write` method no longer usable (e.g. will return errors in the + /// future). It's recommended that once `shutdown` is called the + /// `write` method is no longer called. + /// + /// # Panics + /// + /// This function will panic if not called within the context of a future's + /// task. + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>; + + /// Like [`poll_write`], except that it writes from a slice of buffers. + /// + /// Data is copied from each buffer in order, with the final buffer + /// read from possibly being only partially consumed. This method must + /// behave as a call to [`write`] with the buffers concatenated would. + /// + /// The default implementation calls [`poll_write`] with either the first nonempty + /// buffer provided, or an empty one if none exists. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker()`) to receive a notification when the object becomes + /// writable or is closed. + /// + /// # Note + /// + /// This should be implemented as a single "atomic" write action. If any + /// data has been partially written, it is wrong to return an error or + /// pending. + /// + /// [`poll_write`]: AsyncWrite::poll_write + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<Result<usize, io::Error>> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } + + /// Determines if this writer has an efficient [`poll_write_vectored`] + /// implementation. + /// + /// If a writer does not override the default [`poll_write_vectored`] + /// implementation, code using it may want to avoid the method all together + /// and coalesce writes into a single buffer for higher performance. + /// + /// The default implementation returns `false`. + /// + /// [`poll_write_vectored`]: AsyncWrite::poll_write_vectored + fn is_write_vectored(&self) -> bool { + false + } +} + +macro_rules! deref_async_write { + () => { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut **self).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut **self).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut **self).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut **self).poll_shutdown(cx) + } + }; +} + +impl<T: ?Sized + AsyncWrite + Unpin> AsyncWrite for Box<T> { + deref_async_write!(); +} + +impl<T: ?Sized + AsyncWrite + Unpin> AsyncWrite for &mut T { + deref_async_write!(); +} + +impl<P> AsyncWrite for Pin<P> +where + P: DerefMut + Unpin, + P::Target: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.get_mut().as_mut().poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.get_mut().as_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.get_mut().as_mut().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.get_mut().as_mut().poll_shutdown(cx) + } +} + +impl AsyncWrite for Vec<u8> { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.get_mut().extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for io::Cursor<&mut [u8]> { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write(&mut *self, buf)) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(io::Write::flush(&mut *self)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsyncWrite for io::Cursor<&mut Vec<u8>> { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write(&mut *self, buf)) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(io::Write::flush(&mut *self)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsyncWrite for io::Cursor<Vec<u8>> { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write(&mut *self, buf)) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(io::Write::flush(&mut *self)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsyncWrite for io::Cursor<Box<[u8]>> { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write(&mut *self, buf)) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(io::Write::flush(&mut *self)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} diff --git a/third_party/rust/tokio/src/io/blocking.rs b/third_party/rust/tokio/src/io/blocking.rs new file mode 100644 index 0000000000..1d79ee7a27 --- /dev/null +++ b/third_party/rust/tokio/src/io/blocking.rs @@ -0,0 +1,279 @@ +use crate::io::sys; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::cmp; +use std::future::Future; +use std::io; +use std::io::prelude::*; +use std::pin::Pin; +use std::task::Poll::*; +use std::task::{Context, Poll}; + +use self::State::*; + +/// `T` should not implement _both_ Read and Write. +#[derive(Debug)] +pub(crate) struct Blocking<T> { + inner: Option<T>, + state: State<T>, + /// `true` if the lower IO layer needs flushing. + need_flush: bool, +} + +#[derive(Debug)] +pub(crate) struct Buf { + buf: Vec<u8>, + pos: usize, +} + +pub(crate) const MAX_BUF: usize = 16 * 1024; + +#[derive(Debug)] +enum State<T> { + Idle(Option<Buf>), + Busy(sys::Blocking<(io::Result<usize>, Buf, T)>), +} + +cfg_io_std! { + impl<T> Blocking<T> { + pub(crate) fn new(inner: T) -> Blocking<T> { + Blocking { + inner: Some(inner), + state: State::Idle(Some(Buf::with_capacity(0))), + need_flush: false, + } + } + } +} + +impl<T> AsyncRead for Blocking<T> +where + T: Read + Unpin + Send + 'static, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + loop { + match self.state { + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap(); + + if !buf.is_empty() { + buf.copy_to(dst); + *buf_cell = Some(buf); + return Ready(Ok(())); + } + + buf.ensure_capacity_for(dst); + let mut inner = self.inner.take().unwrap(); + + self.state = Busy(sys::run(move || { + let res = buf.read_from(&mut inner); + (res, buf, inner) + })); + } + Busy(ref mut rx) => { + let (res, mut buf, inner) = ready!(Pin::new(rx).poll(cx))?; + self.inner = Some(inner); + + match res { + Ok(_) => { + buf.copy_to(dst); + self.state = Idle(Some(buf)); + return Ready(Ok(())); + } + Err(e) => { + assert!(buf.is_empty()); + + self.state = Idle(Some(buf)); + return Ready(Err(e)); + } + } + } + } + } + } +} + +impl<T> AsyncWrite for Blocking<T> +where + T: Write + Unpin + Send + 'static, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + src: &[u8], + ) -> Poll<io::Result<usize>> { + loop { + match self.state { + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap(); + + assert!(buf.is_empty()); + + let n = buf.copy_from(src); + let mut inner = self.inner.take().unwrap(); + + self.state = Busy(sys::run(move || { + let n = buf.len(); + let res = buf.write_to(&mut inner).map(|_| n); + + (res, buf, inner) + })); + self.need_flush = true; + + return Ready(Ok(n)); + } + Busy(ref mut rx) => { + let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?; + self.state = Idle(Some(buf)); + self.inner = Some(inner); + + // If error, return + res?; + } + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + loop { + let need_flush = self.need_flush; + match self.state { + // The buffer is not used here + Idle(ref mut buf_cell) => { + if need_flush { + let buf = buf_cell.take().unwrap(); + let mut inner = self.inner.take().unwrap(); + + self.state = Busy(sys::run(move || { + let res = inner.flush().map(|_| 0); + (res, buf, inner) + })); + + self.need_flush = false; + } else { + return Ready(Ok(())); + } + } + Busy(ref mut rx) => { + let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?; + self.state = Idle(Some(buf)); + self.inner = Some(inner); + + // If error, return + res?; + } + } + } + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } +} + +/// Repeats operations that are interrupted. +macro_rules! uninterruptibly { + ($e:expr) => {{ + loop { + match $e { + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + res => break res, + } + } + }}; +} + +impl Buf { + pub(crate) fn with_capacity(n: usize) -> Buf { + Buf { + buf: Vec::with_capacity(n), + pos: 0, + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub(crate) fn len(&self) -> usize { + self.buf.len() - self.pos + } + + pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize { + let n = cmp::min(self.len(), dst.remaining()); + dst.put_slice(&self.bytes()[..n]); + self.pos += n; + + if self.pos == self.buf.len() { + self.buf.truncate(0); + self.pos = 0; + } + + n + } + + pub(crate) fn copy_from(&mut self, src: &[u8]) -> usize { + assert!(self.is_empty()); + + let n = cmp::min(src.len(), MAX_BUF); + + self.buf.extend_from_slice(&src[..n]); + n + } + + pub(crate) fn bytes(&self) -> &[u8] { + &self.buf[self.pos..] + } + + pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) { + assert!(self.is_empty()); + + let len = cmp::min(bytes.remaining(), MAX_BUF); + + if self.buf.len() < len { + self.buf.reserve(len - self.buf.len()); + } + + unsafe { + self.buf.set_len(len); + } + } + + pub(crate) fn read_from<T: Read>(&mut self, rd: &mut T) -> io::Result<usize> { + let res = uninterruptibly!(rd.read(&mut self.buf)); + + if let Ok(n) = res { + self.buf.truncate(n); + } else { + self.buf.clear(); + } + + assert_eq!(self.pos, 0); + + res + } + + pub(crate) fn write_to<T: Write>(&mut self, wr: &mut T) -> io::Result<()> { + assert_eq!(self.pos, 0); + + // `write_all` already ignores interrupts + let res = wr.write_all(&self.buf); + self.buf.clear(); + res + } +} + +cfg_fs! { + impl Buf { + pub(crate) fn discard_read(&mut self) -> i64 { + let ret = -(self.bytes().len() as i64); + self.pos = 0; + self.buf.truncate(0); + ret + } + } +} diff --git a/third_party/rust/tokio/src/io/bsd/poll_aio.rs b/third_party/rust/tokio/src/io/bsd/poll_aio.rs new file mode 100644 index 0000000000..f1ac4b2d77 --- /dev/null +++ b/third_party/rust/tokio/src/io/bsd/poll_aio.rs @@ -0,0 +1,195 @@ +//! Use POSIX AIO futures with Tokio. + +use crate::io::driver::{Handle, Interest, ReadyEvent, Registration}; +use mio::event::Source; +use mio::Registry; +use mio::Token; +use std::fmt; +use std::io; +use std::ops::{Deref, DerefMut}; +use std::os::unix::io::AsRawFd; +use std::os::unix::prelude::RawFd; +use std::task::{Context, Poll}; + +/// Like [`mio::event::Source`], but for POSIX AIO only. +/// +/// Tokio's consumer must pass an implementor of this trait to create a +/// [`Aio`] object. +pub trait AioSource { + /// Registers this AIO event source with Tokio's reactor. + fn register(&mut self, kq: RawFd, token: usize); + + /// Deregisters this AIO event source with Tokio's reactor. + fn deregister(&mut self); +} + +/// Wraps the user's AioSource in order to implement mio::event::Source, which +/// is what the rest of the crate wants. +struct MioSource<T>(T); + +impl<T: AioSource> Source for MioSource<T> { + fn register( + &mut self, + registry: &Registry, + token: Token, + interests: mio::Interest, + ) -> io::Result<()> { + assert!(interests.is_aio() || interests.is_lio()); + self.0.register(registry.as_raw_fd(), usize::from(token)); + Ok(()) + } + + fn deregister(&mut self, _registry: &Registry) -> io::Result<()> { + self.0.deregister(); + Ok(()) + } + + fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: mio::Interest, + ) -> io::Result<()> { + assert!(interests.is_aio() || interests.is_lio()); + self.0.register(registry.as_raw_fd(), usize::from(token)); + Ok(()) + } +} + +/// Associates a POSIX AIO control block with the reactor that drives it. +/// +/// `Aio`'s wrapped type must implement [`AioSource`] to be driven +/// by the reactor. +/// +/// The wrapped source may be accessed through the `Aio` via the `Deref` and +/// `DerefMut` traits. +/// +/// ## Clearing readiness +/// +/// If [`Aio::poll_ready`] returns ready, but the consumer determines that the +/// Source is not completely ready and must return to the Pending state, +/// [`Aio::clear_ready`] may be used. This can be useful with +/// [`lio_listio`], which may generate a kevent when only a portion of the +/// operations have completed. +/// +/// ## Platforms +/// +/// Only FreeBSD implements POSIX AIO with kqueue notification, so +/// `Aio` is only available for that operating system. +/// +/// [`lio_listio`]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/lio_listio.html +// Note: Unlike every other kqueue event source, POSIX AIO registers events not +// via kevent(2) but when the aiocb is submitted to the kernel via aio_read, +// aio_write, etc. It needs the kqueue's file descriptor to do that. So +// AsyncFd can't be used for POSIX AIO. +// +// Note that Aio doesn't implement Drop. There's no need. Unlike other +// kqueue sources, simply dropping the object effectively deregisters it. +pub struct Aio<E> { + io: MioSource<E>, + registration: Registration, +} + +// ===== impl Aio ===== + +impl<E: AioSource> Aio<E> { + /// Creates a new `Aio` suitable for use with POSIX AIO functions. + /// + /// It will be associated with the default reactor. The runtime is usually + /// set implicitly when this function is called from a future driven by a + /// Tokio runtime, otherwise runtime can be set explicitly with + /// [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn new_for_aio(io: E) -> io::Result<Self> { + Self::new_with_interest(io, Interest::AIO) + } + + /// Creates a new `Aio` suitable for use with [`lio_listio`]. + /// + /// It will be associated with the default reactor. The runtime is usually + /// set implicitly when this function is called from a future driven by a + /// Tokio runtime, otherwise runtime can be set explicitly with + /// [`Runtime::enter`](crate::runtime::Runtime::enter) function. + /// + /// [`lio_listio`]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/lio_listio.html + pub fn new_for_lio(io: E) -> io::Result<Self> { + Self::new_with_interest(io, Interest::LIO) + } + + fn new_with_interest(io: E, interest: Interest) -> io::Result<Self> { + let mut io = MioSource(io); + let handle = Handle::current(); + let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?; + Ok(Self { io, registration }) + } + + /// Indicates to Tokio that the source is no longer ready. The internal + /// readiness flag will be cleared, and tokio will wait for the next + /// edge-triggered readiness notification from the OS. + /// + /// It is critical that this method not be called unless your code + /// _actually observes_ that the source is _not_ ready. The OS must + /// deliver a subsequent notification, or this source will block + /// forever. It is equally critical that you `do` call this method if you + /// resubmit the same structure to the kernel and poll it again. + /// + /// This method is not very useful with AIO readiness, since each `aiocb` + /// structure is typically only used once. It's main use with + /// [`lio_listio`], which will sometimes send notification when only a + /// portion of its elements are complete. In that case, the caller must + /// call `clear_ready` before resubmitting it. + /// + /// [`lio_listio`]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/lio_listio.html + pub fn clear_ready(&self, ev: AioEvent) { + self.registration.clear_readiness(ev.0) + } + + /// Destroy the [`Aio`] and return its inner source. + pub fn into_inner(self) -> E { + self.io.0 + } + + /// Polls for readiness. Either AIO or LIO counts. + /// + /// This method returns: + /// * `Poll::Pending` if the underlying operation is not complete, whether + /// or not it completed successfully. This will be true if the OS is + /// still processing it, or if it has not yet been submitted to the OS. + /// * `Poll::Ready(Ok(_))` if the underlying operation is complete. + /// * `Poll::Ready(Err(_))` if the reactor has been shutdown. This does + /// _not_ indicate that the underlying operation encountered an error. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` + /// is scheduled to receive a wakeup when the underlying operation + /// completes. Note that on multiple calls to `poll_ready`, only the `Waker` from the + /// `Context` passed to the most recent call is scheduled to receive a wakeup. + pub fn poll_ready<'a>(&'a self, cx: &mut Context<'_>) -> Poll<io::Result<AioEvent>> { + let ev = ready!(self.registration.poll_read_ready(cx))?; + Poll::Ready(Ok(AioEvent(ev))) + } +} + +impl<E: AioSource> Deref for Aio<E> { + type Target = E; + + fn deref(&self) -> &E { + &self.io.0 + } +} + +impl<E: AioSource> DerefMut for Aio<E> { + fn deref_mut(&mut self) -> &mut E { + &mut self.io.0 + } +} + +impl<E: AioSource + fmt::Debug> fmt::Debug for Aio<E> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Aio").field("io", &self.io.0).finish() + } +} + +/// Opaque data returned by [`Aio::poll_ready`]. +/// +/// It can be fed back to [`Aio::clear_ready`]. +#[derive(Debug)] +pub struct AioEvent(ReadyEvent); diff --git a/third_party/rust/tokio/src/io/driver/interest.rs b/third_party/rust/tokio/src/io/driver/interest.rs new file mode 100644 index 0000000000..d6b46dfb7c --- /dev/null +++ b/third_party/rust/tokio/src/io/driver/interest.rs @@ -0,0 +1,132 @@ +#![cfg_attr(not(feature = "net"), allow(dead_code, unreachable_pub))] + +use crate::io::driver::Ready; + +use std::fmt; +use std::ops; + +/// Readiness event interest. +/// +/// Specifies the readiness events the caller is interested in when awaiting on +/// I/O resource readiness states. +#[cfg_attr(docsrs, doc(cfg(feature = "net")))] +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct Interest(mio::Interest); + +impl Interest { + // The non-FreeBSD definitions in this block are active only when + // building documentation. + cfg_aio! { + /// Interest for POSIX AIO. + #[cfg(target_os = "freebsd")] + pub const AIO: Interest = Interest(mio::Interest::AIO); + + /// Interest for POSIX AIO. + #[cfg(not(target_os = "freebsd"))] + pub const AIO: Interest = Interest(mio::Interest::READABLE); + + /// Interest for POSIX AIO lio_listio events. + #[cfg(target_os = "freebsd")] + pub const LIO: Interest = Interest(mio::Interest::LIO); + + /// Interest for POSIX AIO lio_listio events. + #[cfg(not(target_os = "freebsd"))] + pub const LIO: Interest = Interest(mio::Interest::READABLE); + } + + /// Interest in all readable events. + /// + /// Readable interest includes read-closed events. + pub const READABLE: Interest = Interest(mio::Interest::READABLE); + + /// Interest in all writable events. + /// + /// Writable interest includes write-closed events. + pub const WRITABLE: Interest = Interest(mio::Interest::WRITABLE); + + /// Returns true if the value includes readable interest. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Interest; + /// + /// assert!(Interest::READABLE.is_readable()); + /// assert!(!Interest::WRITABLE.is_readable()); + /// + /// let both = Interest::READABLE | Interest::WRITABLE; + /// assert!(both.is_readable()); + /// ``` + pub const fn is_readable(self) -> bool { + self.0.is_readable() + } + + /// Returns true if the value includes writable interest. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Interest; + /// + /// assert!(!Interest::READABLE.is_writable()); + /// assert!(Interest::WRITABLE.is_writable()); + /// + /// let both = Interest::READABLE | Interest::WRITABLE; + /// assert!(both.is_writable()); + /// ``` + pub const fn is_writable(self) -> bool { + self.0.is_writable() + } + + /// Add together two `Interest` values. + /// + /// This function works from a `const` context. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Interest; + /// + /// const BOTH: Interest = Interest::READABLE.add(Interest::WRITABLE); + /// + /// assert!(BOTH.is_readable()); + /// assert!(BOTH.is_writable()); + pub const fn add(self, other: Interest) -> Interest { + Interest(self.0.add(other.0)) + } + + // This function must be crate-private to avoid exposing a `mio` dependency. + pub(crate) const fn to_mio(self) -> mio::Interest { + self.0 + } + + pub(super) fn mask(self) -> Ready { + match self { + Interest::READABLE => Ready::READABLE | Ready::READ_CLOSED, + Interest::WRITABLE => Ready::WRITABLE | Ready::WRITE_CLOSED, + _ => Ready::EMPTY, + } + } +} + +impl ops::BitOr for Interest { + type Output = Self; + + #[inline] + fn bitor(self, other: Self) -> Self { + self.add(other) + } +} + +impl ops::BitOrAssign for Interest { + #[inline] + fn bitor_assign(&mut self, other: Self) { + self.0 = (*self | other).0; + } +} + +impl fmt::Debug for Interest { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/io/driver/mod.rs b/third_party/rust/tokio/src/io/driver/mod.rs new file mode 100644 index 0000000000..19f67a24e7 --- /dev/null +++ b/third_party/rust/tokio/src/io/driver/mod.rs @@ -0,0 +1,354 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +mod interest; +#[allow(unreachable_pub)] +pub use interest::Interest; + +mod ready; +#[allow(unreachable_pub)] +pub use ready::Ready; + +mod registration; +pub(crate) use registration::Registration; + +mod scheduled_io; +use scheduled_io::ScheduledIo; + +use crate::park::{Park, Unpark}; +use crate::util::slab::{self, Slab}; +use crate::{loom::sync::Mutex, util::bit}; + +use std::fmt; +use std::io; +use std::sync::{Arc, Weak}; +use std::time::Duration; + +/// I/O driver, backed by Mio. +pub(crate) struct Driver { + /// Tracks the number of times `turn` is called. It is safe for this to wrap + /// as it is mostly used to determine when to call `compact()`. + tick: u8, + + /// Reuse the `mio::Events` value across calls to poll. + events: Option<mio::Events>, + + /// Primary slab handle containing the state for each resource registered + /// with this driver. During Drop this is moved into the Inner structure, so + /// this is an Option to allow it to be vacated (until Drop this is always + /// Some). + resources: Option<Slab<ScheduledIo>>, + + /// The system event queue. + poll: mio::Poll, + + /// State shared between the reactor and the handles. + inner: Arc<Inner>, +} + +/// A reference to an I/O driver. +#[derive(Clone)] +pub(crate) struct Handle { + inner: Weak<Inner>, +} + +#[derive(Debug)] +pub(crate) struct ReadyEvent { + tick: u8, + pub(crate) ready: Ready, +} + +pub(super) struct Inner { + /// Primary slab handle containing the state for each resource registered + /// with this driver. + /// + /// The ownership of this slab is moved into this structure during + /// `Driver::drop`, so that `Inner::drop` can notify all outstanding handles + /// without risking new ones being registered in the meantime. + resources: Mutex<Option<Slab<ScheduledIo>>>, + + /// Registers I/O resources. + registry: mio::Registry, + + /// Allocates `ScheduledIo` handles when creating new resources. + pub(super) io_dispatch: slab::Allocator<ScheduledIo>, + + /// Used to wake up the reactor from a call to `turn`. + waker: mio::Waker, +} + +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +enum Direction { + Read, + Write, +} + +enum Tick { + Set(u8), + Clear(u8), +} + +// TODO: Don't use a fake token. Instead, reserve a slot entry for the wakeup +// token. +const TOKEN_WAKEUP: mio::Token = mio::Token(1 << 31); + +const ADDRESS: bit::Pack = bit::Pack::least_significant(24); + +// Packs the generation value in the `readiness` field. +// +// The generation prevents a race condition where a slab slot is reused for a +// new socket while the I/O driver is about to apply a readiness event. The +// generation value is checked when setting new readiness. If the generation do +// not match, then the readiness event is discarded. +const GENERATION: bit::Pack = ADDRESS.then(7); + +fn _assert_kinds() { + fn _assert<T: Send + Sync>() {} + + _assert::<Handle>(); +} + +// ===== impl Driver ===== + +impl Driver { + /// Creates a new event loop, returning any error that happened during the + /// creation. + pub(crate) fn new() -> io::Result<Driver> { + let poll = mio::Poll::new()?; + let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?; + let registry = poll.registry().try_clone()?; + + let slab = Slab::new(); + let allocator = slab.allocator(); + + Ok(Driver { + tick: 0, + events: Some(mio::Events::with_capacity(1024)), + poll, + resources: Some(slab), + inner: Arc::new(Inner { + resources: Mutex::new(None), + registry, + io_dispatch: allocator, + waker, + }), + }) + } + + /// Returns a handle to this event loop which can be sent across threads + /// and can be used as a proxy to the event loop itself. + /// + /// Handles are cloneable and clones always refer to the same event loop. + /// This handle is typically passed into functions that create I/O objects + /// to bind them to this event loop. + pub(crate) fn handle(&self) -> Handle { + Handle { + inner: Arc::downgrade(&self.inner), + } + } + + fn turn(&mut self, max_wait: Option<Duration>) -> io::Result<()> { + // How often to call `compact()` on the resource slab + const COMPACT_INTERVAL: u8 = 255; + + self.tick = self.tick.wrapping_add(1); + + if self.tick == COMPACT_INTERVAL { + self.resources.as_mut().unwrap().compact() + } + + let mut events = self.events.take().expect("i/o driver event store missing"); + + // Block waiting for an event to happen, peeling out how many events + // happened. + match self.poll.poll(&mut events, max_wait) { + Ok(_) => {} + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + + // Process all the events that came in, dispatching appropriately + for event in events.iter() { + let token = event.token(); + + if token != TOKEN_WAKEUP { + self.dispatch(token, Ready::from_mio(event)); + } + } + + self.events = Some(events); + + Ok(()) + } + + fn dispatch(&mut self, token: mio::Token, ready: Ready) { + let addr = slab::Address::from_usize(ADDRESS.unpack(token.0)); + + let resources = self.resources.as_mut().unwrap(); + + let io = match resources.get(addr) { + Some(io) => io, + None => return, + }; + + let res = io.set_readiness(Some(token.0), Tick::Set(self.tick), |curr| curr | ready); + + if res.is_err() { + // token no longer valid! + return; + } + + io.wake(ready); + } +} + +impl Drop for Driver { + fn drop(&mut self) { + (*self.inner.resources.lock()) = self.resources.take(); + } +} + +impl Drop for Inner { + fn drop(&mut self) { + let resources = self.resources.lock().take(); + + if let Some(mut slab) = resources { + slab.for_each(|io| { + // If a task is waiting on the I/O resource, notify it. The task + // will then attempt to use the I/O resource and fail due to the + // driver being shutdown. + io.shutdown(); + }); + } + } +} + +impl Park for Driver { + type Unpark = Handle; + type Error = io::Error; + + fn unpark(&self) -> Self::Unpark { + self.handle() + } + + fn park(&mut self) -> io::Result<()> { + self.turn(None)?; + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> io::Result<()> { + self.turn(Some(duration))?; + Ok(()) + } + + fn shutdown(&mut self) {} +} + +impl fmt::Debug for Driver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Driver") + } +} + +// ===== impl Handle ===== + +cfg_rt! { + impl Handle { + /// Returns a handle to the current reactor. + /// + /// # Panics + /// + /// This function panics if there is no current reactor set and `rt` feature + /// flag is not enabled. + pub(super) fn current() -> Self { + crate::runtime::context::io_handle().expect("A Tokio 1.x context was found, but IO is disabled. Call `enable_io` on the runtime builder to enable IO.") + } + } +} + +cfg_not_rt! { + impl Handle { + /// Returns a handle to the current reactor. + /// + /// # Panics + /// + /// This function panics if there is no current reactor set, or if the `rt` + /// feature flag is not enabled. + pub(super) fn current() -> Self { + panic!("{}", crate::util::error::CONTEXT_MISSING_ERROR) + } + } +} + +impl Handle { + /// Forces a reactor blocked in a call to `turn` to wakeup, or otherwise + /// makes the next call to `turn` return immediately. + /// + /// This method is intended to be used in situations where a notification + /// needs to otherwise be sent to the main reactor. If the reactor is + /// currently blocked inside of `turn` then it will wake up and soon return + /// after this method has been called. If the reactor is not currently + /// blocked in `turn`, then the next call to `turn` will not block and + /// return immediately. + fn wakeup(&self) { + if let Some(inner) = self.inner() { + inner.waker.wake().expect("failed to wake I/O driver"); + } + } + + pub(super) fn inner(&self) -> Option<Arc<Inner>> { + self.inner.upgrade() + } +} + +impl Unpark for Handle { + fn unpark(&self) { + self.wakeup(); + } +} + +impl fmt::Debug for Handle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Handle") + } +} + +// ===== impl Inner ===== + +impl Inner { + /// Registers an I/O resource with the reactor for a given `mio::Ready` state. + /// + /// The registration token is returned. + pub(super) fn add_source( + &self, + source: &mut impl mio::event::Source, + interest: Interest, + ) -> io::Result<slab::Ref<ScheduledIo>> { + let (address, shared) = self.io_dispatch.allocate().ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "reactor at max registered I/O resources", + ) + })?; + + let token = GENERATION.pack(shared.generation(), ADDRESS.pack(address.as_usize(), 0)); + + self.registry + .register(source, mio::Token(token), interest.to_mio())?; + + Ok(shared) + } + + /// Deregisters an I/O resource from the reactor. + pub(super) fn deregister_source(&self, source: &mut impl mio::event::Source) -> io::Result<()> { + self.registry.deregister(source) + } +} + +impl Direction { + pub(super) fn mask(self) -> Ready { + match self { + Direction::Read => Ready::READABLE | Ready::READ_CLOSED, + Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED, + } + } +} diff --git a/third_party/rust/tokio/src/io/driver/platform.rs b/third_party/rust/tokio/src/io/driver/platform.rs new file mode 100644 index 0000000000..6b27988ce6 --- /dev/null +++ b/third_party/rust/tokio/src/io/driver/platform.rs @@ -0,0 +1,44 @@ +pub(crate) use self::sys::*; + +#[cfg(unix)] +mod sys { + use mio::unix::UnixReady; + use mio::Ready; + + pub(crate) fn hup() -> Ready { + UnixReady::hup().into() + } + + pub(crate) fn is_hup(ready: Ready) -> bool { + UnixReady::from(ready).is_hup() + } + + pub(crate) fn error() -> Ready { + UnixReady::error().into() + } + + pub(crate) fn is_error(ready: Ready) -> bool { + UnixReady::from(ready).is_error() + } +} + +#[cfg(windows)] +mod sys { + use mio::Ready; + + pub(crate) fn hup() -> Ready { + Ready::empty() + } + + pub(crate) fn is_hup(_: Ready) -> bool { + false + } + + pub(crate) fn error() -> Ready { + Ready::empty() + } + + pub(crate) fn is_error(_: Ready) -> bool { + false + } +} diff --git a/third_party/rust/tokio/src/io/driver/ready.rs b/third_party/rust/tokio/src/io/driver/ready.rs new file mode 100644 index 0000000000..2430d3022f --- /dev/null +++ b/third_party/rust/tokio/src/io/driver/ready.rs @@ -0,0 +1,250 @@ +#![cfg_attr(not(feature = "net"), allow(unreachable_pub))] + +use std::fmt; +use std::ops; + +const READABLE: usize = 0b0_01; +const WRITABLE: usize = 0b0_10; +const READ_CLOSED: usize = 0b0_0100; +const WRITE_CLOSED: usize = 0b0_1000; + +/// Describes the readiness state of an I/O resources. +/// +/// `Ready` tracks which operation an I/O resource is ready to perform. +#[cfg_attr(docsrs, doc(cfg(feature = "net")))] +#[derive(Clone, Copy, PartialEq, PartialOrd)] +pub struct Ready(usize); + +impl Ready { + /// Returns the empty `Ready` set. + pub const EMPTY: Ready = Ready(0); + + /// Returns a `Ready` representing readable readiness. + pub const READABLE: Ready = Ready(READABLE); + + /// Returns a `Ready` representing writable readiness. + pub const WRITABLE: Ready = Ready(WRITABLE); + + /// Returns a `Ready` representing read closed readiness. + pub const READ_CLOSED: Ready = Ready(READ_CLOSED); + + /// Returns a `Ready` representing write closed readiness. + pub const WRITE_CLOSED: Ready = Ready(WRITE_CLOSED); + + /// Returns a `Ready` representing readiness for all operations. + pub const ALL: Ready = Ready(READABLE | WRITABLE | READ_CLOSED | WRITE_CLOSED); + + // Must remain crate-private to avoid adding a public dependency on Mio. + pub(crate) fn from_mio(event: &mio::event::Event) -> Ready { + let mut ready = Ready::EMPTY; + + #[cfg(all(target_os = "freebsd", feature = "net"))] + { + if event.is_aio() { + ready |= Ready::READABLE; + } + + if event.is_lio() { + ready |= Ready::READABLE; + } + } + + if event.is_readable() { + ready |= Ready::READABLE; + } + + if event.is_writable() { + ready |= Ready::WRITABLE; + } + + if event.is_read_closed() { + ready |= Ready::READ_CLOSED; + } + + if event.is_write_closed() { + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + /// Returns true if `Ready` is the empty set. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Ready; + /// + /// assert!(Ready::EMPTY.is_empty()); + /// assert!(!Ready::READABLE.is_empty()); + /// ``` + pub fn is_empty(self) -> bool { + self == Ready::EMPTY + } + + /// Returns `true` if the value includes `readable`. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Ready; + /// + /// assert!(!Ready::EMPTY.is_readable()); + /// assert!(Ready::READABLE.is_readable()); + /// assert!(Ready::READ_CLOSED.is_readable()); + /// assert!(!Ready::WRITABLE.is_readable()); + /// ``` + pub fn is_readable(self) -> bool { + self.contains(Ready::READABLE) || self.is_read_closed() + } + + /// Returns `true` if the value includes writable `readiness`. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Ready; + /// + /// assert!(!Ready::EMPTY.is_writable()); + /// assert!(!Ready::READABLE.is_writable()); + /// assert!(Ready::WRITABLE.is_writable()); + /// assert!(Ready::WRITE_CLOSED.is_writable()); + /// ``` + pub fn is_writable(self) -> bool { + self.contains(Ready::WRITABLE) || self.is_write_closed() + } + + /// Returns `true` if the value includes read-closed `readiness`. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Ready; + /// + /// assert!(!Ready::EMPTY.is_read_closed()); + /// assert!(!Ready::READABLE.is_read_closed()); + /// assert!(Ready::READ_CLOSED.is_read_closed()); + /// ``` + pub fn is_read_closed(self) -> bool { + self.contains(Ready::READ_CLOSED) + } + + /// Returns `true` if the value includes write-closed `readiness`. + /// + /// # Examples + /// + /// ``` + /// use tokio::io::Ready; + /// + /// assert!(!Ready::EMPTY.is_write_closed()); + /// assert!(!Ready::WRITABLE.is_write_closed()); + /// assert!(Ready::WRITE_CLOSED.is_write_closed()); + /// ``` + pub fn is_write_closed(self) -> bool { + self.contains(Ready::WRITE_CLOSED) + } + + /// Returns true if `self` is a superset of `other`. + /// + /// `other` may represent more than one readiness operations, in which case + /// the function only returns true if `self` contains all readiness + /// specified in `other`. + pub(crate) fn contains<T: Into<Self>>(self, other: T) -> bool { + let other = other.into(); + (self & other) == other + } + + /// Creates a `Ready` instance using the given `usize` representation. + /// + /// The `usize` representation must have been obtained from a call to + /// `Readiness::as_usize`. + /// + /// This function is mainly provided to allow the caller to get a + /// readiness value from an `AtomicUsize`. + pub(crate) fn from_usize(val: usize) -> Ready { + Ready(val & Ready::ALL.as_usize()) + } + + /// Returns a `usize` representation of the `Ready` value. + /// + /// This function is mainly provided to allow the caller to store a + /// readiness value in an `AtomicUsize`. + pub(crate) fn as_usize(self) -> usize { + self.0 + } +} + +cfg_io_readiness! { + use crate::io::Interest; + + impl Ready { + pub(crate) fn from_interest(interest: Interest) -> Ready { + let mut ready = Ready::EMPTY; + + if interest.is_readable() { + ready |= Ready::READABLE; + ready |= Ready::READ_CLOSED; + } + + if interest.is_writable() { + ready |= Ready::WRITABLE; + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + pub(crate) fn intersection(self, interest: Interest) -> Ready { + Ready(self.0 & Ready::from_interest(interest).0) + } + + pub(crate) fn satisfies(self, interest: Interest) -> bool { + self.0 & Ready::from_interest(interest).0 != 0 + } + } +} + +impl ops::BitOr<Ready> for Ready { + type Output = Ready; + + #[inline] + fn bitor(self, other: Ready) -> Ready { + Ready(self.0 | other.0) + } +} + +impl ops::BitOrAssign<Ready> for Ready { + #[inline] + fn bitor_assign(&mut self, other: Ready) { + self.0 |= other.0; + } +} + +impl ops::BitAnd<Ready> for Ready { + type Output = Ready; + + #[inline] + fn bitand(self, other: Ready) -> Ready { + Ready(self.0 & other.0) + } +} + +impl ops::Sub<Ready> for Ready { + type Output = Ready; + + #[inline] + fn sub(self, other: Ready) -> Ready { + Ready(self.0 & !other.0) + } +} + +impl fmt::Debug for Ready { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Ready") + .field("is_readable", &self.is_readable()) + .field("is_writable", &self.is_writable()) + .field("is_read_closed", &self.is_read_closed()) + .field("is_write_closed", &self.is_write_closed()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/io/driver/registration.rs b/third_party/rust/tokio/src/io/driver/registration.rs new file mode 100644 index 0000000000..7350be6345 --- /dev/null +++ b/third_party/rust/tokio/src/io/driver/registration.rs @@ -0,0 +1,262 @@ +#![cfg_attr(not(feature = "net"), allow(dead_code))] + +use crate::io::driver::{Direction, Handle, Interest, ReadyEvent, ScheduledIo}; +use crate::util::slab; + +use mio::event::Source; +use std::io; +use std::task::{Context, Poll}; + +cfg_io_driver! { + /// Associates an I/O resource with the reactor instance that drives it. + /// + /// A registration represents an I/O resource registered with a Reactor such + /// that it will receive task notifications on readiness. This is the lowest + /// level API for integrating with a reactor. + /// + /// The association between an I/O resource is made by calling + /// [`new_with_interest_and_handle`]. + /// Once the association is established, it remains established until the + /// registration instance is dropped. + /// + /// A registration instance represents two separate readiness streams. One + /// for the read readiness and one for write readiness. These streams are + /// independent and can be consumed from separate tasks. + /// + /// **Note**: while `Registration` is `Sync`, the caller must ensure that + /// there are at most two tasks that use a registration instance + /// concurrently. One task for [`poll_read_ready`] and one task for + /// [`poll_write_ready`]. While violating this requirement is "safe" from a + /// Rust memory safety point of view, it will result in unexpected behavior + /// in the form of lost notifications and tasks hanging. + /// + /// ## Platform-specific events + /// + /// `Registration` also allows receiving platform-specific `mio::Ready` + /// events. These events are included as part of the read readiness event + /// stream. The write readiness event stream is only for `Ready::writable()` + /// events. + /// + /// [`new_with_interest_and_handle`]: method@Self::new_with_interest_and_handle + /// [`poll_read_ready`]: method@Self::poll_read_ready` + /// [`poll_write_ready`]: method@Self::poll_write_ready` + #[derive(Debug)] + pub(crate) struct Registration { + /// Handle to the associated driver. + handle: Handle, + + /// Reference to state stored by the driver. + shared: slab::Ref<ScheduledIo>, + } +} + +unsafe impl Send for Registration {} +unsafe impl Sync for Registration {} + +// ===== impl Registration ===== + +impl Registration { + /// Registers the I/O resource with the default reactor, for a specific + /// `Interest`. `new_with_interest` should be used over `new` when you need + /// control over the readiness state, such as when a file descriptor only + /// allows reads. This does not add `hup` or `error` so if you are + /// interested in those states, you will need to add them to the readiness + /// state passed to this function. + /// + /// # Return + /// + /// - `Ok` if the registration happened successfully + /// - `Err` if an error was encountered during registration + pub(crate) fn new_with_interest_and_handle( + io: &mut impl Source, + interest: Interest, + handle: Handle, + ) -> io::Result<Registration> { + let shared = if let Some(inner) = handle.inner() { + inner.add_source(io, interest)? + } else { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to find event loop", + )); + }; + + Ok(Registration { handle, shared }) + } + + /// Deregisters the I/O resource from the reactor it is associated with. + /// + /// This function must be called before the I/O resource associated with the + /// registration is dropped. + /// + /// Note that deregistering does not guarantee that the I/O resource can be + /// registered with a different reactor. Some I/O resource types can only be + /// associated with a single reactor instance for their lifetime. + /// + /// # Return + /// + /// If the deregistration was successful, `Ok` is returned. Any calls to + /// `Reactor::turn` that happen after a successful call to `deregister` will + /// no longer result in notifications getting sent for this registration. + /// + /// `Err` is returned if an error is encountered. + pub(crate) fn deregister(&mut self, io: &mut impl Source) -> io::Result<()> { + let inner = match self.handle.inner() { + Some(inner) => inner, + None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")), + }; + inner.deregister_source(io) + } + + pub(crate) fn clear_readiness(&self, event: ReadyEvent) { + self.shared.clear_readiness(event); + } + + // Uses the poll path, requiring the caller to ensure mutual exclusion for + // correctness. Only the last task to call this function is notified. + pub(crate) fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> { + self.poll_ready(cx, Direction::Read) + } + + // Uses the poll path, requiring the caller to ensure mutual exclusion for + // correctness. Only the last task to call this function is notified. + pub(crate) fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> { + self.poll_ready(cx, Direction::Write) + } + + // Uses the poll path, requiring the caller to ensure mutual exclusion for + // correctness. Only the last task to call this function is notified. + pub(crate) fn poll_read_io<R>( + &self, + cx: &mut Context<'_>, + f: impl FnMut() -> io::Result<R>, + ) -> Poll<io::Result<R>> { + self.poll_io(cx, Direction::Read, f) + } + + // Uses the poll path, requiring the caller to ensure mutual exclusion for + // correctness. Only the last task to call this function is notified. + pub(crate) fn poll_write_io<R>( + &self, + cx: &mut Context<'_>, + f: impl FnMut() -> io::Result<R>, + ) -> Poll<io::Result<R>> { + self.poll_io(cx, Direction::Write, f) + } + + /// Polls for events on the I/O resource's `direction` readiness stream. + /// + /// If called with a task context, notify the task when a new event is + /// received. + fn poll_ready( + &self, + cx: &mut Context<'_>, + direction: Direction, + ) -> Poll<io::Result<ReadyEvent>> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + let ev = ready!(self.shared.poll_readiness(cx, direction)); + + if self.handle.inner().is_none() { + return Poll::Ready(Err(gone())); + } + + coop.made_progress(); + Poll::Ready(Ok(ev)) + } + + fn poll_io<R>( + &self, + cx: &mut Context<'_>, + direction: Direction, + mut f: impl FnMut() -> io::Result<R>, + ) -> Poll<io::Result<R>> { + loop { + let ev = ready!(self.poll_ready(cx, direction))?; + + match f() { + Ok(ret) => { + return Poll::Ready(Ok(ret)); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + pub(crate) fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + let ev = self.shared.ready_event(interest); + + // Don't attempt the operation if the resource is not ready. + if ev.ready.is_empty() { + return Err(io::ErrorKind::WouldBlock.into()); + } + + match f() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.clear_readiness(ev); + Err(io::ErrorKind::WouldBlock.into()) + } + res => res, + } + } +} + +impl Drop for Registration { + fn drop(&mut self) { + // It is possible for a cycle to be created between wakers stored in + // `ScheduledIo` instances and `Arc<driver::Inner>`. To break this + // cycle, wakers are cleared. This is an imperfect solution as it is + // possible to store a `Registration` in a waker. In this case, the + // cycle would remain. + // + // See tokio-rs/tokio#3481 for more details. + self.shared.clear_wakers(); + } +} + +fn gone() -> io::Error { + io::Error::new(io::ErrorKind::Other, "IO driver has terminated") +} + +cfg_io_readiness! { + impl Registration { + pub(crate) async fn readiness(&self, interest: Interest) -> io::Result<ReadyEvent> { + use std::future::Future; + use std::pin::Pin; + + let fut = self.shared.readiness(interest); + pin!(fut); + + crate::future::poll_fn(|cx| { + if self.handle.inner().is_none() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR + ))); + } + + Pin::new(&mut fut).poll(cx).map(Ok) + }).await + } + + pub(crate) async fn async_io<R>(&self, interest: Interest, mut f: impl FnMut() -> io::Result<R>) -> io::Result<R> { + loop { + let event = self.readiness(interest).await?; + + match f() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.clear_readiness(event); + } + x => return x, + } + } + } + } +} diff --git a/third_party/rust/tokio/src/io/driver/scheduled_io.rs b/third_party/rust/tokio/src/io/driver/scheduled_io.rs new file mode 100644 index 0000000000..76f93431ba --- /dev/null +++ b/third_party/rust/tokio/src/io/driver/scheduled_io.rs @@ -0,0 +1,533 @@ +use super::{Interest, Ready, ReadyEvent, Tick}; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::util::bit; +use crate::util::slab::Entry; +use crate::util::WakeList; + +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::task::{Context, Poll, Waker}; + +use super::Direction; + +cfg_io_readiness! { + use crate::util::linked_list::{self, LinkedList}; + + use std::cell::UnsafeCell; + use std::future::Future; + use std::marker::PhantomPinned; + use std::pin::Pin; + use std::ptr::NonNull; +} + +/// Stored in the I/O driver resource slab. +#[derive(Debug)] +pub(crate) struct ScheduledIo { + /// Packs the resource's readiness with the resource's generation. + readiness: AtomicUsize, + + waiters: Mutex<Waiters>, +} + +cfg_io_readiness! { + type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; +} + +#[derive(Debug, Default)] +struct Waiters { + #[cfg(feature = "net")] + /// List of all current waiters. + list: WaitList, + + /// Waker used for AsyncRead. + reader: Option<Waker>, + + /// Waker used for AsyncWrite. + writer: Option<Waker>, + + /// True if this ScheduledIo has been killed due to IO driver shutdown. + is_shutdown: bool, +} + +cfg_io_readiness! { + #[derive(Debug)] + struct Waiter { + pointers: linked_list::Pointers<Waiter>, + + /// The waker for this task. + waker: Option<Waker>, + + /// The interest this waiter is waiting on. + interest: Interest, + + is_ready: bool, + + /// Should never be `!Unpin`. + _p: PhantomPinned, + } + + /// Future returned by `readiness()`. + struct Readiness<'a> { + scheduled_io: &'a ScheduledIo, + + state: State, + + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, + } + + enum State { + Init, + Waiting, + Done, + } +} + +// The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness. +// +// | reserved | generation | driver tick | readiness | +// |----------+------------+--------------+-----------| +// | 1 bit | 7 bits + 8 bits + 16 bits | + +const READINESS: bit::Pack = bit::Pack::least_significant(16); + +const TICK: bit::Pack = READINESS.then(8); + +const GENERATION: bit::Pack = TICK.then(7); + +#[test] +fn test_generations_assert_same() { + assert_eq!(super::GENERATION, GENERATION); +} + +// ===== impl ScheduledIo ===== + +impl Entry for ScheduledIo { + fn reset(&self) { + let state = self.readiness.load(Acquire); + + let generation = GENERATION.unpack(state); + let next = GENERATION.pack_lossy(generation + 1, 0); + + self.readiness.store(next, Release); + } +} + +impl Default for ScheduledIo { + fn default() -> ScheduledIo { + ScheduledIo { + readiness: AtomicUsize::new(0), + waiters: Mutex::new(Default::default()), + } + } +} + +impl ScheduledIo { + pub(crate) fn generation(&self) -> usize { + GENERATION.unpack(self.readiness.load(Acquire)) + } + + /// Invoked when the IO driver is shut down; forces this ScheduledIo into a + /// permanently ready state. + pub(super) fn shutdown(&self) { + self.wake0(Ready::ALL, true) + } + + /// Sets the readiness on this `ScheduledIo` by invoking the given closure on + /// the current value, returning the previous readiness value. + /// + /// # Arguments + /// - `token`: the token for this `ScheduledIo`. + /// - `tick`: whether setting the tick or trying to clear readiness for a + /// specific tick. + /// - `f`: a closure returning a new readiness value given the previous + /// readiness. + /// + /// # Returns + /// + /// If the given token's generation no longer matches the `ScheduledIo`'s + /// generation, then the corresponding IO resource has been removed and + /// replaced with a new resource. In that case, this method returns `Err`. + /// Otherwise, this returns the previous readiness. + pub(super) fn set_readiness( + &self, + token: Option<usize>, + tick: Tick, + f: impl Fn(Ready) -> Ready, + ) -> Result<(), ()> { + let mut current = self.readiness.load(Acquire); + + loop { + let current_generation = GENERATION.unpack(current); + + if let Some(token) = token { + // Check that the generation for this access is still the + // current one. + if GENERATION.unpack(token) != current_generation { + return Err(()); + } + } + + // Mask out the tick/generation bits so that the modifying + // function doesn't see them. + let current_readiness = Ready::from_usize(current); + let new = f(current_readiness); + + let packed = match tick { + Tick::Set(t) => TICK.pack(t as usize, new.as_usize()), + Tick::Clear(t) => { + if TICK.unpack(current) as u8 != t { + // Trying to clear readiness with an old event! + return Err(()); + } + + TICK.pack(t as usize, new.as_usize()) + } + }; + + let next = GENERATION.pack(current_generation, packed); + + match self + .readiness + .compare_exchange(current, next, AcqRel, Acquire) + { + Ok(_) => return Ok(()), + // we lost the race, retry! + Err(actual) => current = actual, + } + } + } + + /// Notifies all pending waiters that have registered interest in `ready`. + /// + /// There may be many waiters to notify. Waking the pending task **must** be + /// done from outside of the lock otherwise there is a potential for a + /// deadlock. + /// + /// A stack array of wakers is created and filled with wakers to notify, the + /// lock is released, and the wakers are notified. Because there may be more + /// than 32 wakers to notify, if the stack array fills up, the lock is + /// released, the array is cleared, and the iteration continues. + pub(super) fn wake(&self, ready: Ready) { + self.wake0(ready, false); + } + + fn wake0(&self, ready: Ready, shutdown: bool) { + let mut wakers = WakeList::new(); + + let mut waiters = self.waiters.lock(); + + waiters.is_shutdown |= shutdown; + + // check for AsyncRead slot + if ready.is_readable() { + if let Some(waker) = waiters.reader.take() { + wakers.push(waker); + } + } + + // check for AsyncWrite slot + if ready.is_writable() { + if let Some(waker) = waiters.writer.take() { + wakers.push(waker); + } + } + + #[cfg(feature = "net")] + 'outer: loop { + let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest)); + + while wakers.can_push() { + match iter.next() { + Some(waiter) => { + let waiter = unsafe { &mut *waiter.as_ptr() }; + + if let Some(waker) = waiter.waker.take() { + waiter.is_ready = true; + wakers.push(waker); + } + } + None => { + break 'outer; + } + } + } + + drop(waiters); + + wakers.wake_all(); + + // Acquire the lock again. + waiters = self.waiters.lock(); + } + + // Release the lock before notifying + drop(waiters); + + wakers.wake_all(); + } + + pub(super) fn ready_event(&self, interest: Interest) -> ReadyEvent { + let curr = self.readiness.load(Acquire); + + ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready: interest.mask() & Ready::from_usize(READINESS.unpack(curr)), + } + } + + /// Polls for readiness events in a given direction. + /// + /// These are to support `AsyncRead` and `AsyncWrite` polling methods, + /// which cannot use the `async fn` version. This uses reserved reader + /// and writer slots. + pub(super) fn poll_readiness( + &self, + cx: &mut Context<'_>, + direction: Direction, + ) -> Poll<ReadyEvent> { + let curr = self.readiness.load(Acquire); + + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + + if ready.is_empty() { + // Update the task info + let mut waiters = self.waiters.lock(); + let slot = match direction { + Direction::Read => &mut waiters.reader, + Direction::Write => &mut waiters.writer, + }; + + // Avoid cloning the waker if one is already stored that matches the + // current task. + match slot { + Some(existing) => { + if !existing.will_wake(cx.waker()) { + *existing = cx.waker().clone(); + } + } + None => { + *slot = Some(cx.waker().clone()); + } + } + + // Try again, in case the readiness was changed while we were + // taking the waiters lock + let curr = self.readiness.load(Acquire); + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + if waiters.is_shutdown { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready: direction.mask(), + }) + } else if ready.is_empty() { + Poll::Pending + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } + } + + pub(crate) fn clear_readiness(&self, event: ReadyEvent) { + // This consumes the current readiness state **except** for closed + // states. Closed states are excluded because they are final states. + let mask_no_closed = event.ready - Ready::READ_CLOSED - Ready::WRITE_CLOSED; + + // result isn't important + let _ = self.set_readiness(None, Tick::Clear(event.tick), |curr| curr - mask_no_closed); + } + + pub(crate) fn clear_wakers(&self) { + let mut waiters = self.waiters.lock(); + waiters.reader.take(); + waiters.writer.take(); + } +} + +impl Drop for ScheduledIo { + fn drop(&mut self) { + self.wake(Ready::ALL); + } +} + +unsafe impl Send for ScheduledIo {} +unsafe impl Sync for ScheduledIo {} + +cfg_io_readiness! { + impl ScheduledIo { + /// An async version of `poll_readiness` which uses a linked list of wakers. + pub(crate) async fn readiness(&self, interest: Interest) -> ReadyEvent { + self.readiness_fut(interest).await + } + + // This is in a separate function so that the borrow checker doesn't think + // we are borrowing the `UnsafeCell` possibly over await boundaries. + // + // Go figure. + fn readiness_fut(&self, interest: Interest) -> Readiness<'_> { + Readiness { + scheduled_io: self, + state: State::Init, + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + is_ready: false, + interest, + _p: PhantomPinned, + }), + } + } + } + + unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } + } + + // ===== impl Readiness ===== + + impl Future for Readiness<'_> { + type Output = ReadyEvent; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + use std::sync::atomic::Ordering::SeqCst; + + let (scheduled_io, state, waiter) = unsafe { + let me = self.get_unchecked_mut(); + (&me.scheduled_io, &mut me.state, &me.waiter) + }; + + loop { + match *state { + State::Init => { + // Optimistically check existing readiness + let curr = scheduled_io.readiness.load(SeqCst); + let ready = Ready::from_usize(READINESS.unpack(curr)); + + // Safety: `waiter.interest` never changes + let interest = unsafe { (*waiter.get()).interest }; + let ready = ready.intersection(interest); + + if !ready.is_empty() { + // Currently ready! + let tick = TICK.unpack(curr) as u8; + *state = State::Done; + return Poll::Ready(ReadyEvent { tick, ready }); + } + + // Wasn't ready, take the lock (and check again while locked). + let mut waiters = scheduled_io.waiters.lock(); + + let curr = scheduled_io.readiness.load(SeqCst); + let mut ready = Ready::from_usize(READINESS.unpack(curr)); + + if waiters.is_shutdown { + ready = Ready::ALL; + } + + let ready = ready.intersection(interest); + + if !ready.is_empty() { + // Currently ready! + let tick = TICK.unpack(curr) as u8; + *state = State::Done; + return Poll::Ready(ReadyEvent { tick, ready }); + } + + // Not ready even after locked, insert into list... + + // Safety: called while locked + unsafe { + (*waiter.get()).waker = Some(cx.waker().clone()); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters + .list + .push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + *state = State::Waiting; + } + State::Waiting => { + // Currently in the "Waiting" state, implying the caller has + // a waiter stored in the waiter list (guarded by + // `notify.waiters`). In order to access the waker fields, + // we must hold the lock. + + let waiters = scheduled_io.waiters.lock(); + + // Safety: called while locked + let w = unsafe { &mut *waiter.get() }; + + if w.is_ready { + // Our waker has been notified. + *state = State::Done; + } else { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + + return Poll::Pending; + } + + // Explicit drop of the lock to indicate the scope that the + // lock is held. Because holding the lock is required to + // ensure safe access to fields not held within the lock, it + // is helpful to visualize the scope of the critical + // section. + drop(waiters); + } + State::Done => { + let tick = TICK.unpack(scheduled_io.readiness.load(Acquire)) as u8; + + // Safety: State::Done means it is no longer shared + let w = unsafe { &mut *waiter.get() }; + + return Poll::Ready(ReadyEvent { + tick, + ready: Ready::from_interest(w.interest), + }); + } + } + } + } + } + + impl Drop for Readiness<'_> { + fn drop(&mut self) { + let mut waiters = self.scheduled_io.waiters.lock(); + + // Safety: `waiter` is only ever stored in `waiters` + unsafe { + waiters + .list + .remove(NonNull::new_unchecked(self.waiter.get())) + }; + } + } + + unsafe impl Send for Readiness<'_> {} + unsafe impl Sync for Readiness<'_> {} +} diff --git a/third_party/rust/tokio/src/io/mod.rs b/third_party/rust/tokio/src/io/mod.rs new file mode 100644 index 0000000000..cfdda61f69 --- /dev/null +++ b/third_party/rust/tokio/src/io/mod.rs @@ -0,0 +1,276 @@ +#![cfg_attr(loom, allow(dead_code, unreachable_pub))] + +//! Traits, helpers, and type definitions for asynchronous I/O functionality. +//! +//! This module is the asynchronous version of `std::io`. Primarily, it +//! defines two traits, [`AsyncRead`] and [`AsyncWrite`], which are asynchronous +//! versions of the [`Read`] and [`Write`] traits in the standard library. +//! +//! # AsyncRead and AsyncWrite +//! +//! Like the standard library's [`Read`] and [`Write`] traits, [`AsyncRead`] and +//! [`AsyncWrite`] provide the most general interface for reading and writing +//! input and output. Unlike the standard library's traits, however, they are +//! _asynchronous_ — meaning that reading from or writing to a `tokio::io` +//! type will _yield_ to the Tokio scheduler when IO is not ready, rather than +//! blocking. This allows other tasks to run while waiting on IO. +//! +//! Another difference is that `AsyncRead` and `AsyncWrite` only contain +//! core methods needed to provide asynchronous reading and writing +//! functionality. Instead, utility methods are defined in the [`AsyncReadExt`] +//! and [`AsyncWriteExt`] extension traits. These traits are automatically +//! implemented for all values that implement `AsyncRead` and `AsyncWrite` +//! respectively. +//! +//! End users will rarely interact directly with `AsyncRead` and +//! `AsyncWrite`. Instead, they will use the async functions defined in the +//! extension traits. Library authors are expected to implement `AsyncRead` +//! and `AsyncWrite` in order to provide types that behave like byte streams. +//! +//! Even with these differences, Tokio's `AsyncRead` and `AsyncWrite` traits +//! can be used in almost exactly the same manner as the standard library's +//! `Read` and `Write`. Most types in the standard library that implement `Read` +//! and `Write` have asynchronous equivalents in `tokio` that implement +//! `AsyncRead` and `AsyncWrite`, such as [`File`] and [`TcpStream`]. +//! +//! For example, the standard library documentation introduces `Read` by +//! [demonstrating][std_example] reading some bytes from a [`std::fs::File`]. We +//! can do the same with [`tokio::fs::File`][`File`]: +//! +//! ```no_run +//! use tokio::io::{self, AsyncReadExt}; +//! use tokio::fs::File; +//! +//! #[tokio::main] +//! async fn main() -> io::Result<()> { +//! let mut f = File::open("foo.txt").await?; +//! let mut buffer = [0; 10]; +//! +//! // read up to 10 bytes +//! let n = f.read(&mut buffer).await?; +//! +//! println!("The bytes: {:?}", &buffer[..n]); +//! Ok(()) +//! } +//! ``` +//! +//! [`File`]: crate::fs::File +//! [`TcpStream`]: crate::net::TcpStream +//! [`std::fs::File`]: std::fs::File +//! [std_example]: std::io#read-and-write +//! +//! ## Buffered Readers and Writers +//! +//! Byte-based interfaces are unwieldy and can be inefficient, as we'd need to be +//! making near-constant calls to the operating system. To help with this, +//! `std::io` comes with [support for _buffered_ readers and writers][stdbuf], +//! and therefore, `tokio::io` does as well. +//! +//! Tokio provides an async version of the [`std::io::BufRead`] trait, +//! [`AsyncBufRead`]; and async [`BufReader`] and [`BufWriter`] structs, which +//! wrap readers and writers. These wrappers use a buffer, reducing the number +//! of calls and providing nicer methods for accessing exactly what you want. +//! +//! For example, [`BufReader`] works with the [`AsyncBufRead`] trait to add +//! extra methods to any async reader: +//! +//! ```no_run +//! use tokio::io::{self, BufReader, AsyncBufReadExt}; +//! use tokio::fs::File; +//! +//! #[tokio::main] +//! async fn main() -> io::Result<()> { +//! let f = File::open("foo.txt").await?; +//! let mut reader = BufReader::new(f); +//! let mut buffer = String::new(); +//! +//! // read a line into buffer +//! reader.read_line(&mut buffer).await?; +//! +//! println!("{}", buffer); +//! Ok(()) +//! } +//! ``` +//! +//! [`BufWriter`] doesn't add any new ways of writing; it just buffers every call +//! to [`write`](crate::io::AsyncWriteExt::write). However, you **must** flush +//! [`BufWriter`] to ensure that any buffered data is written. +//! +//! ```no_run +//! use tokio::io::{self, BufWriter, AsyncWriteExt}; +//! use tokio::fs::File; +//! +//! #[tokio::main] +//! async fn main() -> io::Result<()> { +//! let f = File::create("foo.txt").await?; +//! { +//! let mut writer = BufWriter::new(f); +//! +//! // Write a byte to the buffer. +//! writer.write(&[42u8]).await?; +//! +//! // Flush the buffer before it goes out of scope. +//! writer.flush().await?; +//! +//! } // Unless flushed or shut down, the contents of the buffer is discarded on drop. +//! +//! Ok(()) +//! } +//! ``` +//! +//! [stdbuf]: std::io#bufreader-and-bufwriter +//! [`std::io::BufRead`]: std::io::BufRead +//! [`AsyncBufRead`]: crate::io::AsyncBufRead +//! [`BufReader`]: crate::io::BufReader +//! [`BufWriter`]: crate::io::BufWriter +//! +//! ## Implementing AsyncRead and AsyncWrite +//! +//! Because they are traits, we can implement [`AsyncRead`] and [`AsyncWrite`] for +//! our own types, as well. Note that these traits must only be implemented for +//! non-blocking I/O types that integrate with the futures type system. In +//! other words, these types must never block the thread, and instead the +//! current task is notified when the I/O resource is ready. +//! +//! ## Conversion to and from Sink/Stream +//! +//! It is often convenient to encapsulate the reading and writing of +//! bytes and instead work with a [`Sink`] or [`Stream`] of some data +//! type that is encoded as bytes and/or decoded from bytes. Tokio +//! provides some utility traits in the [tokio-util] crate that +//! abstract the asynchronous buffering that is required and allows +//! you to write [`Encoder`] and [`Decoder`] functions working with a +//! buffer of bytes, and then use that ["codec"] to transform anything +//! that implements [`AsyncRead`] and [`AsyncWrite`] into a `Sink`/`Stream` of +//! your structured data. +//! +//! [tokio-util]: https://docs.rs/tokio-util/0.6/tokio_util/codec/index.html +//! +//! # Standard input and output +//! +//! Tokio provides asynchronous APIs to standard [input], [output], and [error]. +//! These APIs are very similar to the ones provided by `std`, but they also +//! implement [`AsyncRead`] and [`AsyncWrite`]. +//! +//! Note that the standard input / output APIs **must** be used from the +//! context of the Tokio runtime, as they require Tokio-specific features to +//! function. Calling these functions outside of a Tokio runtime will panic. +//! +//! [input]: fn@stdin +//! [output]: fn@stdout +//! [error]: fn@stderr +//! +//! # `std` re-exports +//! +//! Additionally, [`Error`], [`ErrorKind`], [`Result`], and [`SeekFrom`] are +//! re-exported from `std::io` for ease of use. +//! +//! [`AsyncRead`]: trait@AsyncRead +//! [`AsyncWrite`]: trait@AsyncWrite +//! [`AsyncReadExt`]: trait@AsyncReadExt +//! [`AsyncWriteExt`]: trait@AsyncWriteExt +//! ["codec"]: https://docs.rs/tokio-util/0.6/tokio_util/codec/index.html +//! [`Encoder`]: https://docs.rs/tokio-util/0.6/tokio_util/codec/trait.Encoder.html +//! [`Decoder`]: https://docs.rs/tokio-util/0.6/tokio_util/codec/trait.Decoder.html +//! [`Error`]: struct@Error +//! [`ErrorKind`]: enum@ErrorKind +//! [`Result`]: type@Result +//! [`Read`]: std::io::Read +//! [`SeekFrom`]: enum@SeekFrom +//! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html +//! [`Stream`]: https://docs.rs/futures/0.3/futures/stream/trait.Stream.html +//! [`Write`]: std::io::Write +cfg_io_blocking! { + pub(crate) mod blocking; +} + +mod async_buf_read; +pub use self::async_buf_read::AsyncBufRead; + +mod async_read; +pub use self::async_read::AsyncRead; + +mod async_seek; +pub use self::async_seek::AsyncSeek; + +mod async_write; +pub use self::async_write::AsyncWrite; + +mod read_buf; +pub use self::read_buf::ReadBuf; + +// Re-export some types from `std::io` so that users don't have to deal +// with conflicts when `use`ing `tokio::io` and `std::io`. +#[doc(no_inline)] +pub use std::io::{Error, ErrorKind, Result, SeekFrom}; + +cfg_io_driver_impl! { + pub(crate) mod driver; + + cfg_net! { + pub use driver::{Interest, Ready}; + } + + mod poll_evented; + + #[cfg(not(loom))] + pub(crate) use poll_evented::PollEvented; +} + +cfg_aio! { + /// BSD-specific I/O types. + pub mod bsd { + mod poll_aio; + + pub use poll_aio::{Aio, AioEvent, AioSource}; + } +} + +cfg_net_unix! { + mod async_fd; + + pub mod unix { + //! Asynchronous IO structures specific to Unix-like operating systems. + pub use super::async_fd::{AsyncFd, AsyncFdReadyGuard, AsyncFdReadyMutGuard, TryIoError}; + } +} + +cfg_io_std! { + mod stdio_common; + + mod stderr; + pub use stderr::{stderr, Stderr}; + + mod stdin; + pub use stdin::{stdin, Stdin}; + + mod stdout; + pub use stdout::{stdout, Stdout}; +} + +cfg_io_util! { + mod split; + pub use split::{split, ReadHalf, WriteHalf}; + + pub(crate) mod seek; + pub(crate) mod util; + pub use util::{ + copy, copy_bidirectional, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take, + }; +} + +cfg_not_io_util! { + cfg_process! { + pub(crate) mod util; + } +} + +cfg_io_blocking! { + /// Types in this module can be mocked out in tests. + mod sys { + // TODO: don't rename + pub(crate) use crate::blocking::spawn_blocking as run; + pub(crate) use crate::blocking::JoinHandle as Blocking; + } +} diff --git a/third_party/rust/tokio/src/io/poll_evented.rs b/third_party/rust/tokio/src/io/poll_evented.rs new file mode 100644 index 0000000000..ce4c1426ac --- /dev/null +++ b/third_party/rust/tokio/src/io/poll_evented.rs @@ -0,0 +1,214 @@ +use crate::io::driver::{Handle, Interest, Registration}; + +use mio::event::Source; +use std::fmt; +use std::io; +use std::ops::Deref; +use std::panic::{RefUnwindSafe, UnwindSafe}; + +cfg_io_driver! { + /// Associates an I/O resource that implements the [`std::io::Read`] and/or + /// [`std::io::Write`] traits with the reactor that drives it. + /// + /// `PollEvented` uses [`Registration`] internally to take a type that + /// implements [`mio::event::Source`] as well as [`std::io::Read`] and or + /// [`std::io::Write`] and associate it with a reactor that will drive it. + /// + /// Once the [`mio::event::Source`] type is wrapped by `PollEvented`, it can be + /// used from within the future's execution model. As such, the + /// `PollEvented` type provides [`AsyncRead`] and [`AsyncWrite`] + /// implementations using the underlying I/O resource as well as readiness + /// events provided by the reactor. + /// + /// **Note**: While `PollEvented` is `Sync` (if the underlying I/O type is + /// `Sync`), the caller must ensure that there are at most two tasks that + /// use a `PollEvented` instance concurrently. One for reading and one for + /// writing. While violating this requirement is "safe" from a Rust memory + /// model point of view, it will result in unexpected behavior in the form + /// of lost notifications and tasks hanging. + /// + /// ## Readiness events + /// + /// Besides just providing [`AsyncRead`] and [`AsyncWrite`] implementations, + /// this type also supports access to the underlying readiness event stream. + /// While similar in function to what [`Registration`] provides, the + /// semantics are a bit different. + /// + /// Two functions are provided to access the readiness events: + /// [`poll_read_ready`] and [`poll_write_ready`]. These functions return the + /// current readiness state of the `PollEvented` instance. If + /// [`poll_read_ready`] indicates read readiness, immediately calling + /// [`poll_read_ready`] again will also indicate read readiness. + /// + /// When the operation is attempted and is unable to succeed due to the I/O + /// resource not being ready, the caller must call `clear_readiness`. + /// This clears the readiness state until a new readiness event is received. + /// + /// This allows the caller to implement additional functions. For example, + /// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and + /// `clear_read_ready`. + /// + /// ## Platform-specific events + /// + /// `PollEvented` also allows receiving platform-specific `mio::Ready` events. + /// These events are included as part of the read readiness event stream. The + /// write readiness event stream is only for `Ready::writable()` events. + /// + /// [`AsyncRead`]: crate::io::AsyncRead + /// [`AsyncWrite`]: crate::io::AsyncWrite + /// [`TcpListener`]: crate::net::TcpListener + /// [`poll_read_ready`]: Registration::poll_read_ready + /// [`poll_write_ready`]: Registration::poll_write_ready + pub(crate) struct PollEvented<E: Source> { + io: Option<E>, + registration: Registration, + } +} + +// ===== impl PollEvented ===== + +impl<E: Source> PollEvented<E> { + /// Creates a new `PollEvented` associated with the default reactor. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + #[cfg_attr(feature = "signal", allow(unused))] + pub(crate) fn new(io: E) -> io::Result<Self> { + PollEvented::new_with_interest(io, Interest::READABLE | Interest::WRITABLE) + } + + /// Creates a new `PollEvented` associated with the default reactor, for + /// specific `Interest` state. `new_with_interest` should be used over `new` + /// when you need control over the readiness state, such as when a file + /// descriptor only allows reads. This does not add `hup` or `error` so if + /// you are interested in those states, you will need to add them to the + /// readiness state passed to this function. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called from + /// a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) + /// function. + #[cfg_attr(feature = "signal", allow(unused))] + pub(crate) fn new_with_interest(io: E, interest: Interest) -> io::Result<Self> { + Self::new_with_interest_and_handle(io, interest, Handle::current()) + } + + pub(crate) fn new_with_interest_and_handle( + mut io: E, + interest: Interest, + handle: Handle, + ) -> io::Result<Self> { + let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?; + Ok(Self { + io: Some(io), + registration, + }) + } + + /// Returns a reference to the registration. + #[cfg(any( + feature = "net", + all(unix, feature = "process"), + all(unix, feature = "signal"), + ))] + pub(crate) fn registration(&self) -> &Registration { + &self.registration + } + + /// Deregisters the inner io from the registration and returns a Result containing the inner io. + #[cfg(any(feature = "net", feature = "process"))] + pub(crate) fn into_inner(mut self) -> io::Result<E> { + let mut inner = self.io.take().unwrap(); // As io shouldn't ever be None, just unwrap here. + self.registration.deregister(&mut inner)?; + Ok(inner) + } +} + +feature! { + #![any(feature = "net", feature = "process")] + + use crate::io::ReadBuf; + use std::task::{Context, Poll}; + + impl<E: Source> PollEvented<E> { + // Safety: The caller must ensure that `E` can read into uninitialized memory + pub(crate) unsafe fn poll_read<'a>( + &'a self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> + where + &'a E: io::Read + 'a, + { + use std::io::Read; + + let n = ready!(self.registration.poll_read_io(cx, || { + let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); + self.io.as_ref().unwrap().read(b) + }))?; + + // Safety: We trust `TcpStream::read` to have filled up `n` bytes in the + // buffer. + buf.assume_init(n); + buf.advance(n); + Poll::Ready(Ok(())) + } + + pub(crate) fn poll_write<'a>(&'a self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> + where + &'a E: io::Write + 'a, + { + use std::io::Write; + self.registration.poll_write_io(cx, || self.io.as_ref().unwrap().write(buf)) + } + + #[cfg(feature = "net")] + pub(crate) fn poll_write_vectored<'a>( + &'a self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> + where + &'a E: io::Write + 'a, + { + use std::io::Write; + self.registration.poll_write_io(cx, || self.io.as_ref().unwrap().write_vectored(bufs)) + } + } +} + +impl<E: Source> UnwindSafe for PollEvented<E> {} + +impl<E: Source> RefUnwindSafe for PollEvented<E> {} + +impl<E: Source> Deref for PollEvented<E> { + type Target = E; + + fn deref(&self) -> &E { + self.io.as_ref().unwrap() + } +} + +impl<E: Source + fmt::Debug> fmt::Debug for PollEvented<E> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollEvented").field("io", &self.io).finish() + } +} + +impl<E: Source> Drop for PollEvented<E> { + fn drop(&mut self) { + if let Some(mut io) = self.io.take() { + // Ignore errors + let _ = self.registration.deregister(&mut io); + } + } +} diff --git a/third_party/rust/tokio/src/io/read_buf.rs b/third_party/rust/tokio/src/io/read_buf.rs new file mode 100644 index 0000000000..8c34ae6c81 --- /dev/null +++ b/third_party/rust/tokio/src/io/read_buf.rs @@ -0,0 +1,291 @@ +use std::fmt; +use std::mem::MaybeUninit; + +/// A wrapper around a byte buffer that is incrementally filled and initialized. +/// +/// This type is a sort of "double cursor". It tracks three regions in the +/// buffer: a region at the beginning of the buffer that has been logically +/// filled with data, a region that has been initialized at some point but not +/// yet logically filled, and a region at the end that may be uninitialized. +/// The filled region is guaranteed to be a subset of the initialized region. +/// +/// In summary, the contents of the buffer can be visualized as: +/// +/// ```not_rust +/// [ capacity ] +/// [ filled | unfilled ] +/// [ initialized | uninitialized ] +/// ``` +/// +/// It is undefined behavior to de-initialize any bytes from the uninitialized +/// region, since it is merely unknown whether this region is uninitialized or +/// not, and if part of it turns out to be initialized, it must stay initialized. +pub struct ReadBuf<'a> { + buf: &'a mut [MaybeUninit<u8>], + filled: usize, + initialized: usize, +} + +impl<'a> ReadBuf<'a> { + /// Creates a new `ReadBuf` from a fully initialized buffer. + #[inline] + pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> { + let initialized = buf.len(); + let buf = unsafe { slice_to_uninit_mut(buf) }; + ReadBuf { + buf, + filled: 0, + initialized, + } + } + + /// Creates a new `ReadBuf` from a fully uninitialized buffer. + /// + /// Use `assume_init` if part of the buffer is known to be already initialized. + #[inline] + pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> { + ReadBuf { + buf, + filled: 0, + initialized: 0, + } + } + + /// Returns the total capacity of the buffer. + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + /// Returns a shared reference to the filled portion of the buffer. + #[inline] + pub fn filled(&self) -> &[u8] { + let slice = &self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + unsafe { slice_assume_init(slice) } + } + + /// Returns a mutable reference to the filled portion of the buffer. + #[inline] + pub fn filled_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + unsafe { slice_assume_init_mut(slice) } + } + + /// Returns a new `ReadBuf` comprised of the unfilled section up to `n`. + #[inline] + pub fn take(&mut self, n: usize) -> ReadBuf<'_> { + let max = std::cmp::min(self.remaining(), n); + // Safety: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. + unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) } + } + + /// Returns a shared reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized(&self) -> &[u8] { + let slice = &self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + unsafe { slice_assume_init(slice) } + } + + /// Returns a mutable reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + unsafe { slice_assume_init_mut(slice) } + } + + /// Returns a mutable reference to the entire buffer, without ensuring that it has been fully + /// initialized. + /// + /// The elements between 0 and `self.filled().len()` are filled, and those between 0 and + /// `self.initialized().len()` are initialized (and so can be converted to a `&mut [u8]`). + /// + /// The caller of this method must ensure that these invariants are upheld. For example, if the + /// caller initializes some of the uninitialized section of the buffer, it must call + /// [`assume_init`](Self::assume_init) with the number of bytes initialized. + /// + /// # Safety + /// + /// The caller must not de-initialize portions of the buffer that have already been initialized. + /// This includes any bytes in the region marked as uninitialized by `ReadBuf`. + #[inline] + pub unsafe fn inner_mut(&mut self) -> &mut [MaybeUninit<u8>] { + self.buf + } + + /// Returns a mutable reference to the unfilled part of the buffer without ensuring that it has been fully + /// initialized. + /// + /// # Safety + /// + /// The caller must not de-initialize portions of the buffer that have already been initialized. + /// This includes any bytes in the region marked as uninitialized by `ReadBuf`. + #[inline] + pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit<u8>] { + &mut self.buf[self.filled..] + } + + /// Returns a mutable reference to the unfilled part of the buffer, ensuring it is fully initialized. + /// + /// Since `ReadBuf` tracks the region of the buffer that has been initialized, this is effectively "free" after + /// the first use. + #[inline] + pub fn initialize_unfilled(&mut self) -> &mut [u8] { + self.initialize_unfilled_to(self.remaining()) + } + + /// Returns a mutable reference to the first `n` bytes of the unfilled part of the buffer, ensuring it is + /// fully initialized. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `n`. + #[inline] + pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] { + assert!(self.remaining() >= n, "n overflows remaining"); + + // This can't overflow, otherwise the assert above would have failed. + let end = self.filled + n; + + if self.initialized < end { + unsafe { + self.buf[self.initialized..end] + .as_mut_ptr() + .write_bytes(0, end - self.initialized); + } + self.initialized = end; + } + + let slice = &mut self.buf[self.filled..end]; + // safety: just above, we checked that the end of the buf has + // been initialized to some value. + unsafe { slice_assume_init_mut(slice) } + } + + /// Returns the number of bytes at the end of the slice that have not yet been filled. + #[inline] + pub fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + /// Clears the buffer, resetting the filled region to empty. + /// + /// The number of initialized bytes is not changed, and the contents of the buffer are not modified. + #[inline] + pub fn clear(&mut self) { + self.filled = 0; + } + + /// Advances the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the initialized region. + #[inline] + pub fn advance(&mut self, n: usize) { + let new = self.filled.checked_add(n).expect("filled overflow"); + self.set_filled(new); + } + + /// Sets the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// Note that this can be used to *shrink* the filled region of the buffer in addition to growing it (for + /// example, by a `AsyncRead` implementation that compresses data in-place). + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the initialized region. + #[inline] + pub fn set_filled(&mut self, n: usize) { + assert!( + n <= self.initialized, + "filled must not become larger than initialized" + ); + self.filled = n; + } + + /// Asserts that the first `n` unfilled bytes of the buffer are initialized. + /// + /// `ReadBuf` assumes that bytes are never de-initialized, so this method does nothing when called with fewer + /// bytes than are already known to be initialized. + /// + /// # Safety + /// + /// The caller must ensure that `n` unfilled bytes of the buffer have already been initialized. + #[inline] + pub unsafe fn assume_init(&mut self, n: usize) { + let new = self.filled + n; + if new > self.initialized { + self.initialized = new; + } + } + + /// Appends data to the buffer, advancing the written position and possibly also the initialized position. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `buf.len()`. + #[inline] + pub fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf[self.filled..end] + .as_mut_ptr() + .cast::<u8>() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.initialized < end { + self.initialized = end; + } + self.filled = end; + } +} + +impl fmt::Debug for ReadBuf<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("initialized", &self.initialized) + .field("capacity", &self.capacity()) + .finish() + } +} + +unsafe fn slice_to_uninit_mut(slice: &mut [u8]) -> &mut [MaybeUninit<u8>] { + &mut *(slice as *mut [u8] as *mut [MaybeUninit<u8>]) +} + +// TODO: This could use `MaybeUninit::slice_assume_init` when it is stable. +unsafe fn slice_assume_init(slice: &[MaybeUninit<u8>]) -> &[u8] { + &*(slice as *const [MaybeUninit<u8>] as *const [u8]) +} + +// TODO: This could use `MaybeUninit::slice_assume_init_mut` when it is stable. +unsafe fn slice_assume_init_mut(slice: &mut [MaybeUninit<u8>]) -> &mut [u8] { + &mut *(slice as *mut [MaybeUninit<u8>] as *mut [u8]) +} diff --git a/third_party/rust/tokio/src/io/seek.rs b/third_party/rust/tokio/src/io/seek.rs new file mode 100644 index 0000000000..e64205d9cf --- /dev/null +++ b/third_party/rust/tokio/src/io/seek.rs @@ -0,0 +1,57 @@ +use crate::io::AsyncSeek; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io::{self, SeekFrom}; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Seek<'a, S: ?Sized> { + seek: &'a mut S, + pos: Option<SeekFrom>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn seek<S>(seek: &mut S, pos: SeekFrom) -> Seek<'_, S> +where + S: AsyncSeek + ?Sized + Unpin, +{ + Seek { + seek, + pos: Some(pos), + _pin: PhantomPinned, + } +} + +impl<S> Future for Seek<'_, S> +where + S: AsyncSeek + ?Sized + Unpin, +{ + type Output = io::Result<u64>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + match me.pos { + Some(pos) => { + // ensure no seek in progress + ready!(Pin::new(&mut *me.seek).poll_complete(cx))?; + match Pin::new(&mut *me.seek).start_seek(*pos) { + Ok(()) => { + *me.pos = None; + Pin::new(&mut *me.seek).poll_complete(cx) + } + Err(e) => Poll::Ready(Err(e)), + } + } + None => Pin::new(&mut *me.seek).poll_complete(cx), + } + } +} diff --git a/third_party/rust/tokio/src/io/split.rs b/third_party/rust/tokio/src/io/split.rs new file mode 100644 index 0000000000..8258a0f7a0 --- /dev/null +++ b/third_party/rust/tokio/src/io/split.rs @@ -0,0 +1,180 @@ +//! Split a single value implementing `AsyncRead + AsyncWrite` into separate +//! `AsyncRead` and `AsyncWrite` handles. +//! +//! To restore this read/write object from its `split::ReadHalf` and +//! `split::WriteHalf` use `unsplit`. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::cell::UnsafeCell; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// The readable half of a value returned from [`split`](split()). + pub struct ReadHalf<T> { + inner: Arc<Inner<T>>, + } + + /// The writable half of a value returned from [`split`](split()). + pub struct WriteHalf<T> { + inner: Arc<Inner<T>>, + } + + /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate + /// `AsyncRead` and `AsyncWrite` handles. + /// + /// To restore this read/write object from its `ReadHalf` and + /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()). + pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) + where + T: AsyncRead + AsyncWrite, + { + let inner = Arc::new(Inner { + locked: AtomicBool::new(false), + stream: UnsafeCell::new(stream), + }); + + let rd = ReadHalf { + inner: inner.clone(), + }; + + let wr = WriteHalf { inner }; + + (rd, wr) + } +} + +struct Inner<T> { + locked: AtomicBool, + stream: UnsafeCell<T>, +} + +struct Guard<'a, T> { + inner: &'a Inner<T>, +} + +impl<T> ReadHalf<T> { + /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same + /// stream. + pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { + other.is_pair_of(self) + } + + /// Reunites with a previously split `WriteHalf`. + /// + /// # Panics + /// + /// If this `ReadHalf` and the given `WriteHalf` do not originate from the + /// same `split` operation this method will panic. + /// This can be checked ahead of time by comparing the stream ID + /// of the two halves. + pub fn unsplit(self, wr: WriteHalf<T>) -> T { + if self.is_pair_of(&wr) { + drop(wr); + + let inner = Arc::try_unwrap(self.inner) + .ok() + .expect("`Arc::try_unwrap` failed"); + + inner.stream.into_inner() + } else { + panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") + } + } +} + +impl<T> WriteHalf<T> { + /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same + /// stream. + pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<T: AsyncRead> AsyncRead for ReadHalf<T> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let mut inner = ready!(self.inner.poll_lock(cx)); + inner.stream_pin().poll_read(cx, buf) + } +} + +impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + let mut inner = ready!(self.inner.poll_lock(cx)); + inner.stream_pin().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let mut inner = ready!(self.inner.poll_lock(cx)); + inner.stream_pin().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let mut inner = ready!(self.inner.poll_lock(cx)); + inner.stream_pin().poll_shutdown(cx) + } +} + +impl<T> Inner<T> { + fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> { + if self + .locked + .compare_exchange(false, true, Acquire, Acquire) + .is_ok() + { + Poll::Ready(Guard { inner: self }) + } else { + // Spin... but investigate a better strategy + + std::thread::yield_now(); + cx.waker().wake_by_ref(); + + Poll::Pending + } + } +} + +impl<T> Guard<'_, T> { + fn stream_pin(&mut self) -> Pin<&mut T> { + // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual + // exclusion. + unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } + } +} + +impl<T> Drop for Guard<'_, T> { + fn drop(&mut self) { + self.inner.locked.store(false, Release); + } +} + +unsafe impl<T: Send> Send for ReadHalf<T> {} +unsafe impl<T: Send> Send for WriteHalf<T> {} +unsafe impl<T: Sync> Sync for ReadHalf<T> {} +unsafe impl<T: Sync> Sync for WriteHalf<T> {} + +impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("split::ReadHalf").finish() + } +} + +impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("split::WriteHalf").finish() + } +} diff --git a/third_party/rust/tokio/src/io/stderr.rs b/third_party/rust/tokio/src/io/stderr.rs new file mode 100644 index 0000000000..2f624fba9d --- /dev/null +++ b/third_party/rust/tokio/src/io/stderr.rs @@ -0,0 +1,109 @@ +use crate::io::blocking::Blocking; +use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows; +use crate::io::AsyncWrite; + +use std::io; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +cfg_io_std! { + /// A handle to the standard error stream of a process. + /// + /// Concurrent writes to stderr must be executed with care: Only individual + /// writes to this [`AsyncWrite`] are guaranteed to be intact. In particular + /// you should be aware that writes using [`write_all`] are not guaranteed + /// to occur as a single write, so multiple threads writing data with + /// [`write_all`] may result in interleaved output. + /// + /// Created by the [`stderr`] function. + /// + /// [`stderr`]: stderr() + /// [`AsyncWrite`]: AsyncWrite + /// [`write_all`]: crate::io::AsyncWriteExt::write_all() + /// + /// # Examples + /// + /// ``` + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut stderr = io::stdout(); + /// stderr.write_all(b"Print some error here.").await?; + /// Ok(()) + /// } + /// ``` + #[derive(Debug)] + pub struct Stderr { + std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stderr>>, + } + + /// Constructs a new handle to the standard error of the current process. + /// + /// The returned handle allows writing to standard error from the within the + /// Tokio runtime. + /// + /// Concurrent writes to stderr must be executed with care: Only individual + /// writes to this [`AsyncWrite`] are guaranteed to be intact. In particular + /// you should be aware that writes using [`write_all`] are not guaranteed + /// to occur as a single write, so multiple threads writing data with + /// [`write_all`] may result in interleaved output. + /// + /// [`AsyncWrite`]: AsyncWrite + /// [`write_all`]: crate::io::AsyncWriteExt::write_all() + /// + /// # Examples + /// + /// ``` + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut stderr = io::stderr(); + /// stderr.write_all(b"Print some error here.").await?; + /// Ok(()) + /// } + /// ``` + pub fn stderr() -> Stderr { + let std = io::stderr(); + Stderr { + std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), + } + } +} + +#[cfg(unix)] +impl std::os::unix::io::AsRawFd for Stderr { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + std::io::stderr().as_raw_fd() + } +} + +#[cfg(windows)] +impl std::os::windows::io::AsRawHandle for Stderr { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + std::io::stderr().as_raw_handle() + } +} + +impl AsyncWrite for Stderr { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.std).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.std).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.std).poll_shutdown(cx) + } +} diff --git a/third_party/rust/tokio/src/io/stdin.rs b/third_party/rust/tokio/src/io/stdin.rs new file mode 100644 index 0000000000..c9578f17b6 --- /dev/null +++ b/third_party/rust/tokio/src/io/stdin.rs @@ -0,0 +1,73 @@ +use crate::io::blocking::Blocking; +use crate::io::{AsyncRead, ReadBuf}; + +use std::io; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +cfg_io_std! { + /// A handle to the standard input stream of a process. + /// + /// The handle implements the [`AsyncRead`] trait, but beware that concurrent + /// reads of `Stdin` must be executed with care. + /// + /// This handle is best used for non-interactive uses, such as when a file + /// is piped into the application. For technical reasons, `stdin` is + /// implemented by using an ordinary blocking read on a separate thread, and + /// it is impossible to cancel that read. This can make shutdown of the + /// runtime hang until the user presses enter. + /// + /// For interactive uses, it is recommended to spawn a thread dedicated to + /// user input and use blocking IO directly in that thread. + /// + /// Created by the [`stdin`] function. + /// + /// [`stdin`]: fn@stdin + /// [`AsyncRead`]: trait@AsyncRead + #[derive(Debug)] + pub struct Stdin { + std: Blocking<std::io::Stdin>, + } + + /// Constructs a new handle to the standard input of the current process. + /// + /// This handle is best used for non-interactive uses, such as when a file + /// is piped into the application. For technical reasons, `stdin` is + /// implemented by using an ordinary blocking read on a separate thread, and + /// it is impossible to cancel that read. This can make shutdown of the + /// runtime hang until the user presses enter. + /// + /// For interactive uses, it is recommended to spawn a thread dedicated to + /// user input and use blocking IO directly in that thread. + pub fn stdin() -> Stdin { + let std = io::stdin(); + Stdin { + std: Blocking::new(std), + } + } +} + +#[cfg(unix)] +impl std::os::unix::io::AsRawFd for Stdin { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + std::io::stdin().as_raw_fd() + } +} + +#[cfg(windows)] +impl std::os::windows::io::AsRawHandle for Stdin { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + std::io::stdin().as_raw_handle() + } +} + +impl AsyncRead for Stdin { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.std).poll_read(cx, buf) + } +} diff --git a/third_party/rust/tokio/src/io/stdio_common.rs b/third_party/rust/tokio/src/io/stdio_common.rs new file mode 100644 index 0000000000..7e4a198a82 --- /dev/null +++ b/third_party/rust/tokio/src/io/stdio_common.rs @@ -0,0 +1,220 @@ +//! Contains utilities for stdout and stderr. +use crate::io::AsyncWrite; +use std::pin::Pin; +use std::task::{Context, Poll}; +/// # Windows +/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest, +/// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF. +/// That's why, wrapped writer will always receive well-formed utf-8 bytes. +/// # Other platforms +/// Passes data to `inner` as is. +#[derive(Debug)] +pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { + inner: W, +} + +impl<W> SplitByUtf8BoundaryIfWindows<W> { + pub(crate) fn new(inner: W) -> Self { + Self { inner } + } +} + +// this constant is defined by Unicode standard. +const MAX_BYTES_PER_CHAR: usize = 4; + +// Subject for tweaking here +const MAGIC_CONST: usize = 8; + +impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W> +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + // just a closure to avoid repetitive code + let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf); + + // 1. Only windows stdio can suffer from non-utf8. + // We also check for `test` so that we can write some tests + // for further code. Since `AsyncWrite` can always shrink + // buffer at its discretion, excessive (i.e. in tests) shrinking + // does not break correctness. + // 2. If buffer is small, it will not be shrinked. + // That's why, it's "textness" will not change, so we don't have + // to fixup it. + if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF + { + return call_inner(buf); + } + + buf = &buf[..crate::io::blocking::MAX_BUF]; + + // Now there are two possibilities. + // If caller gave is binary buffer, we **should not** shrink it + // anymore, because excessive shrinking hits performance. + // If caller gave as binary buffer, we **must** additionally + // shrink it to strip incomplete char at the end of buffer. + // that's why check we will perform now is allowed to have + // false-positive. + + // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes. + // if they are (possibly incomplete) utf8, then we can be quite sure + // that input buffer was utf8. + + let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) { + Ok(_) => true, + Err(err) => { + let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to(); + incomplete_bytes < MAX_BYTES_PER_CHAR + } + }; + + if have_to_fix_up { + // We must pop several bytes at the end which form incomplete + // character. To achieve it, we exploit UTF8 encoding: + // for any code point, all bytes except first start with 0b10 prefix. + // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details + let trailing_incomplete_char_size = buf + .iter() + .rev() + .take(MAX_BYTES_PER_CHAR) + .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000) + .unwrap_or(0) + + 1; + buf = &buf[..buf.len() - trailing_incomplete_char_size]; + } + + call_inner(buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +#[cfg(not(loom))] +mod tests { + use crate::io::AsyncWriteExt; + use std::io; + use std::pin::Pin; + use std::task::Context; + use std::task::Poll; + + const MAX_BUF: usize = 16 * 1024; + + struct TextMockWriter; + + impl crate::io::AsyncWrite for TextMockWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + assert!(buf.len() <= MAX_BUF); + assert!(std::str::from_utf8(buf).is_ok()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + } + + struct LoggingMockWriter { + write_history: Vec<usize>, + } + + impl LoggingMockWriter { + fn new() -> Self { + LoggingMockWriter { + write_history: Vec::new(), + } + } + } + + impl crate::io::AsyncWrite for LoggingMockWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + assert!(buf.len() <= MAX_BUF); + self.write_history.push(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + } + + #[test] + fn test_splitter() { + let data = str::repeat("█", MAX_BUF); + let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); + let fut = async move { + wr.write_all(data.as_bytes()).await.unwrap(); + }; + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(fut); + } + + #[test] + fn test_pseudo_text() { + // In this test we write a piece of binary data, whose beginning is + // text though. We then validate that even in this corner case buffer + // was not shrinked too much. + let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; + let mut data: Vec<u8> = str::repeat("a", checked_count).into(); + data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1)); + let mut writer = LoggingMockWriter::new(); + let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(async { + splitter.write_all(&data).await.unwrap(); + }); + // Check that at most two writes were performed + assert!(writer.write_history.len() <= 2); + // Check that all has been written + assert_eq!( + writer.write_history.iter().copied().sum::<usize>(), + data.len() + ); + // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked + // from the buffer: one because it was outside of MAX_BUF boundary, and + // up to one "utf8 code point". + assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); + } +} diff --git a/third_party/rust/tokio/src/io/stdout.rs b/third_party/rust/tokio/src/io/stdout.rs new file mode 100644 index 0000000000..a08ed01eed --- /dev/null +++ b/third_party/rust/tokio/src/io/stdout.rs @@ -0,0 +1,108 @@ +use crate::io::blocking::Blocking; +use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows; +use crate::io::AsyncWrite; +use std::io; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +cfg_io_std! { + /// A handle to the standard output stream of a process. + /// + /// Concurrent writes to stdout must be executed with care: Only individual + /// writes to this [`AsyncWrite`] are guaranteed to be intact. In particular + /// you should be aware that writes using [`write_all`] are not guaranteed + /// to occur as a single write, so multiple threads writing data with + /// [`write_all`] may result in interleaved output. + /// + /// Created by the [`stdout`] function. + /// + /// [`stdout`]: stdout() + /// [`AsyncWrite`]: AsyncWrite + /// [`write_all`]: crate::io::AsyncWriteExt::write_all() + /// + /// # Examples + /// + /// ``` + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut stdout = io::stdout(); + /// stdout.write_all(b"Hello world!").await?; + /// Ok(()) + /// } + /// ``` + #[derive(Debug)] + pub struct Stdout { + std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stdout>>, + } + + /// Constructs a new handle to the standard output of the current process. + /// + /// The returned handle allows writing to standard out from the within the + /// Tokio runtime. + /// + /// Concurrent writes to stdout must be executed with care: Only individual + /// writes to this [`AsyncWrite`] are guaranteed to be intact. In particular + /// you should be aware that writes using [`write_all`] are not guaranteed + /// to occur as a single write, so multiple threads writing data with + /// [`write_all`] may result in interleaved output. + /// + /// [`AsyncWrite`]: AsyncWrite + /// [`write_all`]: crate::io::AsyncWriteExt::write_all() + /// + /// # Examples + /// + /// ``` + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut stdout = io::stdout(); + /// stdout.write_all(b"Hello world!").await?; + /// Ok(()) + /// } + /// ``` + pub fn stdout() -> Stdout { + let std = io::stdout(); + Stdout { + std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), + } + } +} + +#[cfg(unix)] +impl std::os::unix::io::AsRawFd for Stdout { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + std::io::stdout().as_raw_fd() + } +} + +#[cfg(windows)] +impl std::os::windows::io::AsRawHandle for Stdout { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + std::io::stdout().as_raw_handle() + } +} + +impl AsyncWrite for Stdout { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Pin::new(&mut self.std).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.std).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Pin::new(&mut self.std).poll_shutdown(cx) + } +} diff --git a/third_party/rust/tokio/src/io/util/async_buf_read_ext.rs b/third_party/rust/tokio/src/io/util/async_buf_read_ext.rs new file mode 100644 index 0000000000..b241e354ba --- /dev/null +++ b/third_party/rust/tokio/src/io/util/async_buf_read_ext.rs @@ -0,0 +1,351 @@ +use crate::io::util::fill_buf::{fill_buf, FillBuf}; +use crate::io::util::lines::{lines, Lines}; +use crate::io::util::read_line::{read_line, ReadLine}; +use crate::io::util::read_until::{read_until, ReadUntil}; +use crate::io::util::split::{split, Split}; +use crate::io::AsyncBufRead; + +cfg_io_util! { + /// An extension trait which adds utility methods to [`AsyncBufRead`] types. + /// + /// [`AsyncBufRead`]: crate::io::AsyncBufRead + pub trait AsyncBufReadExt: AsyncBufRead { + /// Reads all bytes into `buf` until the delimiter `byte` or EOF is reached. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize>; + /// ``` + /// + /// This function will read bytes from the underlying stream until the + /// delimiter or EOF is found. Once found, all bytes up to, and including, + /// the delimiter (if found) will be appended to `buf`. + /// + /// If successful, this function will return the total number of bytes read. + /// + /// If this function returns `Ok(0)`, the stream has reached EOF. + /// + /// # Errors + /// + /// This function will ignore all instances of [`ErrorKind::Interrupted`] and + /// will otherwise return any errors returned by [`fill_buf`]. + /// + /// If an I/O error is encountered then all bytes read so far will be + /// present in `buf` and its length will have been adjusted appropriately. + /// + /// [`fill_buf`]: AsyncBufRead::poll_fill_buf + /// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted + /// + /// # Cancel safety + /// + /// If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then some data may have been partially read. Any + /// partially read bytes are appended to `buf`, and the method can be + /// called again to continue reading until `byte`. + /// + /// This method returns the total number of bytes read. If you cancel + /// the call to `read_until` and then call it again to continue reading, + /// the counter is reset. + /// + /// # Examples + /// + /// [`std::io::Cursor`][`Cursor`] is a type that implements `BufRead`. In + /// this example, we use [`Cursor`] to read all the bytes in a byte slice + /// in hyphen delimited segments: + /// + /// [`Cursor`]: std::io::Cursor + /// + /// ``` + /// use tokio::io::AsyncBufReadExt; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut cursor = Cursor::new(b"lorem-ipsum"); + /// let mut buf = vec![]; + /// + /// // cursor is at 'l' + /// let num_bytes = cursor.read_until(b'-', &mut buf) + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 6); + /// assert_eq!(buf, b"lorem-"); + /// buf.clear(); + /// + /// // cursor is at 'i' + /// let num_bytes = cursor.read_until(b'-', &mut buf) + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 5); + /// assert_eq!(buf, b"ipsum"); + /// buf.clear(); + /// + /// // cursor is at EOF + /// let num_bytes = cursor.read_until(b'-', &mut buf) + /// .await + /// .expect("reading from cursor won't fail"); + /// assert_eq!(num_bytes, 0); + /// assert_eq!(buf, b""); + /// } + /// ``` + fn read_until<'a>(&'a mut self, byte: u8, buf: &'a mut Vec<u8>) -> ReadUntil<'a, Self> + where + Self: Unpin, + { + read_until(self, byte, buf) + } + + /// Reads all bytes until a newline (the 0xA byte) is reached, and append + /// them to the provided buffer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_line(&mut self, buf: &mut String) -> io::Result<usize>; + /// ``` + /// + /// This function will read bytes from the underlying stream until the + /// newline delimiter (the 0xA byte) or EOF is found. Once found, all bytes + /// up to, and including, the delimiter (if found) will be appended to + /// `buf`. + /// + /// If successful, this function will return the total number of bytes read. + /// + /// If this function returns `Ok(0)`, the stream has reached EOF. + /// + /// # Errors + /// + /// This function has the same error semantics as [`read_until`] and will + /// also return an error if the read bytes are not valid UTF-8. If an I/O + /// error is encountered then `buf` may contain some bytes already read in + /// the event that all data read so far was valid UTF-8. + /// + /// [`read_until`]: AsyncBufReadExt::read_until + /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If the method is used as the + /// event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then some data may have been partially + /// read, and this data is lost. There are no guarantees regarding the + /// contents of `buf` when the call is cancelled. The current + /// implementation replaces `buf` with the empty string, but this may + /// change in the future. + /// + /// This function does not behave like [`read_until`] because of the + /// requirement that a string contains only valid utf-8. If you need a + /// cancellation safe `read_line`, there are three options: + /// + /// * Call [`read_until`] with a newline character and manually perform the utf-8 check. + /// * The stream returned by [`lines`] has a cancellation safe + /// [`next_line`] method. + /// * Use [`tokio_util::codec::LinesCodec`][LinesCodec]. + /// + /// [LinesCodec]: https://docs.rs/tokio-util/0.6/tokio_util/codec/struct.LinesCodec.html + /// [`read_until`]: Self::read_until + /// [`lines`]: Self::lines + /// [`next_line`]: crate::io::Lines::next_line + /// + /// # Examples + /// + /// [`std::io::Cursor`][`Cursor`] is a type that implements + /// `AsyncBufRead`. In this example, we use [`Cursor`] to read all the + /// lines in a byte slice: + /// + /// [`Cursor`]: std::io::Cursor + /// + /// ``` + /// use tokio::io::AsyncBufReadExt; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut cursor = Cursor::new(b"foo\nbar"); + /// let mut buf = String::new(); + /// + /// // cursor is at 'f' + /// let num_bytes = cursor.read_line(&mut buf) + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 4); + /// assert_eq!(buf, "foo\n"); + /// buf.clear(); + /// + /// // cursor is at 'b' + /// let num_bytes = cursor.read_line(&mut buf) + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 3); + /// assert_eq!(buf, "bar"); + /// buf.clear(); + /// + /// // cursor is at EOF + /// let num_bytes = cursor.read_line(&mut buf) + /// .await + /// .expect("reading from cursor won't fail"); + /// + /// assert_eq!(num_bytes, 0); + /// assert_eq!(buf, ""); + /// } + /// ``` + fn read_line<'a>(&'a mut self, buf: &'a mut String) -> ReadLine<'a, Self> + where + Self: Unpin, + { + read_line(self, buf) + } + + /// Returns a stream of the contents of this reader split on the byte + /// `byte`. + /// + /// This method is the asynchronous equivalent to + /// [`BufRead::split`](std::io::BufRead::split). + /// + /// The stream returned from this function will yield instances of + /// [`io::Result`]`<`[`Option`]`<`[`Vec<u8>`]`>>`. Each vector returned will *not* have + /// the delimiter byte at the end. + /// + /// [`io::Result`]: std::io::Result + /// [`Option`]: core::option::Option + /// [`Vec<u8>`]: std::vec::Vec + /// + /// # Errors + /// + /// Each item of the stream has the same error semantics as + /// [`AsyncBufReadExt::read_until`](AsyncBufReadExt::read_until). + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncBufRead; + /// use tokio::io::AsyncBufReadExt; + /// + /// # async fn dox(my_buf_read: impl AsyncBufRead + Unpin) -> std::io::Result<()> { + /// let mut segments = my_buf_read.split(b'f'); + /// + /// while let Some(segment) = segments.next_segment().await? { + /// println!("length = {}", segment.len()) + /// } + /// # Ok(()) + /// # } + /// ``` + fn split(self, byte: u8) -> Split<Self> + where + Self: Sized + Unpin, + { + split(self, byte) + } + + /// Returns the contents of the internal buffer, filling it with more + /// data from the inner reader if it is empty. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`consume`] method to function properly. When calling this method, + /// none of the contents will be "read" in the sense that later calling + /// `read` may return the same contents. As such, [`consume`] must be + /// called with the number of bytes that are consumed from this buffer + /// to ensure that the bytes are never returned twice. + /// + /// An empty buffer returned indicates that the stream has reached EOF. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn fill_buf(&mut self) -> io::Result<&[u8]>; + /// ``` + /// + /// # Errors + /// + /// This function will return an I/O error if the underlying reader was + /// read, but returned an error. + /// + /// [`consume`]: crate::io::AsyncBufReadExt::consume + fn fill_buf(&mut self) -> FillBuf<'_, Self> + where + Self: Unpin, + { + fill_buf(self) + } + + /// Tells this buffer that `amt` bytes have been consumed from the + /// buffer, so they should no longer be returned in calls to [`read`]. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`fill_buf`] method to function properly. This function does not + /// perform any I/O, it simply informs this object that some amount of + /// its buffer, returned from [`fill_buf`], has been consumed and should + /// no longer be returned. As such, this function may do odd things if + /// [`fill_buf`] isn't called before calling it. + /// + /// The `amt` must be less than the number of bytes in the buffer + /// returned by [`fill_buf`]. + /// + /// [`read`]: crate::io::AsyncReadExt::read + /// [`fill_buf`]: crate::io::AsyncBufReadExt::fill_buf + fn consume(&mut self, amt: usize) + where + Self: Unpin, + { + std::pin::Pin::new(self).consume(amt) + } + + /// Returns a stream over the lines of this reader. + /// This method is the async equivalent to [`BufRead::lines`](std::io::BufRead::lines). + /// + /// The stream returned from this function will yield instances of + /// [`io::Result`]`<`[`Option`]`<`[`String`]`>>`. Each string returned will *not* have a newline + /// byte (the 0xA byte) or CRLF (0xD, 0xA bytes) at the end. + /// + /// [`io::Result`]: std::io::Result + /// [`Option`]: core::option::Option + /// [`String`]: String + /// + /// # Errors + /// + /// Each line of the stream has the same error semantics as [`AsyncBufReadExt::read_line`]. + /// + /// # Examples + /// + /// [`std::io::Cursor`][`Cursor`] is a type that implements `BufRead`. In + /// this example, we use [`Cursor`] to iterate over all the lines in a byte + /// slice. + /// + /// [`Cursor`]: std::io::Cursor + /// + /// ``` + /// use tokio::io::AsyncBufReadExt; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() { + /// let cursor = Cursor::new(b"lorem\nipsum\r\ndolor"); + /// + /// let mut lines = cursor.lines(); + /// + /// assert_eq!(lines.next_line().await.unwrap(), Some(String::from("lorem"))); + /// assert_eq!(lines.next_line().await.unwrap(), Some(String::from("ipsum"))); + /// assert_eq!(lines.next_line().await.unwrap(), Some(String::from("dolor"))); + /// assert_eq!(lines.next_line().await.unwrap(), None); + /// } + /// ``` + /// + /// [`AsyncBufReadExt::read_line`]: AsyncBufReadExt::read_line + fn lines(self) -> Lines<Self> + where + Self: Sized, + { + lines(self) + } + } +} + +impl<R: AsyncBufRead + ?Sized> AsyncBufReadExt for R {} diff --git a/third_party/rust/tokio/src/io/util/async_read_ext.rs b/third_party/rust/tokio/src/io/util/async_read_ext.rs new file mode 100644 index 0000000000..df5445c2c6 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/async_read_ext.rs @@ -0,0 +1,1294 @@ +use crate::io::util::chain::{chain, Chain}; +use crate::io::util::read::{read, Read}; +use crate::io::util::read_buf::{read_buf, ReadBuf}; +use crate::io::util::read_exact::{read_exact, ReadExact}; +use crate::io::util::read_int::{ReadF32, ReadF32Le, ReadF64, ReadF64Le}; +use crate::io::util::read_int::{ + ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, +}; +use crate::io::util::read_int::{ + ReadU128, ReadU128Le, ReadU16, ReadU16Le, ReadU32, ReadU32Le, ReadU64, ReadU64Le, ReadU8, +}; +use crate::io::util::read_to_end::{read_to_end, ReadToEnd}; +use crate::io::util::read_to_string::{read_to_string, ReadToString}; +use crate::io::util::take::{take, Take}; +use crate::io::AsyncRead; + +use bytes::BufMut; + +cfg_io_util! { + /// Defines numeric reader + macro_rules! read_impl { + ( + $( + $(#[$outer:meta])* + fn $name:ident(&mut self) -> $($fut:ident)*; + )* + ) => { + $( + $(#[$outer])* + fn $name<'a>(&'a mut self) -> $($fut)*<&'a mut Self> where Self: Unpin { + $($fut)*::new(self) + } + )* + } + } + + /// Reads bytes from a source. + /// + /// Implemented as an extension trait, adding utility methods to all + /// [`AsyncRead`] types. Callers will tend to import this trait instead of + /// [`AsyncRead`]. + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{self, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = [0; 10]; + /// + /// // The `read` method is defined by this trait. + /// let n = f.read(&mut buffer[..]).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// See [module][crate::io] documentation for more details. + /// + /// [`AsyncRead`]: AsyncRead + pub trait AsyncReadExt: AsyncRead { + /// Creates a new `AsyncRead` instance that chains this stream with + /// `next`. + /// + /// The returned `AsyncRead` instance will first read all bytes from this object + /// until EOF is encountered. Afterwards the output is equivalent to the + /// output of `next`. + /// + /// # Examples + /// + /// [`File`][crate::fs::File]s implement `AsyncRead`: + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{self, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let f1 = File::open("foo.txt").await?; + /// let f2 = File::open("bar.txt").await?; + /// + /// let mut handle = f1.chain(f2); + /// let mut buffer = String::new(); + /// + /// // read the value into a String. We could use any AsyncRead + /// // method here, this is just one example. + /// handle.read_to_string(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + fn chain<R>(self, next: R) -> Chain<Self, R> + where + Self: Sized, + R: AsyncRead, + { + chain(self, next) + } + + /// Pulls some bytes from this source into the specified buffer, + /// returning how many bytes were read. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize>; + /// ``` + /// + /// This method does not provide any guarantees about whether it + /// completes immediately or asynchronously. + /// + /// # Return + /// + /// If the return value of this method is `Ok(n)`, then it must be + /// guaranteed that `0 <= n <= buf.len()`. A nonzero `n` value indicates + /// that the buffer `buf` has been filled in with `n` bytes of data from + /// this source. If `n` is `0`, then it can indicate one of two + /// scenarios: + /// + /// 1. This reader has reached its "end of file" and will likely no longer + /// be able to produce bytes. Note that this does not mean that the + /// reader will *always* no longer be able to produce bytes. + /// 2. The buffer specified was 0 bytes in length. + /// + /// No guarantees are provided about the contents of `buf` when this + /// function is called, implementations cannot rely on any property of the + /// contents of `buf` being `true`. It is recommended that *implementations* + /// only write data to `buf` instead of reading its contents. + /// + /// Correspondingly, however, *callers* of this method may not assume + /// any guarantees about how the implementation uses `buf`. It is + /// possible that the code that's supposed to write to the buffer might + /// also read from it. It is your responsibility to make sure that `buf` + /// is initialized before calling `read`. + /// + /// # Errors + /// + /// If this function encounters any form of I/O or other error, an error + /// variant will be returned. If an error is returned then it must be + /// guaranteed that no bytes were read. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no data was read. + /// + /// # Examples + /// + /// [`File`][crate::fs::File]s implement `Read`: + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{self, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = [0; 10]; + /// + /// // read up to 10 bytes + /// let n = f.read(&mut buffer[..]).await?; + /// + /// println!("The bytes: {:?}", &buffer[..n]); + /// Ok(()) + /// } + /// ``` + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Read<'a, Self> + where + Self: Unpin, + { + read(self, buf) + } + + /// Pulls some bytes from this source into the specified buffer, + /// advancing the buffer's internal cursor. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<usize>; + /// ``` + /// + /// Usually, only a single `read` syscall is issued, even if there is + /// more space in the supplied buffer. + /// + /// This method does not provide any guarantees about whether it + /// completes immediately or asynchronously. + /// + /// # Return + /// + /// A nonzero `n` value indicates that the buffer `buf` has been filled + /// in with `n` bytes of data from this source. If `n` is `0`, then it + /// can indicate one of two scenarios: + /// + /// 1. This reader has reached its "end of file" and will likely no longer + /// be able to produce bytes. Note that this does not mean that the + /// reader will *always* no longer be able to produce bytes. + /// 2. The buffer specified had a remaining capacity of zero. + /// + /// # Errors + /// + /// If this function encounters any form of I/O or other error, an error + /// variant will be returned. If an error is returned then it must be + /// guaranteed that no bytes were read. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no data was read. + /// + /// # Examples + /// + /// [`File`] implements `Read` and [`BytesMut`] implements [`BufMut`]: + /// + /// [`File`]: crate::fs::File + /// [`BytesMut`]: bytes::BytesMut + /// [`BufMut`]: bytes::BufMut + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use bytes::BytesMut; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = BytesMut::with_capacity(10); + /// + /// assert!(buffer.is_empty()); + /// + /// // read up to 10 bytes, note that the return value is not needed + /// // to access the data that was read as `buffer`'s internal + /// // cursor is updated. + /// f.read_buf(&mut buffer).await?; + /// + /// println!("The bytes: {:?}", &buffer[..]); + /// Ok(()) + /// } + /// ``` + fn read_buf<'a, B>(&'a mut self, buf: &'a mut B) -> ReadBuf<'a, Self, B> + where + Self: Sized + Unpin, + B: BufMut, + { + read_buf(self, buf) + } + + /// Reads the exact number of bytes required to fill `buf`. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<usize>; + /// ``` + /// + /// This function reads as many bytes as necessary to completely fill + /// the specified buffer `buf`. + /// + /// # Errors + /// + /// If the operation encounters an "end of file" before completely + /// filling the buffer, it returns an error of the kind + /// [`ErrorKind::UnexpectedEof`]. The contents of `buf` are unspecified + /// in this case. + /// + /// If any other read error is encountered then the operation + /// immediately returns. The contents of `buf` are unspecified in this + /// case. + /// + /// If this operation returns an error, it is unspecified how many bytes + /// it has read, but it will never read more than would be necessary to + /// completely fill the buffer. + /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If the method is used as the + /// event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then some data may already have been + /// read into `buf`. + /// + /// # Examples + /// + /// [`File`][crate::fs::File]s implement `Read`: + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{self, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = [0; 10]; + /// + /// // read exactly 10 bytes + /// f.read_exact(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + /// + /// [`ErrorKind::UnexpectedEof`]: std::io::ErrorKind::UnexpectedEof + fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, Self> + where + Self: Unpin, + { + read_exact(self, buf) + } + + read_impl! { + /// Reads an unsigned 8 bit integer from the underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u8(&mut self) -> io::Result<u8>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 8 bit integers from an `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![2, 5]); + /// + /// assert_eq!(2, reader.read_u8().await?); + /// assert_eq!(5, reader.read_u8().await?); + /// + /// Ok(()) + /// } + /// ``` + fn read_u8(&mut self) -> ReadU8; + + /// Reads a signed 8 bit integer from the underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i8(&mut self) -> io::Result<i8>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 8 bit integers from an `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x02, 0xfb]); + /// + /// assert_eq!(2, reader.read_i8().await?); + /// assert_eq!(-5, reader.read_i8().await?); + /// + /// Ok(()) + /// } + /// ``` + fn read_i8(&mut self) -> ReadI8; + + /// Reads an unsigned 16-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u16(&mut self) -> io::Result<u16>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 16 bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![2, 5, 3, 0]); + /// + /// assert_eq!(517, reader.read_u16().await?); + /// assert_eq!(768, reader.read_u16().await?); + /// Ok(()) + /// } + /// ``` + fn read_u16(&mut self) -> ReadU16; + + /// Reads a signed 16-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i16(&mut self) -> io::Result<i16>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 16 bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0xc1, 0xff, 0x7c]); + /// + /// assert_eq!(193, reader.read_i16().await?); + /// assert_eq!(-132, reader.read_i16().await?); + /// Ok(()) + /// } + /// ``` + fn read_i16(&mut self) -> ReadI16; + + /// Reads an unsigned 32-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u32(&mut self) -> io::Result<u32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 32-bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0x00, 0x01, 0x0b]); + /// + /// assert_eq!(267, reader.read_u32().await?); + /// Ok(()) + /// } + /// ``` + fn read_u32(&mut self) -> ReadU32; + + /// Reads a signed 32-bit integer in big-endian order from the + /// underlying reader. + /// + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i32(&mut self) -> io::Result<i32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 32-bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7a, 0x33]); + /// + /// assert_eq!(-34253, reader.read_i32().await?); + /// Ok(()) + /// } + /// ``` + fn read_i32(&mut self) -> ReadI32; + + /// Reads an unsigned 64-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u64(&mut self) -> io::Result<u64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 64-bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(918733457491587, reader.read_u64().await?); + /// Ok(()) + /// } + /// ``` + fn read_u64(&mut self) -> ReadU64; + + /// Reads an signed 64-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i64(&mut self) -> io::Result<i64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 64-bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]); + /// + /// assert_eq!(i64::MIN, reader.read_i64().await?); + /// Ok(()) + /// } + /// ``` + fn read_i64(&mut self) -> ReadI64; + + /// Reads an unsigned 128-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u128(&mut self) -> io::Result<u128>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 128-bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83, + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(16947640962301618749969007319746179, reader.read_u128().await?); + /// Ok(()) + /// } + /// ``` + fn read_u128(&mut self) -> ReadU128; + + /// Reads an signed 128-bit integer in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i128(&mut self) -> io::Result<i128>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 128-bit big-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x80, 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0 + /// ]); + /// + /// assert_eq!(i128::MIN, reader.read_i128().await?); + /// Ok(()) + /// } + /// ``` + fn read_i128(&mut self) -> ReadI128; + + /// Reads an 32-bit floating point type in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f32(&mut self) -> io::Result<f32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 32-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0x7f, 0xff, 0xff]); + /// + /// assert_eq!(f32::MIN, reader.read_f32().await?); + /// Ok(()) + /// } + /// ``` + fn read_f32(&mut self) -> ReadF32; + + /// Reads an 64-bit floating point type in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f64(&mut self) -> io::Result<f64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 64-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0xff, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + /// ]); + /// + /// assert_eq!(f64::MIN, reader.read_f64().await?); + /// Ok(()) + /// } + /// ``` + fn read_f64(&mut self) -> ReadF64; + + /// Reads an unsigned 16-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u16_le(&mut self) -> io::Result<u16>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 16 bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![2, 5, 3, 0]); + /// + /// assert_eq!(1282, reader.read_u16_le().await?); + /// assert_eq!(3, reader.read_u16_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u16_le(&mut self) -> ReadU16Le; + + /// Reads a signed 16-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i16_le(&mut self) -> io::Result<i16>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 16 bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0xc1, 0xff, 0x7c]); + /// + /// assert_eq!(-16128, reader.read_i16_le().await?); + /// assert_eq!(31999, reader.read_i16_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i16_le(&mut self) -> ReadI16Le; + + /// Reads an unsigned 32-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u32_le(&mut self) -> io::Result<u32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 32-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x00, 0x00, 0x01, 0x0b]); + /// + /// assert_eq!(184614912, reader.read_u32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u32_le(&mut self) -> ReadU32Le; + + /// Reads a signed 32-bit integer in little-endian order from the + /// underlying reader. + /// + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i32_le(&mut self) -> io::Result<i32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 32-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7a, 0x33]); + /// + /// assert_eq!(863698943, reader.read_i32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i32_le(&mut self) -> ReadI32Le; + + /// Reads an unsigned 64-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u64_le(&mut self) -> io::Result<u64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 64-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(9477368352180732672, reader.read_u64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u64_le(&mut self) -> ReadU64Le; + + /// Reads an signed 64-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i64_le(&mut self) -> io::Result<i64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 64-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]); + /// + /// assert_eq!(128, reader.read_i64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i64_le(&mut self) -> ReadI64Le; + + /// Reads an unsigned 128-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_u128_le(&mut self) -> io::Result<u128>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read unsigned 128-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83, + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// + /// assert_eq!(174826588484952389081207917399662330624, reader.read_u128_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_u128_le(&mut self) -> ReadU128Le; + + /// Reads an signed 128-bit integer in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_i128_le(&mut self) -> io::Result<i128>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read signed 128-bit little-endian integers from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0x80, 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0 + /// ]); + /// + /// assert_eq!(128, reader.read_i128_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_i128_le(&mut self) -> ReadI128Le; + + /// Reads an 32-bit floating point type in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f32_le(&mut self) -> io::Result<f32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 32-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7f, 0xff]); + /// + /// assert_eq!(f32::MIN, reader.read_f32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_f32_le(&mut self) -> ReadF32Le; + + /// Reads an 64-bit floating point type in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f64_le(&mut self) -> io::Result<f64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 64-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0xff + /// ]); + /// + /// assert_eq!(f64::MIN, reader.read_f64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_f64_le(&mut self) -> ReadF64Le; + } + + /// Reads all bytes until EOF in this source, placing them into `buf`. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize>; + /// ``` + /// + /// All bytes read from this source will be appended to the specified + /// buffer `buf`. This function will continuously call [`read()`] to + /// append more data to `buf` until [`read()`] returns `Ok(0)`. + /// + /// If successful, the total number of bytes read is returned. + /// + /// [`read()`]: AsyncReadExt::read + /// + /// # Errors + /// + /// If a read error is encountered then the `read_to_end` operation + /// immediately completes. Any bytes which have already been read will + /// be appended to `buf`. + /// + /// # Examples + /// + /// [`File`][crate::fs::File]s implement `Read`: + /// + /// ```no_run + /// use tokio::io::{self, AsyncReadExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = Vec::new(); + /// + /// // read the whole file + /// f.read_to_end(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + /// + /// (See also the [`tokio::fs::read`] convenience function for reading from a + /// file.) + /// + /// [`tokio::fs::read`]: fn@crate::fs::read + fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self> + where + Self: Unpin, + { + read_to_end(self, buf) + } + + /// Reads all bytes until EOF in this source, appending them to `buf`. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize>; + /// ``` + /// + /// If successful, the number of bytes which were read and appended to + /// `buf` is returned. + /// + /// # Errors + /// + /// If the data in this stream is *not* valid UTF-8 then an error is + /// returned and `buf` is unchanged. + /// + /// See [`read_to_end`][AsyncReadExt::read_to_end] for other error semantics. + /// + /// # Examples + /// + /// [`File`][crate::fs::File]s implement `Read`: + /// + /// ```no_run + /// use tokio::io::{self, AsyncReadExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt").await?; + /// let mut buffer = String::new(); + /// + /// f.read_to_string(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + /// + /// (See also the [`crate::fs::read_to_string`] convenience function for + /// reading from a file.) + /// + /// [`crate::fs::read_to_string`]: fn@crate::fs::read_to_string + fn read_to_string<'a>(&'a mut self, dst: &'a mut String) -> ReadToString<'a, Self> + where + Self: Unpin, + { + read_to_string(self, dst) + } + + /// Creates an adaptor which reads at most `limit` bytes from it. + /// + /// This function returns a new instance of `AsyncRead` which will read + /// at most `limit` bytes, after which it will always return EOF + /// (`Ok(0)`). Any read errors will not count towards the number of + /// bytes read and future calls to [`read()`] may succeed. + /// + /// [`read()`]: fn@crate::io::AsyncReadExt::read + /// + /// [read]: AsyncReadExt::read + /// + /// # Examples + /// + /// [`File`][crate::fs::File]s implement `Read`: + /// + /// ```no_run + /// use tokio::io::{self, AsyncReadExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let f = File::open("foo.txt").await?; + /// let mut buffer = [0; 5]; + /// + /// // read at most five bytes + /// let mut handle = f.take(5); + /// + /// handle.read(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + fn take(self, limit: u64) -> Take<Self> + where + Self: Sized, + { + take(self, limit) + } + } +} + +impl<R: AsyncRead + ?Sized> AsyncReadExt for R {} diff --git a/third_party/rust/tokio/src/io/util/async_seek_ext.rs b/third_party/rust/tokio/src/io/util/async_seek_ext.rs new file mode 100644 index 0000000000..46b3e6c0d3 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/async_seek_ext.rs @@ -0,0 +1,93 @@ +use crate::io::seek::{seek, Seek}; +use crate::io::AsyncSeek; +use std::io::SeekFrom; + +cfg_io_util! { + /// An extension trait that adds utility methods to [`AsyncSeek`] types. + /// + /// # Examples + /// + /// ``` + /// use std::io::{self, Cursor, SeekFrom}; + /// use tokio::io::{AsyncSeekExt, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut cursor = Cursor::new(b"abcdefg"); + /// + /// // the `seek` method is defined by this trait + /// cursor.seek(SeekFrom::Start(3)).await?; + /// + /// let mut buf = [0; 1]; + /// let n = cursor.read(&mut buf).await?; + /// assert_eq!(n, 1); + /// assert_eq!(buf, [b'd']); + /// + /// Ok(()) + /// } + /// ``` + /// + /// See [module][crate::io] documentation for more details. + /// + /// [`AsyncSeek`]: AsyncSeek + pub trait AsyncSeekExt: AsyncSeek { + /// Creates a future which will seek an IO object, and then yield the + /// new position in the object and the object itself. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64>; + /// ``` + /// + /// In the case of an error the buffer and the object will be discarded, with + /// the error yielded. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::io::{AsyncSeekExt, AsyncReadExt}; + /// + /// use std::io::SeekFrom; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::open("foo.txt").await?; + /// file.seek(SeekFrom::Start(6)).await?; + /// + /// let mut contents = vec![0u8; 10]; + /// file.read_exact(&mut contents).await?; + /// # Ok(()) + /// # } + /// ``` + fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> + where + Self: Unpin, + { + seek(self, pos) + } + + /// Creates a future which will rewind to the beginning of the stream. + /// + /// This is convenience method, equivalent to to `self.seek(SeekFrom::Start(0))`. + fn rewind(&mut self) -> Seek<'_, Self> + where + Self: Unpin, + { + self.seek(SeekFrom::Start(0)) + } + + /// Creates a future which will return the current seek position from the + /// start of the stream. + /// + /// This is equivalent to `self.seek(SeekFrom::Current(0))`. + fn stream_position(&mut self) -> Seek<'_, Self> + where + Self: Unpin, + { + self.seek(SeekFrom::Current(0)) + } + } +} + +impl<S: AsyncSeek + ?Sized> AsyncSeekExt for S {} diff --git a/third_party/rust/tokio/src/io/util/async_write_ext.rs b/third_party/rust/tokio/src/io/util/async_write_ext.rs new file mode 100644 index 0000000000..93a318315e --- /dev/null +++ b/third_party/rust/tokio/src/io/util/async_write_ext.rs @@ -0,0 +1,1293 @@ +use crate::io::util::flush::{flush, Flush}; +use crate::io::util::shutdown::{shutdown, Shutdown}; +use crate::io::util::write::{write, Write}; +use crate::io::util::write_all::{write_all, WriteAll}; +use crate::io::util::write_all_buf::{write_all_buf, WriteAllBuf}; +use crate::io::util::write_buf::{write_buf, WriteBuf}; +use crate::io::util::write_int::{WriteF32, WriteF32Le, WriteF64, WriteF64Le}; +use crate::io::util::write_int::{ + WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le, + WriteI8, +}; +use crate::io::util::write_int::{ + WriteU128, WriteU128Le, WriteU16, WriteU16Le, WriteU32, WriteU32Le, WriteU64, WriteU64Le, + WriteU8, +}; +use crate::io::util::write_vectored::{write_vectored, WriteVectored}; +use crate::io::AsyncWrite; +use std::io::IoSlice; + +use bytes::Buf; + +cfg_io_util! { + /// Defines numeric writer. + macro_rules! write_impl { + ( + $( + $(#[$outer:meta])* + fn $name:ident(&mut self, n: $ty:ty) -> $($fut:ident)*; + )* + ) => { + $( + $(#[$outer])* + fn $name<'a>(&'a mut self, n: $ty) -> $($fut)*<&'a mut Self> where Self: Unpin { + $($fut)*::new(self, n) + } + )* + } + } + + /// Writes bytes to a sink. + /// + /// Implemented as an extension trait, adding utility methods to all + /// [`AsyncWrite`] types. Callers will tend to import this trait instead of + /// [`AsyncWrite`]. + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let data = b"some bytes"; + /// + /// let mut pos = 0; + /// let mut buffer = File::create("foo.txt").await?; + /// + /// while pos < data.len() { + /// let bytes_written = buffer.write(&data[pos..]).await?; + /// pos += bytes_written; + /// } + /// + /// Ok(()) + /// } + /// ``` + /// + /// See [module][crate::io] documentation for more details. + /// + /// [`AsyncWrite`]: AsyncWrite + pub trait AsyncWriteExt: AsyncWrite { + /// Writes a buffer into this writer, returning how many bytes were + /// written. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write(&mut self, buf: &[u8]) -> io::Result<usize>; + /// ``` + /// + /// This function will attempt to write the entire contents of `buf`, but + /// the entire write may not succeed, or the write may also generate an + /// error. A call to `write` represents *at most one* attempt to write to + /// any wrapped object. + /// + /// # Return + /// + /// If the return value is `Ok(n)` then it must be guaranteed that `n <= + /// buf.len()`. A return value of `0` typically means that the + /// underlying object is no longer able to accept bytes and will likely + /// not be able to in the future as well, or that the buffer provided is + /// empty. + /// + /// # Errors + /// + /// Each call to `write` may generate an I/O error indicating that the + /// operation could not be completed. If an error is returned then no bytes + /// in the buffer were written to this writer. + /// + /// It is **not** considered an error if the entire buffer could not be + /// written to this writer. + /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// + /// // Writes some prefix of the byte string, not necessarily all of it. + /// file.write(b"some bytes").await?; + /// Ok(()) + /// } + /// ``` + fn write<'a>(&'a mut self, src: &'a [u8]) -> Write<'a, Self> + where + Self: Unpin, + { + write(self, src) + } + + /// Like [`write`], except that it writes from a slice of buffers. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>; + /// ``` + /// + /// See [`AsyncWrite::poll_write_vectored`] for more details. + /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// use std::io::IoSlice; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// + /// let bufs: &[_] = &[ + /// IoSlice::new(b"hello"), + /// IoSlice::new(b" "), + /// IoSlice::new(b"world"), + /// ]; + /// + /// file.write_vectored(&bufs).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`write`]: AsyncWriteExt::write + fn write_vectored<'a, 'b>(&'a mut self, bufs: &'a [IoSlice<'b>]) -> WriteVectored<'a, 'b, Self> + where + Self: Unpin, + { + write_vectored(self, bufs) + } + + /// Writes a buffer into this writer, advancing the buffer's internal + /// cursor. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_buf<B: Buf>(&mut self, buf: &mut B) -> io::Result<usize>; + /// ``` + /// + /// This function will attempt to write the entire contents of `buf`, but + /// the entire write may not succeed, or the write may also generate an + /// error. After the operation completes, the buffer's + /// internal cursor is advanced by the number of bytes written. A + /// subsequent call to `write_buf` using the **same** `buf` value will + /// resume from the point that the first call to `write_buf` completed. + /// A call to `write_buf` represents *at most one* attempt to write to any + /// wrapped object. + /// + /// # Return + /// + /// If the return value is `Ok(n)` then it must be guaranteed that `n <= + /// buf.len()`. A return value of `0` typically means that the + /// underlying object is no longer able to accept bytes and will likely + /// not be able to in the future as well, or that the buffer provided is + /// empty. + /// + /// # Errors + /// + /// Each call to `write` may generate an I/O error indicating that the + /// operation could not be completed. If an error is returned then no bytes + /// in the buffer were written to this writer. + /// + /// It is **not** considered an error if the entire buffer could not be + /// written to this writer. + /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]: + /// + /// [`File`]: crate::fs::File + /// [`Buf`]: bytes::Buf + /// [`Cursor`]: std::io::Cursor + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// use bytes::Buf; + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buffer = Cursor::new(b"data to write"); + /// + /// // Loop until the entire contents of the buffer are written to + /// // the file. + /// while buffer.has_remaining() { + /// // Writes some prefix of the byte string, not necessarily + /// // all of it. + /// file.write_buf(&mut buffer).await?; + /// } + /// + /// Ok(()) + /// } + /// ``` + fn write_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteBuf<'a, Self, B> + where + Self: Sized + Unpin, + B: Buf, + { + write_buf(self, src) + } + + /// Attempts to write an entire buffer into this writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_all_buf(&mut self, buf: impl Buf) -> Result<(), io::Error> { + /// while buf.has_remaining() { + /// self.write_buf(&mut buf).await?; + /// } + /// Ok(()) + /// } + /// ``` + /// + /// This method will continuously call [`write`] until + /// [`buf.has_remaining()`](bytes::Buf::has_remaining) returns false. This method will not + /// return until the entire buffer has been successfully written or an error occurs. The + /// first error generated will be returned. + /// + /// The buffer is advanced after each chunk is successfully written. After failure, + /// `src.chunk()` will return the chunk that failed to write. + /// + /// # Cancel safety + /// + /// If `write_all_buf` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then the data in the provided buffer may have been + /// partially written. However, it is guaranteed that the provided + /// buffer has been [advanced] by the amount of bytes that have been + /// partially written. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]: + /// + /// [`File`]: crate::fs::File + /// [`Buf`]: bytes::Buf + /// [`Cursor`]: std::io::Cursor + /// [advanced]: bytes::Buf::advance + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buffer = Cursor::new(b"data to write"); + /// + /// file.write_all_buf(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + /// + /// [`write`]: AsyncWriteExt::write + fn write_all_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteAllBuf<'a, Self, B> + where + Self: Sized + Unpin, + B: Buf, + { + write_all_buf(self, src) + } + + /// Attempts to write an entire buffer into this writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_all(&mut self, buf: &[u8]) -> io::Result<()>; + /// ``` + /// + /// This method will continuously call [`write`] until there is no more data + /// to be written. This method will not return until the entire buffer + /// has been successfully written or such an error occurs. The first + /// error generated from this method will be returned. + /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If it is used as the event + /// in a [`tokio::select!`](crate::select) statement and some other + /// branch completes first, then the provided buffer may have been + /// partially written, but future calls to `write_all` will start over + /// from the beginning of the buffer. + /// + /// # Errors + /// + /// This function will return the first error that [`write`] returns. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// + /// file.write_all(b"some bytes").await?; + /// Ok(()) + /// } + /// ``` + /// + /// [`write`]: AsyncWriteExt::write + fn write_all<'a>(&'a mut self, src: &'a [u8]) -> WriteAll<'a, Self> + where + Self: Unpin, + { + write_all(self, src) + } + + write_impl! { + /// Writes an unsigned 8-bit integer to the underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u8(&mut self, n: u8) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 8 bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u8(2).await?; + /// writer.write_u8(5).await?; + /// + /// assert_eq!(writer, b"\x02\x05"); + /// Ok(()) + /// } + /// ``` + fn write_u8(&mut self, n: u8) -> WriteU8; + + /// Writes an unsigned 8-bit integer to the underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i8(&mut self, n: i8) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 8 bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u8(2).await?; + /// writer.write_u8(5).await?; + /// + /// assert_eq!(writer, b"\x02\x05"); + /// Ok(()) + /// } + /// ``` + fn write_i8(&mut self, n: i8) -> WriteI8; + + /// Writes an unsigned 16-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u16(&mut self, n: u16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u16(517).await?; + /// writer.write_u16(768).await?; + /// + /// assert_eq!(writer, b"\x02\x05\x03\x00"); + /// Ok(()) + /// } + /// ``` + fn write_u16(&mut self, n: u16) -> WriteU16; + + /// Writes a signed 16-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i16(&mut self, n: i16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i16(193).await?; + /// writer.write_i16(-132).await?; + /// + /// assert_eq!(writer, b"\x00\xc1\xff\x7c"); + /// Ok(()) + /// } + /// ``` + fn write_i16(&mut self, n: i16) -> WriteI16; + + /// Writes an unsigned 32-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u32(&mut self, n: u32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u32(267).await?; + /// writer.write_u32(1205419366).await?; + /// + /// assert_eq!(writer, b"\x00\x00\x01\x0b\x47\xd9\x3d\x66"); + /// Ok(()) + /// } + /// ``` + fn write_u32(&mut self, n: u32) -> WriteU32; + + /// Writes a signed 32-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i32(&mut self, n: i32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i32(267).await?; + /// writer.write_i32(1205419366).await?; + /// + /// assert_eq!(writer, b"\x00\x00\x01\x0b\x47\xd9\x3d\x66"); + /// Ok(()) + /// } + /// ``` + fn write_i32(&mut self, n: i32) -> WriteI32; + + /// Writes an unsigned 64-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u64(&mut self, n: u64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u64(918733457491587).await?; + /// writer.write_u64(143).await?; + /// + /// assert_eq!(writer, b"\x00\x03\x43\x95\x4d\x60\x86\x83\x00\x00\x00\x00\x00\x00\x00\x8f"); + /// Ok(()) + /// } + /// ``` + fn write_u64(&mut self, n: u64) -> WriteU64; + + /// Writes an signed 64-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i64(&mut self, n: i64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i64(i64::MIN).await?; + /// writer.write_i64(i64::MAX).await?; + /// + /// assert_eq!(writer, b"\x80\x00\x00\x00\x00\x00\x00\x00\x7f\xff\xff\xff\xff\xff\xff\xff"); + /// Ok(()) + /// } + /// ``` + fn write_i64(&mut self, n: i64) -> WriteI64; + + /// Writes an unsigned 128-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u128(&mut self, n: u128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u128(16947640962301618749969007319746179).await?; + /// + /// assert_eq!(writer, vec![ + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83, + /// 0x00, 0x03, 0x43, 0x95, 0x4d, 0x60, 0x86, 0x83 + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_u128(&mut self, n: u128) -> WriteU128; + + /// Writes an signed 128-bit integer in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i128(&mut self, n: i128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i128(i128::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0x80, 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0 + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_i128(&mut self, n: i128) -> WriteI128; + + /// Writes an 32-bit floating point type in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f32(&mut self, n: f32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 32-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f32(f32::MIN).await?; + /// + /// assert_eq!(writer, vec![0xff, 0x7f, 0xff, 0xff]); + /// Ok(()) + /// } + /// ``` + fn write_f32(&mut self, n: f32) -> WriteF32; + + /// Writes an 64-bit floating point type in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f64(&mut self, n: f64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 64-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f64(f64::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0xff, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_f64(&mut self, n: f64) -> WriteF64; + + /// Writes an unsigned 16-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u16_le(&mut self, n: u16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u16_le(517).await?; + /// writer.write_u16_le(768).await?; + /// + /// assert_eq!(writer, b"\x05\x02\x00\x03"); + /// Ok(()) + /// } + /// ``` + fn write_u16_le(&mut self, n: u16) -> WriteU16Le; + + /// Writes a signed 16-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i16_le(&mut self, n: i16) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 16-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i16_le(193).await?; + /// writer.write_i16_le(-132).await?; + /// + /// assert_eq!(writer, b"\xc1\x00\x7c\xff"); + /// Ok(()) + /// } + /// ``` + fn write_i16_le(&mut self, n: i16) -> WriteI16Le; + + /// Writes an unsigned 32-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u32_le(&mut self, n: u32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u32_le(267).await?; + /// writer.write_u32_le(1205419366).await?; + /// + /// assert_eq!(writer, b"\x0b\x01\x00\x00\x66\x3d\xd9\x47"); + /// Ok(()) + /// } + /// ``` + fn write_u32_le(&mut self, n: u32) -> WriteU32Le; + + /// Writes a signed 32-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i32_le(&mut self, n: i32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 32-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i32_le(267).await?; + /// writer.write_i32_le(1205419366).await?; + /// + /// assert_eq!(writer, b"\x0b\x01\x00\x00\x66\x3d\xd9\x47"); + /// Ok(()) + /// } + /// ``` + fn write_i32_le(&mut self, n: i32) -> WriteI32Le; + + /// Writes an unsigned 64-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u64_le(&mut self, n: u64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u64_le(918733457491587).await?; + /// writer.write_u64_le(143).await?; + /// + /// assert_eq!(writer, b"\x83\x86\x60\x4d\x95\x43\x03\x00\x8f\x00\x00\x00\x00\x00\x00\x00"); + /// Ok(()) + /// } + /// ``` + fn write_u64_le(&mut self, n: u64) -> WriteU64Le; + + /// Writes an signed 64-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i64_le(&mut self, n: i64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 64-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i64_le(i64::MIN).await?; + /// writer.write_i64_le(i64::MAX).await?; + /// + /// assert_eq!(writer, b"\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f"); + /// Ok(()) + /// } + /// ``` + fn write_i64_le(&mut self, n: i64) -> WriteI64Le; + + /// Writes an unsigned 128-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_u128_le(&mut self, n: u128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write unsigned 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_u128_le(16947640962301618749969007319746179).await?; + /// + /// assert_eq!(writer, vec![ + /// 0x83, 0x86, 0x60, 0x4d, 0x95, 0x43, 0x03, 0x00, + /// 0x83, 0x86, 0x60, 0x4d, 0x95, 0x43, 0x03, 0x00, + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_u128_le(&mut self, n: u128) -> WriteU128Le; + + /// Writes an signed 128-bit integer in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_i128_le(&mut self, n: i128) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write signed 128-bit integers to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_i128_le(i128::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0, 0, 0, 0, 0, 0, 0, + /// 0, 0, 0, 0, 0, 0, 0, 0, 0x80 + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_i128_le(&mut self, n: i128) -> WriteI128Le; + + /// Writes an 32-bit floating point type in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f32_le(&mut self, n: f32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 32-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f32_le(f32::MIN).await?; + /// + /// assert_eq!(writer, vec![0xff, 0xff, 0x7f, 0xff]); + /// Ok(()) + /// } + /// ``` + fn write_f32_le(&mut self, n: f32) -> WriteF32Le; + + /// Writes an 64-bit floating point type in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f64_le(&mut self, n: f64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 64-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f64_le(f64::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0xff + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_f64_le(&mut self, n: f64) -> WriteF64Le; + } + + /// Flushes this output stream, ensuring that all intermediately buffered + /// contents reach their destination. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn flush(&mut self) -> io::Result<()>; + /// ``` + /// + /// # Errors + /// + /// It is considered an error if not all bytes could be written due to + /// I/O errors or EOF being reached. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, BufWriter, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let f = File::create("foo.txt").await?; + /// let mut buffer = BufWriter::new(f); + /// + /// buffer.write_all(b"some bytes").await?; + /// buffer.flush().await?; + /// Ok(()) + /// } + /// ``` + fn flush(&mut self) -> Flush<'_, Self> + where + Self: Unpin, + { + flush(self) + } + + /// Shuts down the output stream, ensuring that the value can be dropped + /// cleanly. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn shutdown(&mut self) -> io::Result<()>; + /// ``` + /// + /// Similar to [`flush`], all intermediately buffered is written to the + /// underlying stream. Once the operation completes, the caller should + /// no longer attempt to write to the stream. For example, the + /// `TcpStream` implementation will issue a `shutdown(Write)` sys call. + /// + /// [`flush`]: fn@crate::io::AsyncWriteExt::flush + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, BufWriter, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let f = File::create("foo.txt").await?; + /// let mut buffer = BufWriter::new(f); + /// + /// buffer.write_all(b"some bytes").await?; + /// buffer.shutdown().await?; + /// Ok(()) + /// } + /// ``` + fn shutdown(&mut self) -> Shutdown<'_, Self> + where + Self: Unpin, + { + shutdown(self) + } + } +} + +impl<W: AsyncWrite + ?Sized> AsyncWriteExt for W {} diff --git a/third_party/rust/tokio/src/io/util/buf_reader.rs b/third_party/rust/tokio/src/io/util/buf_reader.rs new file mode 100644 index 0000000000..60879c0fdc --- /dev/null +++ b/third_party/rust/tokio/src/io/util/buf_reader.rs @@ -0,0 +1,311 @@ +use crate::io::util::DEFAULT_BUF_SIZE; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; + +use pin_project_lite::pin_project; +use std::io::{self, IoSlice, SeekFrom}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{cmp, fmt, mem}; + +pin_project! { + /// The `BufReader` struct adds buffering to any reader. + /// + /// It can be excessively inefficient to work directly with a [`AsyncRead`] + /// instance. A `BufReader` performs large, infrequent reads on the underlying + /// [`AsyncRead`] and maintains an in-memory buffer of the results. + /// + /// `BufReader` can improve the speed of programs that make *small* and + /// *repeated* read calls to the same file or network socket. It does not + /// help when reading very large amounts at once, or reading just one or a few + /// times. It also provides no advantage when reading from a source that is + /// already in memory, like a `Vec<u8>`. + /// + /// When the `BufReader` is dropped, the contents of its buffer will be + /// discarded. Creating multiple instances of a `BufReader` on the same + /// stream can cause data loss. + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct BufReader<R> { + #[pin] + pub(super) inner: R, + pub(super) buf: Box<[u8]>, + pub(super) pos: usize, + pub(super) cap: usize, + pub(super) seek_state: SeekState, + } +} + +impl<R: AsyncRead> BufReader<R> { + /// Creates a new `BufReader` with a default buffer capacity. The default is currently 8 KB, + /// but may change in the future. + pub fn new(inner: R) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, inner) + } + + /// Creates a new `BufReader` with the specified buffer capacity. + pub fn with_capacity(capacity: usize, inner: R) -> Self { + let buffer = vec![0; capacity]; + Self { + inner, + buf: buffer.into_boxed_slice(), + pos: 0, + cap: 0, + seek_state: SeekState::Init, + } + } + + /// Gets a reference to the underlying reader. + /// + /// It is inadvisable to directly read from the underlying reader. + pub fn get_ref(&self) -> &R { + &self.inner + } + + /// Gets a mutable reference to the underlying reader. + /// + /// It is inadvisable to directly read from the underlying reader. + pub fn get_mut(&mut self) -> &mut R { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying reader. + /// + /// It is inadvisable to directly read from the underlying reader. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().inner + } + + /// Consumes this `BufReader`, returning the underlying reader. + /// + /// Note that any leftover data in the internal buffer is lost. + pub fn into_inner(self) -> R { + self.inner + } + + /// Returns a reference to the internally buffered data. + /// + /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is empty. + pub fn buffer(&self) -> &[u8] { + &self.buf[self.pos..self.cap] + } + + /// Invalidates all data in the internal buffer. + #[inline] + fn discard_buffer(self: Pin<&mut Self>) { + let me = self.project(); + *me.pos = 0; + *me.cap = 0; + } +} + +impl<R: AsyncRead> AsyncRead for BufReader<R> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // If we don't have any buffered data and we're doing a massive read + // (larger than our internal buffer), bypass our internal buffer + // entirely. + if self.pos == self.cap && buf.remaining() >= self.buf.len() { + let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); + self.discard_buffer(); + return Poll::Ready(res); + } + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[..amt]); + self.consume(amt); + Poll::Ready(Ok(())) + } +} + +impl<R: AsyncRead> AsyncBufRead for BufReader<R> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + let me = self.project(); + + // If we've reached the end of our internal buffer then we need to fetch + // some more data from the underlying reader. + // Branch using `>=` instead of the more correct `==` + // to tell the compiler that the pos..cap slice is always valid. + if *me.pos >= *me.cap { + debug_assert!(*me.pos == *me.cap); + let mut buf = ReadBuf::new(me.buf); + ready!(me.inner.poll_read(cx, &mut buf))?; + *me.cap = buf.filled().len(); + *me.pos = 0; + } + Poll::Ready(Ok(&me.buf[*me.pos..*me.cap])) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + *me.pos = cmp::min(*me.pos + amt, *me.cap); + } +} + +#[derive(Debug, Clone, Copy)] +pub(super) enum SeekState { + /// start_seek has not been called. + Init, + /// start_seek has been called, but poll_complete has not yet been called. + Start(SeekFrom), + /// Waiting for completion of the first poll_complete in the `n.checked_sub(remainder).is_none()` branch. + PendingOverflowed(i64), + /// Waiting for completion of poll_complete. + Pending, +} + +/// Seeks to an offset, in bytes, in the underlying reader. +/// +/// The position used for seeking with `SeekFrom::Current(_)` is the +/// position the underlying reader would be at if the `BufReader` had no +/// internal buffer. +/// +/// Seeking always discards the internal buffer, even if the seek position +/// would otherwise fall within it. This guarantees that calling +/// `.into_inner()` immediately after a seek yields the underlying reader +/// at the same position. +/// +/// See [`AsyncSeek`] for more details. +/// +/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` +/// where `n` minus the internal buffer length overflows an `i64`, two +/// seeks will be performed instead of one. If the second seek returns +/// `Err`, the underlying reader will be left at the same position it would +/// have if you called `seek` with `SeekFrom::Current(0)`. +impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> { + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + // We needs to call seek operation multiple times. + // And we should always call both start_seek and poll_complete, + // as start_seek alone cannot guarantee that the operation will be completed. + // poll_complete receives a Context and returns a Poll, so it cannot be called + // inside start_seek. + *self.project().seek_state = SeekState::Start(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) { + SeekState::Init => { + // 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + SeekState::Start(SeekFrom::Current(n)) => { + let remainder = (self.cap - self.pos) as i64; + // it should be safe to assume that remainder fits within an i64 as the alternative + // means we managed to allocate 8 exbibytes and that's absurd. + // But it's not out of the realm of possibility for some weird underlying reader to + // support seeking by i64::MIN so we need to handle underflow when subtracting + // remainder. + if let Some(offset) = n.checked_sub(remainder) { + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(offset))?; + } else { + // seek backwards by our remainder, and then by the offset + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(-remainder))?; + if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { + *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); + return Poll::Pending; + } + + // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 + self.as_mut().discard_buffer(); + + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(n))?; + } + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::PendingOverflowed(n) => { + if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { + *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); + return Poll::Pending; + } + + // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 + self.as_mut().discard_buffer(); + + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(n))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::Start(pos) => { + // Seeking with Start/End doesn't care about our buffer length. + self.as_mut().get_pin_mut().start_seek(pos)?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?, + }; + + match res { + Poll::Ready(res) => { + self.discard_buffer(); + Poll::Ready(Ok(res)) + } + Poll::Pending => { + *self.as_mut().project().seek_state = SeekState::Pending; + Poll::Pending + } + } + } +} + +impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.get_pin_mut().poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.get_pin_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.get_ref().is_write_vectored() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.get_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.get_pin_mut().poll_shutdown(cx) + } +} + +impl<R: fmt::Debug> fmt::Debug for BufReader<R> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BufReader") + .field("reader", &self.inner) + .field( + "buffer", + &format_args!("{}/{}", self.cap - self.pos, self.buf.len()), + ) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<BufReader<()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/buf_stream.rs b/third_party/rust/tokio/src/io/util/buf_stream.rs new file mode 100644 index 0000000000..595c142aca --- /dev/null +++ b/third_party/rust/tokio/src/io/util/buf_stream.rs @@ -0,0 +1,207 @@ +use crate::io::util::{BufReader, BufWriter}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; + +use pin_project_lite::pin_project; +use std::io::{self, IoSlice, SeekFrom}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Wraps a type that is [`AsyncWrite`] and [`AsyncRead`], and buffers its input and output. + /// + /// It can be excessively inefficient to work directly with something that implements [`AsyncWrite`] + /// and [`AsyncRead`]. For example, every `write`, however small, has to traverse the syscall + /// interface, and similarly, every read has to do the same. The [`BufWriter`] and [`BufReader`] + /// types aid with these problems respectively, but do so in only one direction. `BufStream` wraps + /// one in the other so that both directions are buffered. See their documentation for details. + #[derive(Debug)] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct BufStream<RW> { + #[pin] + inner: BufReader<BufWriter<RW>>, + } +} + +impl<RW: AsyncRead + AsyncWrite> BufStream<RW> { + /// Wraps a type in both [`BufWriter`] and [`BufReader`]. + /// + /// See the documentation for those types and [`BufStream`] for details. + pub fn new(stream: RW) -> BufStream<RW> { + BufStream { + inner: BufReader::new(BufWriter::new(stream)), + } + } + + /// Creates a `BufStream` with the specified [`BufReader`] capacity and [`BufWriter`] + /// capacity. + /// + /// See the documentation for those types and [`BufStream`] for details. + pub fn with_capacity( + reader_capacity: usize, + writer_capacity: usize, + stream: RW, + ) -> BufStream<RW> { + BufStream { + inner: BufReader::with_capacity( + reader_capacity, + BufWriter::with_capacity(writer_capacity, stream), + ), + } + } + + /// Gets a reference to the underlying I/O object. + /// + /// It is inadvisable to directly read from the underlying I/O object. + pub fn get_ref(&self) -> &RW { + self.inner.get_ref().get_ref() + } + + /// Gets a mutable reference to the underlying I/O object. + /// + /// It is inadvisable to directly read from the underlying I/O object. + pub fn get_mut(&mut self) -> &mut RW { + self.inner.get_mut().get_mut() + } + + /// Gets a pinned mutable reference to the underlying I/O object. + /// + /// It is inadvisable to directly read from the underlying I/O object. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut RW> { + self.project().inner.get_pin_mut().get_pin_mut() + } + + /// Consumes this `BufStream`, returning the underlying I/O object. + /// + /// Note that any leftover data in the internal buffer is lost. + pub fn into_inner(self) -> RW { + self.inner.into_inner().into_inner() + } +} + +impl<RW> From<BufReader<BufWriter<RW>>> for BufStream<RW> { + fn from(b: BufReader<BufWriter<RW>>) -> Self { + BufStream { inner: b } + } +} + +impl<RW> From<BufWriter<BufReader<RW>>> for BufStream<RW> { + fn from(b: BufWriter<BufReader<RW>>) -> Self { + // we need to "invert" the reader and writer + let BufWriter { + inner: + BufReader { + inner, + buf: rbuf, + pos, + cap, + seek_state: rseek_state, + }, + buf: wbuf, + written, + seek_state: wseek_state, + } = b; + + BufStream { + inner: BufReader { + inner: BufWriter { + inner, + buf: wbuf, + written, + seek_state: wseek_state, + }, + buf: rbuf, + pos, + cap, + seek_state: rseek_state, + }, + } + } +} + +impl<RW: AsyncRead + AsyncWrite> AsyncWrite for BufStream<RW> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.project().inner.poll_shutdown(cx) + } +} + +impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.project().inner.poll_read(cx, buf) + } +} + +/// Seek to an offset, in bytes, in the underlying stream. +/// +/// The position used for seeking with `SeekFrom::Current(_)` is the +/// position the underlying stream would be at if the `BufStream` had no +/// internal buffer. +/// +/// Seeking always discards the internal buffer, even if the seek position +/// would otherwise fall within it. This guarantees that calling +/// `.into_inner()` immediately after a seek yields the underlying reader +/// at the same position. +/// +/// See [`AsyncSeek`] for more details. +/// +/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` +/// where `n` minus the internal buffer length overflows an `i64`, two +/// seeks will be performed instead of one. If the second seek returns +/// `Err`, the underlying reader will be left at the same position it would +/// have if you called `seek` with `SeekFrom::Current(0)`. +impl<RW: AsyncRead + AsyncWrite + AsyncSeek> AsyncSeek for BufStream<RW> { + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> { + self.project().inner.start_seek(position) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + self.project().inner.poll_complete(cx) + } +} + +impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + self.project().inner.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().inner.consume(amt) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<BufStream<()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/buf_writer.rs b/third_party/rust/tokio/src/io/util/buf_writer.rs new file mode 100644 index 0000000000..8dd1bba60a --- /dev/null +++ b/third_party/rust/tokio/src/io/util/buf_writer.rs @@ -0,0 +1,310 @@ +use crate::io::util::DEFAULT_BUF_SIZE; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; + +use pin_project_lite::pin_project; +use std::fmt; +use std::io::{self, IoSlice, SeekFrom, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Wraps a writer and buffers its output. + /// + /// It can be excessively inefficient to work directly with something that + /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and + /// writes it to an underlying writer in large, infrequent batches. + /// + /// `BufWriter` can improve the speed of programs that make *small* and + /// *repeated* write calls to the same file or network socket. It does not + /// help when writing very large amounts at once, or writing just one or a few + /// times. It also provides no advantage when writing to a destination that is + /// in memory, like a `Vec<u8>`. + /// + /// When the `BufWriter` is dropped, the contents of its buffer will be + /// discarded. Creating multiple instances of a `BufWriter` on the same + /// stream can cause data loss. If you need to write out the contents of its + /// buffer, you must manually call flush before the writer is dropped. + /// + /// [`AsyncWrite`]: AsyncWrite + /// [`flush`]: super::AsyncWriteExt::flush + /// + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct BufWriter<W> { + #[pin] + pub(super) inner: W, + pub(super) buf: Vec<u8>, + pub(super) written: usize, + pub(super) seek_state: SeekState, + } +} + +impl<W: AsyncWrite> BufWriter<W> { + /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, + /// but may change in the future. + pub fn new(inner: W) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, inner) + } + + /// Creates a new `BufWriter` with the specified buffer capacity. + pub fn with_capacity(cap: usize, inner: W) -> Self { + Self { + inner, + buf: Vec::with_capacity(cap), + written: 0, + seek_state: SeekState::Init, + } + } + + fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let mut me = self.project(); + + let len = me.buf.len(); + let mut ret = Ok(()); + while *me.written < len { + match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) { + Ok(0) => { + ret = Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffered data", + )); + break; + } + Ok(n) => *me.written += n, + Err(e) => { + ret = Err(e); + break; + } + } + } + if *me.written > 0 { + me.buf.drain(..*me.written); + } + *me.written = 0; + Poll::Ready(ret) + } + + /// Gets a reference to the underlying writer. + pub fn get_ref(&self) -> &W { + &self.inner + } + + /// Gets a mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + pub fn get_mut(&mut self) -> &mut W { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().inner + } + + /// Consumes this `BufWriter`, returning the underlying writer. + /// + /// Note that any leftover data in the internal buffer is lost. + pub fn into_inner(self) -> W { + self.inner + } + + /// Returns a reference to the internally buffered data. + pub fn buffer(&self) -> &[u8] { + &self.buf + } +} + +impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + if self.buf.len() + buf.len() > self.buf.capacity() { + ready!(self.as_mut().flush_buf(cx))?; + } + + let me = self.project(); + if buf.len() >= me.buf.capacity() { + me.inner.poll_write(cx, buf) + } else { + Poll::Ready(me.buf.write(buf)) + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + if self.inner.is_write_vectored() { + let total_len = bufs + .iter() + .fold(0usize, |acc, b| acc.saturating_add(b.len())); + if total_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + } + let me = self.as_mut().project(); + if total_len >= me.buf.capacity() { + // It's more efficient to pass the slices directly to the + // underlying writer than to buffer them. + // The case when the total_len calculation saturates at + // usize::MAX is also handled here. + me.inner.poll_write_vectored(cx, bufs) + } else { + bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); + Poll::Ready(Ok(total_len)) + } + } else { + // Remove empty buffers at the beginning of bufs. + while bufs.first().map(|buf| buf.len()) == Some(0) { + bufs = &bufs[1..]; + } + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + // Flush if the first buffer doesn't fit. + let first_len = bufs[0].len(); + if first_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + debug_assert!(self.buf.is_empty()); + } + let me = self.as_mut().project(); + if first_len >= me.buf.capacity() { + // The slice is at least as large as the buffering capacity, + // so it's better to write it directly, bypassing the buffer. + debug_assert!(me.buf.is_empty()); + return me.inner.poll_write(cx, &bufs[0]); + } else { + me.buf.extend_from_slice(&bufs[0]); + bufs = &bufs[1..]; + } + let mut total_written = first_len; + debug_assert!(total_written != 0); + // Append the buffers that fit in the internal buffer. + for buf in bufs { + if buf.len() > me.buf.capacity() - me.buf.len() { + break; + } else { + me.buf.extend_from_slice(buf); + total_written += buf.len(); + } + } + Poll::Ready(Ok(total_written)) + } + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + ready!(self.as_mut().flush_buf(cx))?; + self.get_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + ready!(self.as_mut().flush_buf(cx))?; + self.get_pin_mut().poll_shutdown(cx) + } +} + +#[derive(Debug, Clone, Copy)] +pub(super) enum SeekState { + /// start_seek has not been called. + Init, + /// start_seek has been called, but poll_complete has not yet been called. + Start(SeekFrom), + /// Waiting for completion of poll_complete. + Pending, +} + +/// Seek to the offset, in bytes, in the underlying writer. +/// +/// Seeking always writes out the internal buffer before seeking. +impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> { + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + // We need to flush the internal buffer before seeking. + // It receives a `Context` and returns a `Poll`, so it cannot be called + // inside `start_seek`. + *self.project().seek_state = SeekState::Start(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let pos = match self.seek_state { + SeekState::Init => { + return self.project().inner.poll_complete(cx); + } + SeekState::Start(pos) => Some(pos), + SeekState::Pending => None, + }; + + // Flush the internal buffer before seeking. + ready!(self.as_mut().flush_buf(cx))?; + + let mut me = self.project(); + if let Some(pos) = pos { + // Ensure previous seeks have finished before starting a new one + ready!(me.inner.as_mut().poll_complete(cx))?; + if let Err(e) = me.inner.as_mut().start_seek(pos) { + *me.seek_state = SeekState::Init; + return Poll::Ready(Err(e)); + } + } + match me.inner.poll_complete(cx) { + Poll::Ready(res) => { + *me.seek_state = SeekState::Init; + Poll::Ready(res) + } + Poll::Pending => { + *me.seek_state = SeekState::Pending; + Poll::Pending + } + } + } +} + +impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.get_pin_mut().poll_read(cx, buf) + } +} + +impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + self.get_pin_mut().poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_pin_mut().consume(amt) + } +} + +impl<W: fmt::Debug> fmt::Debug for BufWriter<W> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BufWriter") + .field("writer", &self.inner) + .field( + "buffer", + &format_args!("{}/{}", self.buf.len(), self.buf.capacity()), + ) + .field("written", &self.written) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<BufWriter<()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/chain.rs b/third_party/rust/tokio/src/io/util/chain.rs new file mode 100644 index 0000000000..84f37fc7d4 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/chain.rs @@ -0,0 +1,144 @@ +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Stream for the [`chain`](super::AsyncReadExt::chain) method. + #[must_use = "streams do nothing unless polled"] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct Chain<T, U> { + #[pin] + first: T, + #[pin] + second: U, + done_first: bool, + } +} + +pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U> +where + T: AsyncRead, + U: AsyncRead, +{ + Chain { + first, + second, + done_first: false, + } +} + +impl<T, U> Chain<T, U> +where + T: AsyncRead, + U: AsyncRead, +{ + /// Gets references to the underlying readers in this `Chain`. + pub fn get_ref(&self) -> (&T, &U) { + (&self.first, &self.second) + } + + /// Gets mutable references to the underlying readers in this `Chain`. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying readers as doing so may corrupt the internal state of this + /// `Chain`. + pub fn get_mut(&mut self) -> (&mut T, &mut U) { + (&mut self.first, &mut self.second) + } + + /// Gets pinned mutable references to the underlying readers in this `Chain`. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying readers as doing so may corrupt the internal state of this + /// `Chain`. + pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) { + let me = self.project(); + (me.first, me.second) + } + + /// Consumes the `Chain`, returning the wrapped readers. + pub fn into_inner(self) -> (T, U) { + (self.first, self.second) + } +} + +impl<T, U> fmt::Debug for Chain<T, U> +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Chain") + .field("t", &self.first) + .field("u", &self.second) + .finish() + } +} + +impl<T, U> AsyncRead for Chain<T, U> +where + T: AsyncRead, + U: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let me = self.project(); + + if !*me.done_first { + let rem = buf.remaining(); + ready!(me.first.poll_read(cx, buf))?; + if buf.remaining() == rem { + *me.done_first = true; + } else { + return Poll::Ready(Ok(())); + } + } + me.second.poll_read(cx, buf) + } +} + +impl<T, U> AsyncBufRead for Chain<T, U> +where + T: AsyncBufRead, + U: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + let me = self.project(); + + if !*me.done_first { + match ready!(me.first.poll_fill_buf(cx)?) { + buf if buf.is_empty() => { + *me.done_first = true; + } + buf => return Poll::Ready(Ok(buf)), + } + } + me.second.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + if !*me.done_first { + me.first.consume(amt) + } else { + me.second.consume(amt) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Chain<(), ()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/copy.rs b/third_party/rust/tokio/src/io/util/copy.rs new file mode 100644 index 0000000000..d0ab7cb140 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/copy.rs @@ -0,0 +1,175 @@ +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + need_flush: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} + +impl CopyBuffer { + pub(super) fn new() -> Self { + Self { + read_done: false, + need_flush: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; super::DEFAULT_BUF_SIZE].into_boxed_slice(), + } + } + + pub(super) fn poll_copy<R, W>( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll<io::Result<u64>> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + + match reader.as_mut().poll_read(cx, &mut buf) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + + let n = buf.filled().len(); + if n == 0 { + self.read_done = true; + } else { + self.pos = 0; + self.cap = n; + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let me = &mut *self; + let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + self.need_flush = true; + } + } + + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!( + self.pos <= self.cap, + "writer returned length larger than input slice" + ); + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} + +/// A future that asynchronously copies the entire contents of a reader into a +/// writer. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +struct Copy<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + writer: &'a mut W, + buf: CopyBuffer, +} + +cfg_io_util! { + /// Asynchronously copies the entire contents of a reader into a writer. + /// + /// This function returns a future that will continuously read data from + /// `reader` and then write it into `writer` in a streaming fashion until + /// `reader` returns EOF. + /// + /// On success, the total number of bytes that were copied from `reader` to + /// `writer` is returned. + /// + /// This is an asynchronous version of [`std::io::copy`][std]. + /// + /// [std]: std::io::copy + /// + /// # Errors + /// + /// The returned future will return an error immediately if any call to + /// `poll_read` or `poll_write` returns an error. + /// + /// # Examples + /// + /// ``` + /// use tokio::io; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut reader: &[u8] = b"hello"; + /// let mut writer: Vec<u8> = vec![]; + /// + /// io::copy(&mut reader, &mut writer).await?; + /// + /// assert_eq!(&b"hello"[..], &writer[..]); + /// # Ok(()) + /// # } + /// ``` + pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> + where + R: AsyncRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, + { + Copy { + reader, + writer, + buf: CopyBuffer::new() + }.await + } +} + +impl<R, W> Future for Copy<'_, R, W> +where + R: AsyncRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<u64>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let me = &mut *self; + + me.buf + .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer)) + } +} diff --git a/third_party/rust/tokio/src/io/util/copy_bidirectional.rs b/third_party/rust/tokio/src/io/util/copy_bidirectional.rs new file mode 100644 index 0000000000..c93060b361 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/copy_bidirectional.rs @@ -0,0 +1,120 @@ +use super::copy::CopyBuffer; + +use crate::io::{AsyncRead, AsyncWrite}; + +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +enum TransferState { + Running(CopyBuffer), + ShuttingDown(u64), + Done(u64), +} + +struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> { + a: &'a mut A, + b: &'a mut B, + a_to_b: TransferState, + b_to_a: TransferState, +} + +fn transfer_one_direction<A, B>( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll<io::Result<u64>> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(count); + } + TransferState::ShuttingDown(count) => { + ready!(w.as_mut().poll_shutdown(cx))?; + + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + +impl<'a, A, B> Future for CopyBidirectional<'a, A, B> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<(u64, u64)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + // Unpack self into mut refs to each field to avoid borrow check issues. + let CopyBidirectional { + a, + b, + a_to_b, + b_to_a, + } = &mut *self; + + let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?; + let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?; + + // It is not a problem if ready! returns early because transfer_one_direction for the + // other direction will keep returning TransferState::Done(count) in future calls to poll + let a_to_b = ready!(a_to_b); + let b_to_a = ready!(b_to_a); + + Poll::Ready(Ok((a_to_b, b_to_a))) + } +} + +/// Copies data in both directions between `a` and `b`. +/// +/// This function returns a future that will read from both streams, +/// writing any data read to the opposing stream. +/// This happens in both directions concurrently. +/// +/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on +/// the other, and reading from that stream will stop. Copying of data in +/// the other direction will continue. +/// +/// The future will complete successfully once both directions of communication has been shut down. +/// A direction is shut down when the reader reports EOF, +/// at which point [`shutdown()`] is called on the corresponding writer. When finished, +/// it will return a tuple of the number of bytes copied from a to b +/// and the number of bytes copied from b to a, in that order. +/// +/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown +/// +/// # Errors +/// +/// The future will immediately return an error if any IO operation on `a` +/// or `b` returns an error. Some data read from either stream may be lost (not +/// written to the other stream) in this case. +/// +/// # Return value +/// +/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`. +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] +pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + CopyBidirectional { + a, + b, + a_to_b: TransferState::Running(CopyBuffer::new()), + b_to_a: TransferState::Running(CopyBuffer::new()), + } + .await +} diff --git a/third_party/rust/tokio/src/io/util/copy_buf.rs b/third_party/rust/tokio/src/io/util/copy_buf.rs new file mode 100644 index 0000000000..6831580b40 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/copy_buf.rs @@ -0,0 +1,102 @@ +use crate::io::{AsyncBufRead, AsyncWrite}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// A future that asynchronously copies the entire contents of a reader into a + /// writer. + /// + /// This struct is generally created by calling [`copy_buf`][copy_buf]. Please + /// see the documentation of `copy_buf()` for more details. + /// + /// [copy_buf]: copy_buf() + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + struct CopyBuf<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + writer: &'a mut W, + amt: u64, + } + + /// Asynchronously copies the entire contents of a reader into a writer. + /// + /// This function returns a future that will continuously read data from + /// `reader` and then write it into `writer` in a streaming fashion until + /// `reader` returns EOF. + /// + /// On success, the total number of bytes that were copied from `reader` to + /// `writer` is returned. + /// + /// + /// # Errors + /// + /// The returned future will finish with an error will return an error + /// immediately if any call to `poll_fill_buf` or `poll_write` returns an + /// error. + /// + /// # Examples + /// + /// ``` + /// use tokio::io; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut reader: &[u8] = b"hello"; + /// let mut writer: Vec<u8> = vec![]; + /// + /// io::copy_buf(&mut reader, &mut writer).await?; + /// + /// assert_eq!(b"hello", &writer[..]); + /// # Ok(()) + /// # } + /// ``` + pub async fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> + where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, + { + CopyBuf { + reader, + writer, + amt: 0, + }.await + } +} + +impl<R, W> Future for CopyBuf<'_, R, W> +where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<u64>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let me = &mut *self; + let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?; + if buffer.is_empty() { + ready!(Pin::new(&mut self.writer).poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + + let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?; + if i == 0 { + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); + } + self.amt += i as u64; + Pin::new(&mut *self.reader).consume(i); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + use std::marker::PhantomPinned; + crate::is_unpin::<CopyBuf<'_, PhantomPinned, PhantomPinned>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/empty.rs b/third_party/rust/tokio/src/io/util/empty.rs new file mode 100644 index 0000000000..77db60e40b --- /dev/null +++ b/third_party/rust/tokio/src/io/util/empty.rs @@ -0,0 +1,100 @@ +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// An async reader which is always at EOF. + /// + /// This struct is generally created by calling [`empty`]. Please see + /// the documentation of [`empty()`][`empty`] for more details. + /// + /// This is an asynchronous version of [`std::io::empty`][std]. + /// + /// [`empty`]: fn@empty + /// [std]: std::io::empty + pub struct Empty { + _p: (), + } + + /// Creates a new empty async reader. + /// + /// All reads from the returned reader will return `Poll::Ready(Ok(0))`. + /// + /// This is an asynchronous version of [`std::io::empty`][std]. + /// + /// [std]: std::io::empty + /// + /// # Examples + /// + /// A slightly sad example of not reading anything into a buffer: + /// + /// ``` + /// use tokio::io::{self, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut buffer = String::new(); + /// io::empty().read_to_string(&mut buffer).await.unwrap(); + /// assert!(buffer.is_empty()); + /// } + /// ``` + pub fn empty() -> Empty { + Empty { _p: () } + } +} + +impl AsyncRead for Empty { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + ready!(poll_proceed_and_make_progress(cx)); + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for Empty { + #[inline] + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + ready!(poll_proceed_and_make_progress(cx)); + Poll::Ready(Ok(&[])) + } + + #[inline] + fn consume(self: Pin<&mut Self>, _: usize) {} +} + +impl fmt::Debug for Empty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Empty { .. }") + } +} + +cfg_coop! { + fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> { + let coop = ready!(crate::coop::poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + } +} + +cfg_not_coop! { + fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Empty>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/fill_buf.rs b/third_party/rust/tokio/src/io/util/fill_buf.rs new file mode 100644 index 0000000000..bb07c766e2 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/fill_buf.rs @@ -0,0 +1,59 @@ +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`fill_buf`](crate::io::AsyncBufReadExt::fill_buf) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct FillBuf<'a, R: ?Sized> { + reader: Option<&'a mut R>, + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn fill_buf<R>(reader: &mut R) -> FillBuf<'_, R> +where + R: AsyncBufRead + ?Sized + Unpin, +{ + FillBuf { + reader: Some(reader), + _pin: PhantomPinned, + } +} + +impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for FillBuf<'a, R> { + type Output = io::Result<&'a [u8]>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + let reader = me.reader.take().expect("Polled after completion."); + match Pin::new(&mut *reader).poll_fill_buf(cx) { + Poll::Ready(Ok(slice)) => unsafe { + // Safety: This is necessary only due to a limitation in the + // borrow checker. Once Rust starts using the polonius borrow + // checker, this can be simplified. + // + // The safety of this transmute relies on the fact that the + // value of `reader` is `None` when we return in this branch. + // Otherwise the caller could poll us again after + // completion, and access the mutable reference while the + // returned immutable reference still exists. + let slice = std::mem::transmute::<&[u8], &'a [u8]>(slice); + Poll::Ready(Ok(slice)) + }, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + *me.reader = Some(reader); + Poll::Pending + } + } + } +} diff --git a/third_party/rust/tokio/src/io/util/flush.rs b/third_party/rust/tokio/src/io/util/flush.rs new file mode 100644 index 0000000000..88d60b868d --- /dev/null +++ b/third_party/rust/tokio/src/io/util/flush.rs @@ -0,0 +1,46 @@ +use crate::io::AsyncWrite; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future used to fully flush an I/O object. + /// + /// Created by the [`AsyncWriteExt::flush`][flush] function. + /// [flush]: crate::io::AsyncWriteExt::flush + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Flush<'a, A: ?Sized> { + a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +/// Creates a future which will entirely flush an I/O object. +pub(super) fn flush<A>(a: &mut A) -> Flush<'_, A> +where + A: AsyncWrite + Unpin + ?Sized, +{ + Flush { + a, + _pin: PhantomPinned, + } +} + +impl<A> Future for Flush<'_, A> +where + A: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + Pin::new(&mut *me.a).poll_flush(cx) + } +} diff --git a/third_party/rust/tokio/src/io/util/lines.rs b/third_party/rust/tokio/src/io/util/lines.rs new file mode 100644 index 0000000000..717f633f95 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/lines.rs @@ -0,0 +1,145 @@ +use crate::io::util::read_line::read_line_internal; +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::io; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Reads lines from an [`AsyncBufRead`]. + /// + /// A `Lines` can be turned into a `Stream` with [`LinesStream`]. + /// + /// This type is usually created using the [`lines`] method. + /// + /// [`AsyncBufRead`]: crate::io::AsyncBufRead + /// [`LinesStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.LinesStream.html + /// [`lines`]: crate::io::AsyncBufReadExt::lines + #[derive(Debug)] + #[must_use = "streams do nothing unless polled"] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct Lines<R> { + #[pin] + reader: R, + buf: String, + bytes: Vec<u8>, + read: usize, + } +} + +pub(crate) fn lines<R>(reader: R) -> Lines<R> +where + R: AsyncBufRead, +{ + Lines { + reader, + buf: String::new(), + bytes: Vec::new(), + read: 0, + } +} + +impl<R> Lines<R> +where + R: AsyncBufRead + Unpin, +{ + /// Returns the next line in the stream. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncBufRead; + /// use tokio::io::AsyncBufReadExt; + /// + /// # async fn dox(my_buf_read: impl AsyncBufRead + Unpin) -> std::io::Result<()> { + /// let mut lines = my_buf_read.lines(); + /// + /// while let Some(line) = lines.next_line().await? { + /// println!("length = {}", line.len()) + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn next_line(&mut self) -> io::Result<Option<String>> { + use crate::future::poll_fn; + + poll_fn(|cx| Pin::new(&mut *self).poll_next_line(cx)).await + } + + /// Obtains a mutable reference to the underlying reader. + pub fn get_mut(&mut self) -> &mut R { + &mut self.reader + } + + /// Obtains a reference to the underlying reader. + pub fn get_ref(&mut self) -> &R { + &self.reader + } + + /// Unwraps this `Lines<R>`, returning the underlying reader. + /// + /// Note that any leftover data in the internal buffer is lost. + /// Therefore, a following read from the underlying reader may lead to data loss. + pub fn into_inner(self) -> R { + self.reader + } +} + +impl<R> Lines<R> +where + R: AsyncBufRead, +{ + /// Polls for the next line in the stream. + /// + /// This method returns: + /// + /// * `Poll::Pending` if the next line is not yet available. + /// * `Poll::Ready(Ok(Some(line)))` if the next line is available. + /// * `Poll::Ready(Ok(None))` if there are no more lines in this stream. + /// * `Poll::Ready(Err(err))` if an IO error occurred while reading the next line. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when more bytes become + /// available on the underlying IO resource. Note that on multiple calls to + /// `poll_next_line`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + pub fn poll_next_line( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<Option<String>>> { + let me = self.project(); + + let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?; + debug_assert_eq!(*me.read, 0); + + if n == 0 && me.buf.is_empty() { + return Poll::Ready(Ok(None)); + } + + if me.buf.ends_with('\n') { + me.buf.pop(); + + if me.buf.ends_with('\r') { + me.buf.pop(); + } + } + + Poll::Ready(Ok(Some(mem::take(me.buf)))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Lines<()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/mem.rs b/third_party/rust/tokio/src/io/util/mem.rs new file mode 100644 index 0000000000..4019db56ff --- /dev/null +++ b/third_party/rust/tokio/src/io/util/mem.rs @@ -0,0 +1,295 @@ +//! In-process memory IO types. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::loom::sync::Mutex; + +use bytes::{Buf, BytesMut}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll, Waker}, +}; + +/// A bidirectional pipe to read and write bytes in memory. +/// +/// A pair of `DuplexStream`s are created together, and they act as a "channel" +/// that can be used as in-memory IO types. Writing to one of the pairs will +/// allow that data to be read from the other, and vice versa. +/// +/// # Closing a `DuplexStream` +/// +/// If one end of the `DuplexStream` channel is dropped, any pending reads on +/// the other side will continue to read data until the buffer is drained, then +/// they will signal EOF by returning 0 bytes. Any writes to the other side, +/// including pending ones (that are waiting for free space in the buffer) will +/// return `Err(BrokenPipe)` immediately. +/// +/// # Example +/// +/// ``` +/// # async fn ex() -> std::io::Result<()> { +/// # use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// let (mut client, mut server) = tokio::io::duplex(64); +/// +/// client.write_all(b"ping").await?; +/// +/// let mut buf = [0u8; 4]; +/// server.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"ping"); +/// +/// server.write_all(b"pong").await?; +/// +/// client.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"pong"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] +pub struct DuplexStream { + read: Arc<Mutex<Pipe>>, + write: Arc<Mutex<Pipe>>, +} + +/// A unidirectional IO over a piece of memory. +/// +/// Data can be written to the pipe, and reading will return that data. +#[derive(Debug)] +struct Pipe { + /// The buffer storing the bytes written, also read from. + /// + /// Using a `BytesMut` because it has efficient `Buf` and `BufMut` + /// functionality already. Additionally, it can try to copy data in the + /// same buffer if there read index has advanced far enough. + buffer: BytesMut, + /// Determines if the write side has been closed. + is_closed: bool, + /// The maximum amount of bytes that can be written before returning + /// `Poll::Pending`. + max_buf_size: usize, + /// If the `read` side has been polled and is pending, this is the waker + /// for that parked task. + read_waker: Option<Waker>, + /// If the `write` side has filled the `max_buf_size` and returned + /// `Poll::Pending`, this is the waker for that parked task. + write_waker: Option<Waker>, +} + +// ===== impl DuplexStream ===== + +/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets. +/// +/// The `max_buf_size` argument is the maximum amount of bytes that can be +/// written to a side before the write returns `Poll::Pending`. +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] +pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { + let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + + ( + DuplexStream { + read: one.clone(), + write: two.clone(), + }, + DuplexStream { + read: two, + write: one, + }, + ) +} + +impl AsyncRead for DuplexStream { + // Previous rustc required this `self` to be `mut`, even though newer + // versions recognize it isn't needed to call `lock()`. So for + // compatibility, we include the `mut` and `allow` the lint. + // + // See https://github.com/rust-lang/rust/issues/73592 + #[allow(unused_mut)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.read.lock()).poll_read(cx, buf) + } +} + +impl AsyncWrite for DuplexStream { + #[allow(unused_mut)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + Pin::new(&mut *self.write.lock()).poll_write(cx, buf) + } + + #[allow(unused_mut)] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_flush(cx) + } + + #[allow(unused_mut)] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_shutdown(cx) + } +} + +impl Drop for DuplexStream { + fn drop(&mut self) { + // notify the other side of the closure + self.write.lock().close_write(); + self.read.lock().close_read(); + } +} + +// ===== impl Pipe ===== + +impl Pipe { + fn new(max_buf_size: usize) -> Self { + Pipe { + buffer: BytesMut::new(), + is_closed: false, + max_buf_size, + read_waker: None, + write_waker: None, + } + } + + fn close_write(&mut self) { + self.is_closed = true; + // needs to notify any readers that no more data will come + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } + + fn close_read(&mut self) { + self.is_closed = true; + // needs to notify any writers that they have to abort + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + + fn poll_read_internal( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + if self.buffer.has_remaining() { + let max = self.buffer.remaining().min(buf.remaining()); + buf.put_slice(&self.buffer[..max]); + self.buffer.advance(max); + if max > 0 { + // The passed `buf` might have been empty, don't wake up if + // no bytes have been moved. + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + Poll::Ready(Ok(())) + } else if self.is_closed { + Poll::Ready(Ok(())) + } else { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + + fn poll_write_internal( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let len = buf.len().min(avail); + self.buffer.extend_from_slice(&buf[..len]); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(len)) + } +} + +impl AsyncRead for Pipe { + cfg_coop! { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + let coop = ready!(crate::coop::poll_proceed(cx)); + + let ret = self.poll_read_internal(cx, buf); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + self.poll_read_internal(cx, buf) + } + } +} + +impl AsyncWrite for Pipe { + cfg_coop! { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + let coop = ready!(crate::coop::poll_proceed(cx)); + + let ret = self.poll_write_internal(cx, buf); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + self.poll_write_internal(cx, buf) + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + self.close_write(); + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio/src/io/util/mod.rs b/third_party/rust/tokio/src/io/util/mod.rs new file mode 100644 index 0000000000..21199d0be8 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/mod.rs @@ -0,0 +1,97 @@ +#![allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 + +cfg_io_util! { + mod async_buf_read_ext; + pub use async_buf_read_ext::AsyncBufReadExt; + + mod async_read_ext; + pub use async_read_ext::AsyncReadExt; + + mod async_seek_ext; + pub use async_seek_ext::AsyncSeekExt; + + mod async_write_ext; + pub use async_write_ext::AsyncWriteExt; + + mod buf_reader; + pub use buf_reader::BufReader; + + mod buf_stream; + pub use buf_stream::BufStream; + + mod buf_writer; + pub use buf_writer::BufWriter; + + mod chain; + + mod copy; + pub use copy::copy; + + mod copy_bidirectional; + pub use copy_bidirectional::copy_bidirectional; + + mod copy_buf; + pub use copy_buf::copy_buf; + + mod empty; + pub use empty::{empty, Empty}; + + mod flush; + + mod lines; + pub use lines::Lines; + + mod mem; + pub use mem::{duplex, DuplexStream}; + + mod read; + mod read_buf; + mod read_exact; + mod read_int; + mod read_line; + mod fill_buf; + + mod read_to_end; + mod vec_with_initialized; + cfg_process! { + pub(crate) use read_to_end::read_to_end; + } + + mod read_to_string; + mod read_until; + + mod repeat; + pub use repeat::{repeat, Repeat}; + + mod shutdown; + + mod sink; + pub use sink::{sink, Sink}; + + mod split; + pub use split::Split; + + mod take; + pub use take::Take; + + mod write; + mod write_vectored; + mod write_all; + mod write_buf; + mod write_all_buf; + mod write_int; + + + // used by `BufReader` and `BufWriter` + // https://github.com/rust-lang/rust/blob/master/library/std/src/sys_common/io.rs#L1 + const DEFAULT_BUF_SIZE: usize = 8 * 1024; +} + +cfg_not_io_util! { + cfg_process! { + mod vec_with_initialized; + mod read_to_end; + // Used by process + pub(crate) use read_to_end::read_to_end; + } +} diff --git a/third_party/rust/tokio/src/io/util/read.rs b/third_party/rust/tokio/src/io/util/read.rs new file mode 100644 index 0000000000..edc9d5a9e6 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read.rs @@ -0,0 +1,55 @@ +use crate::io::{AsyncRead, ReadBuf}; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::marker::Unpin; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Tries to read some bytes directly into the given `buf` in asynchronous +/// manner, returning a future type. +/// +/// The returned future will resolve to both the I/O stream and the buffer +/// as well as the number of bytes read once the read operation is completed. +pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R> +where + R: AsyncRead + Unpin + ?Sized, +{ + Read { + reader, + buf, + _pin: PhantomPinned, + } +} + +pin_project! { + /// A future which can be used to easily read available number of bytes to fill + /// a buffer. + /// + /// Created by the [`read`] function. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Read<'a, R: ?Sized> { + reader: &'a mut R, + buf: &'a mut [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl<R> Future for Read<'_, R> +where + R: AsyncRead + Unpin + ?Sized, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + let mut buf = ReadBuf::new(*me.buf); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?; + Poll::Ready(Ok(buf.filled().len())) + } +} diff --git a/third_party/rust/tokio/src/io/util/read_buf.rs b/third_party/rust/tokio/src/io/util/read_buf.rs new file mode 100644 index 0000000000..8ec57c0d6f --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_buf.rs @@ -0,0 +1,72 @@ +use crate::io::AsyncRead; + +use bytes::BufMut; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) fn read_buf<'a, R, B>(reader: &'a mut R, buf: &'a mut B) -> ReadBuf<'a, R, B> +where + R: AsyncRead + Unpin, + B: BufMut, +{ + ReadBuf { + reader, + buf, + _pin: PhantomPinned, + } +} + +pin_project! { + /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadBuf<'a, R, B> { + reader: &'a mut R, + buf: &'a mut B, + #[pin] + _pin: PhantomPinned, + } +} + +impl<R, B> Future for ReadBuf<'_, R, B> +where + R: AsyncRead + Unpin, + B: BufMut, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + use crate::io::ReadBuf; + use std::mem::MaybeUninit; + + let me = self.project(); + + if !me.buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = me.buf.chunk_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) + } +} diff --git a/third_party/rust/tokio/src/io/util/read_exact.rs b/third_party/rust/tokio/src/io/util/read_exact.rs new file mode 100644 index 0000000000..dbdd58bae9 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_exact.rs @@ -0,0 +1,69 @@ +use crate::io::{AsyncRead, ReadBuf}; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::marker::Unpin; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A future which can be used to easily read exactly enough bytes to fill +/// a buffer. +/// +/// Created by the [`AsyncReadExt::read_exact`][read_exact]. +/// [read_exact]: [crate::io::AsyncReadExt::read_exact] +pub(crate) fn read_exact<'a, A>(reader: &'a mut A, buf: &'a mut [u8]) -> ReadExact<'a, A> +where + A: AsyncRead + Unpin + ?Sized, +{ + ReadExact { + reader, + buf: ReadBuf::new(buf), + _pin: PhantomPinned, + } +} + +pin_project! { + /// Creates a future which will read exactly enough bytes to fill `buf`, + /// returning an error if EOF is hit sooner. + /// + /// On success the number of bytes is returned + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadExact<'a, A: ?Sized> { + reader: &'a mut A, + buf: ReadBuf<'a>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +fn eof() -> io::Error { + io::Error::new(io::ErrorKind::UnexpectedEof, "early eof") +} + +impl<A> Future for ReadExact<'_, A> +where + A: AsyncRead + Unpin + ?Sized, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + + loop { + // if our buffer is empty, then we need to read some data to continue. + let rem = me.buf.remaining(); + if rem != 0 { + ready!(Pin::new(&mut *me.reader).poll_read(cx, me.buf))?; + if me.buf.remaining() == rem { + return Err(eof()).into(); + } + } else { + return Poll::Ready(Ok(me.buf.capacity())); + } + } + } +} diff --git a/third_party/rust/tokio/src/io/util/read_int.rs b/third_party/rust/tokio/src/io/util/read_int.rs new file mode 100644 index 0000000000..164dcf5963 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_int.rs @@ -0,0 +1,159 @@ +use crate::io::{AsyncRead, ReadBuf}; + +use bytes::Buf; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::io::ErrorKind::UnexpectedEof; +use std::marker::PhantomPinned; +use std::mem::size_of; +use std::pin::Pin; +use std::task::{Context, Poll}; + +macro_rules! reader { + ($name:ident, $ty:ty, $reader:ident) => { + reader!($name, $ty, $reader, size_of::<$ty>()); + }; + ($name:ident, $ty:ty, $reader:ident, $bytes:expr) => { + pin_project! { + #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct $name<R> { + #[pin] + src: R, + buf: [u8; $bytes], + read: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } + } + + impl<R> $name<R> { + pub(crate) fn new(src: R) -> Self { + $name { + src, + buf: [0; $bytes], + read: 0, + _pin: PhantomPinned, + } + } + } + + impl<R> Future for $name<R> + where + R: AsyncRead, + { + type Output = io::Result<$ty>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + + if *me.read == $bytes as u8 { + return Poll::Ready(Ok(Buf::$reader(&mut &me.buf[..]))); + } + + while *me.read < $bytes as u8 { + let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]); + + *me.read += match me.src.as_mut().poll_read(cx, &mut buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Ready(Ok(())) => { + let n = buf.filled().len(); + if n == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + n as u8 + } + }; + } + + let num = Buf::$reader(&mut &me.buf[..]); + + Poll::Ready(Ok(num)) + } + } + }; +} + +macro_rules! reader8 { + ($name:ident, $ty:ty) => { + pin_project! { + /// Future returned from `read_u8` + #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct $name<R> { + #[pin] + reader: R, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } + } + + impl<R> $name<R> { + pub(crate) fn new(reader: R) -> $name<R> { + $name { + reader, + _pin: PhantomPinned, + } + } + } + + impl<R> Future for $name<R> + where + R: AsyncRead, + { + type Output = io::Result<$ty>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + let mut buf = [0; 1]; + let mut buf = ReadBuf::new(&mut buf); + match me.reader.poll_read(cx, &mut buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(())) => { + if buf.filled().len() == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + Poll::Ready(Ok(buf.filled()[0] as $ty)) + } + } + } + } + }; +} + +reader8!(ReadU8, u8); +reader8!(ReadI8, i8); + +reader!(ReadU16, u16, get_u16); +reader!(ReadU32, u32, get_u32); +reader!(ReadU64, u64, get_u64); +reader!(ReadU128, u128, get_u128); + +reader!(ReadI16, i16, get_i16); +reader!(ReadI32, i32, get_i32); +reader!(ReadI64, i64, get_i64); +reader!(ReadI128, i128, get_i128); + +reader!(ReadF32, f32, get_f32); +reader!(ReadF64, f64, get_f64); + +reader!(ReadU16Le, u16, get_u16_le); +reader!(ReadU32Le, u32, get_u32_le); +reader!(ReadU64Le, u64, get_u64_le); +reader!(ReadU128Le, u128, get_u128_le); + +reader!(ReadI16Le, i16, get_i16_le); +reader!(ReadI32Le, i32, get_i32_le); +reader!(ReadI64Le, i64, get_i64_le); +reader!(ReadI128Le, i128, get_i128_le); + +reader!(ReadF32Le, f32, get_f32_le); +reader!(ReadF64Le, f64, get_f64_le); diff --git a/third_party/rust/tokio/src/io/util/read_line.rs b/third_party/rust/tokio/src/io/util/read_line.rs new file mode 100644 index 0000000000..e641f51532 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_line.rs @@ -0,0 +1,119 @@ +use crate::io::util::read_until::read_until_internal; +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::mem; +use std::pin::Pin; +use std::string::FromUtf8Error; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`read_line`](crate::io::AsyncBufReadExt::read_line) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadLine<'a, R: ?Sized> { + reader: &'a mut R, + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + // The actual allocation of the string is moved into this vector instead. + buf: Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R> +where + R: AsyncBufRead + ?Sized + Unpin, +{ + ReadLine { + reader, + buf: mem::take(string).into_bytes(), + output: string, + read: 0, + _pin: PhantomPinned, + } +} + +fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_read: usize) { + let original_len = vector.len() - num_bytes_read; + vector.truncate(original_len); + *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); +} + +/// This handles the various failure cases and puts the string back into `output`. +/// +/// The `truncate_on_io_error` bool is necessary because `read_to_string` and `read_line` +/// disagree on what should happen when an IO error occurs. +pub(super) fn finish_string_read( + io_res: io::Result<usize>, + utf8_res: Result<String, FromUtf8Error>, + read: usize, + output: &mut String, + truncate_on_io_error: bool, +) -> Poll<io::Result<usize>> { + match (io_res, utf8_res) { + (Ok(num_bytes), Ok(string)) => { + debug_assert_eq!(read, 0); + *output = string; + Poll::Ready(Ok(num_bytes)) + } + (Err(io_err), Ok(string)) => { + *output = string; + if truncate_on_io_error { + let original_len = output.len() - read; + output.truncate(original_len); + } + Poll::Ready(Err(io_err)) + } + (Ok(num_bytes), Err(utf8_err)) => { + debug_assert_eq!(read, 0); + put_back_original_data(output, utf8_err.into_bytes(), num_bytes); + + Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "stream did not contain valid UTF-8", + ))) + } + (Err(io_err), Err(utf8_err)) => { + put_back_original_data(output, utf8_err.into_bytes(), read); + + Poll::Ready(Err(io_err)) + } + } +} + +pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( + reader: Pin<&mut R>, + cx: &mut Context<'_>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, +) -> Poll<io::Result<usize>> { + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::take(buf)); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, false) +} + +impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + read_line_internal(Pin::new(*me.reader), cx, me.output, me.buf, me.read) + } +} diff --git a/third_party/rust/tokio/src/io/util/read_to_end.rs b/third_party/rust/tokio/src/io/util/read_to_end.rs new file mode 100644 index 0000000000..f4a564d7dd --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_to_end.rs @@ -0,0 +1,112 @@ +use crate::io::util::vec_with_initialized::{into_read_buf_parts, VecU8, VecWithInitialized}; +use crate::io::AsyncRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadToEnd<'a, R: ?Sized> { + reader: &'a mut R, + buf: VecWithInitialized<&'a mut Vec<u8>>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) -> ReadToEnd<'a, R> +where + R: AsyncRead + Unpin + ?Sized, +{ + ReadToEnd { + reader, + buf: VecWithInitialized::new(buffer), + read: 0, + _pin: PhantomPinned, + } +} + +pub(super) fn read_to_end_internal<V: VecU8, R: AsyncRead + ?Sized>( + buf: &mut VecWithInitialized<V>, + mut reader: Pin<&mut R>, + num_read: &mut usize, + cx: &mut Context<'_>, +) -> Poll<io::Result<usize>> { + loop { + let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx)); + match ret { + Err(err) => return Poll::Ready(Err(err)), + Ok(0) => return Poll::Ready(Ok(mem::replace(num_read, 0))), + Ok(num) => { + *num_read += num; + } + } + } +} + +/// Tries to read from the provided AsyncRead. +/// +/// The length of the buffer is increased by the number of bytes read. +fn poll_read_to_end<V: VecU8, R: AsyncRead + ?Sized>( + buf: &mut VecWithInitialized<V>, + read: Pin<&mut R>, + cx: &mut Context<'_>, +) -> Poll<io::Result<usize>> { + // This uses an adaptive system to extend the vector when it fills. We want to + // avoid paying to allocate and zero a huge chunk of memory if the reader only + // has 4 bytes while still making large reads if the reader does have a ton + // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every + // time is 4,500 times (!) slower than this if the reader has a very small + // amount of data to return. + buf.reserve(32); + + // Get a ReadBuf into the vector. + let mut read_buf = buf.get_read_buf(); + + let filled_before = read_buf.filled().len(); + let poll_result = read.poll_read(cx, &mut read_buf); + let filled_after = read_buf.filled().len(); + let n = filled_after - filled_before; + + // Update the length of the vector using the result of poll_read. + let read_buf_parts = into_read_buf_parts(read_buf); + buf.apply_read_buf(read_buf_parts); + + match poll_result { + Poll::Pending => { + // In this case, nothing should have been read. However we still + // update the vector in case the poll_read call initialized parts of + // the vector's unused capacity. + debug_assert_eq!(filled_before, filled_after); + Poll::Pending + } + Poll::Ready(Err(err)) => { + debug_assert_eq!(filled_before, filled_after); + Poll::Ready(Err(err)) + } + Poll::Ready(Ok(())) => Poll::Ready(Ok(n)), + } +} + +impl<A> Future for ReadToEnd<'_, A> +where + A: AsyncRead + ?Sized + Unpin, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + read_to_end_internal(me.buf, Pin::new(*me.reader), me.read, cx) + } +} diff --git a/third_party/rust/tokio/src/io/util/read_to_string.rs b/third_party/rust/tokio/src/io/util/read_to_string.rs new file mode 100644 index 0000000000..b3d82a26ba --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_to_string.rs @@ -0,0 +1,78 @@ +use crate::io::util::read_line::finish_string_read; +use crate::io::util::read_to_end::read_to_end_internal; +use crate::io::util::vec_with_initialized::VecWithInitialized; +use crate::io::AsyncRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, mem}; + +pin_project! { + /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadToString<'a, R: ?Sized> { + reader: &'a mut R, + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + // The actual allocation of the string is moved into this vector instead. + buf: VecWithInitialized<Vec<u8>>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn read_to_string<'a, R>( + reader: &'a mut R, + string: &'a mut String, +) -> ReadToString<'a, R> +where + R: AsyncRead + ?Sized + Unpin, +{ + let buf = mem::take(string).into_bytes(); + ReadToString { + reader, + buf: VecWithInitialized::new(buf), + output: string, + read: 0, + _pin: PhantomPinned, + } +} + +fn read_to_string_internal<R: AsyncRead + ?Sized>( + reader: Pin<&mut R>, + output: &mut String, + buf: &mut VecWithInitialized<Vec<u8>>, + read: &mut usize, + cx: &mut Context<'_>, +) -> Poll<io::Result<usize>> { + let io_res = ready!(read_to_end_internal(buf, reader, read, cx)); + let utf8_res = String::from_utf8(buf.take()); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, true) +} + +impl<A> Future for ReadToString<'_, A> +where + A: AsyncRead + ?Sized + Unpin, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) + } +} diff --git a/third_party/rust/tokio/src/io/util/read_until.rs b/third_party/rust/tokio/src/io/util/read_until.rs new file mode 100644 index 0000000000..90a0e8a18d --- /dev/null +++ b/third_party/rust/tokio/src/io/util/read_until.rs @@ -0,0 +1,79 @@ +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. + /// The delimiter is included in the resulting vector. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadUntil<'a, R: ?Sized> { + reader: &'a mut R, + delimiter: u8, + buf: &'a mut Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn read_until<'a, R>( + reader: &'a mut R, + delimiter: u8, + buf: &'a mut Vec<u8>, +) -> ReadUntil<'a, R> +where + R: AsyncBufRead + ?Sized + Unpin, +{ + ReadUntil { + reader, + delimiter, + buf, + read: 0, + _pin: PhantomPinned, + } +} + +pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>( + mut reader: Pin<&mut R>, + cx: &mut Context<'_>, + delimiter: u8, + buf: &mut Vec<u8>, + read: &mut usize, +) -> Poll<io::Result<usize>> { + loop { + let (done, used) = { + let available = ready!(reader.as_mut().poll_fill_buf(cx))?; + if let Some(i) = memchr::memchr(delimiter, available) { + buf.extend_from_slice(&available[..=i]); + (true, i + 1) + } else { + buf.extend_from_slice(available); + (false, available.len()) + } + }; + reader.as_mut().consume(used); + *read += used; + if done || used == 0 { + return Poll::Ready(Ok(mem::replace(read, 0))); + } + } +} + +impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> { + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + read_until_internal(Pin::new(*me.reader), cx, *me.delimiter, me.buf, me.read) + } +} diff --git a/third_party/rust/tokio/src/io/util/repeat.rs b/third_party/rust/tokio/src/io/util/repeat.rs new file mode 100644 index 0000000000..1142765df5 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/repeat.rs @@ -0,0 +1,72 @@ +use crate::io::{AsyncRead, ReadBuf}; + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// An async reader which yields one byte over and over and over and over and + /// over and... + /// + /// This struct is generally created by calling [`repeat`][repeat]. Please + /// see the documentation of `repeat()` for more details. + /// + /// This is an asynchronous version of [`std::io::Repeat`][std]. + /// + /// [repeat]: fn@repeat + /// [std]: std::io::Repeat + #[derive(Debug)] + pub struct Repeat { + byte: u8, + } + + /// Creates an instance of an async reader that infinitely repeats one byte. + /// + /// All reads from this reader will succeed by filling the specified buffer with + /// the given byte. + /// + /// This is an asynchronous version of [`std::io::repeat`][std]. + /// + /// [std]: std::io::repeat + /// + /// # Examples + /// + /// ``` + /// use tokio::io::{self, AsyncReadExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut buffer = [0; 3]; + /// io::repeat(0b101).read_exact(&mut buffer).await.unwrap(); + /// assert_eq!(buffer, [0b101, 0b101, 0b101]); + /// } + /// ``` + pub fn repeat(byte: u8) -> Repeat { + Repeat { byte } + } +} + +impl AsyncRead for Repeat { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // TODO: could be faster, but should we unsafe it? + while buf.remaining() != 0 { + buf.put_slice(&[self.byte]); + } + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Repeat>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/shutdown.rs b/third_party/rust/tokio/src/io/util/shutdown.rs new file mode 100644 index 0000000000..6d30b004b1 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/shutdown.rs @@ -0,0 +1,46 @@ +use crate::io::AsyncWrite; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future used to shutdown an I/O object. + /// + /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. + /// [shutdown]: crate::io::AsyncWriteExt::shutdown + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[derive(Debug)] + pub struct Shutdown<'a, A: ?Sized> { + a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +/// Creates a future which will shutdown an I/O object. +pub(super) fn shutdown<A>(a: &mut A) -> Shutdown<'_, A> +where + A: AsyncWrite + Unpin + ?Sized, +{ + Shutdown { + a, + _pin: PhantomPinned, + } +} + +impl<A> Future for Shutdown<'_, A> +where + A: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + Pin::new(me.a).poll_shutdown(cx) + } +} diff --git a/third_party/rust/tokio/src/io/util/sink.rs b/third_party/rust/tokio/src/io/util/sink.rs new file mode 100644 index 0000000000..05ee773fa3 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/sink.rs @@ -0,0 +1,87 @@ +use crate::io::AsyncWrite; + +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// An async writer which will move data into the void. + /// + /// This struct is generally created by calling [`sink`][sink]. Please + /// see the documentation of `sink()` for more details. + /// + /// This is an asynchronous version of [`std::io::Sink`][std]. + /// + /// [sink]: sink() + /// [std]: std::io::Sink + pub struct Sink { + _p: (), + } + + /// Creates an instance of an async writer which will successfully consume all + /// data. + /// + /// All calls to [`poll_write`] on the returned instance will return + /// `Poll::Ready(Ok(buf.len()))` and the contents of the buffer will not be + /// inspected. + /// + /// This is an asynchronous version of [`std::io::sink`][std]. + /// + /// [`poll_write`]: crate::io::AsyncWrite::poll_write() + /// [std]: std::io::sink + /// + /// # Examples + /// + /// ``` + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let buffer = vec![1, 2, 3, 5, 8]; + /// let num_bytes = io::sink().write(&buffer).await?; + /// assert_eq!(num_bytes, 5); + /// Ok(()) + /// } + /// ``` + pub fn sink() -> Sink { + Sink { _p: () } + } +} + +impl AsyncWrite for Sink { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + Poll::Ready(Ok(buf.len())) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } +} + +impl fmt::Debug for Sink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Sink { .. }") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Sink>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/split.rs b/third_party/rust/tokio/src/io/util/split.rs new file mode 100644 index 0000000000..7489c24281 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/split.rs @@ -0,0 +1,121 @@ +use crate::io::util::read_until::read_until_internal; +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::io; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Splitter for the [`split`](crate::io::AsyncBufReadExt::split) method. + /// + /// A `Split` can be turned into a `Stream` with [`SplitStream`]. + /// + /// [`SplitStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.SplitStream.html + #[derive(Debug)] + #[must_use = "streams do nothing unless polled"] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct Split<R> { + #[pin] + reader: R, + buf: Vec<u8>, + delim: u8, + read: usize, + } +} + +pub(crate) fn split<R>(reader: R, delim: u8) -> Split<R> +where + R: AsyncBufRead, +{ + Split { + reader, + buf: Vec::new(), + delim, + read: 0, + } +} + +impl<R> Split<R> +where + R: AsyncBufRead + Unpin, +{ + /// Returns the next segment in the stream. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncBufRead; + /// use tokio::io::AsyncBufReadExt; + /// + /// # async fn dox(my_buf_read: impl AsyncBufRead + Unpin) -> std::io::Result<()> { + /// let mut segments = my_buf_read.split(b'f'); + /// + /// while let Some(segment) = segments.next_segment().await? { + /// println!("length = {}", segment.len()) + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn next_segment(&mut self) -> io::Result<Option<Vec<u8>>> { + use crate::future::poll_fn; + + poll_fn(|cx| Pin::new(&mut *self).poll_next_segment(cx)).await + } +} + +impl<R> Split<R> +where + R: AsyncBufRead, +{ + /// Polls for the next segment in the stream. + /// + /// This method returns: + /// + /// * `Poll::Pending` if the next segment is not yet available. + /// * `Poll::Ready(Ok(Some(segment)))` if the next segment is available. + /// * `Poll::Ready(Ok(None))` if there are no more segments in this stream. + /// * `Poll::Ready(Err(err))` if an IO error occurred while reading the + /// next segment. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when more bytes become + /// available on the underlying IO resource. + /// + /// Note that on multiple calls to `poll_next_segment`, only the `Waker` + /// from the `Context` passed to the most recent call is scheduled to + /// receive a wakeup. + pub fn poll_next_segment( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<Option<Vec<u8>>>> { + let me = self.project(); + + let n = ready!(read_until_internal( + me.reader, cx, *me.delim, me.buf, me.read, + ))?; + // read_until_internal resets me.read to zero once it finds the delimiter + debug_assert_eq!(*me.read, 0); + + if n == 0 && me.buf.is_empty() { + return Poll::Ready(Ok(None)); + } + + if me.buf.last() == Some(me.delim) { + me.buf.pop(); + } + + Poll::Ready(Ok(Some(mem::take(me.buf)))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Split<()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/take.rs b/third_party/rust/tokio/src/io/util/take.rs new file mode 100644 index 0000000000..df2f61b9e6 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/take.rs @@ -0,0 +1,137 @@ +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{cmp, io}; + +pin_project! { + /// Stream for the [`take`](super::AsyncReadExt::take) method. + #[derive(Debug)] + #[must_use = "streams do nothing unless you `.await` or poll them"] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct Take<R> { + #[pin] + inner: R, + // Add '_' to avoid conflicts with `limit` method. + limit_: u64, + } +} + +pub(super) fn take<R: AsyncRead>(inner: R, limit: u64) -> Take<R> { + Take { + inner, + limit_: limit, + } +} + +impl<R: AsyncRead> Take<R> { + /// Returns the remaining number of bytes that can be + /// read before this instance will return EOF. + /// + /// # Note + /// + /// This instance may reach `EOF` after reading fewer bytes than indicated by + /// this method if the underlying [`AsyncRead`] instance reaches EOF. + pub fn limit(&self) -> u64 { + self.limit_ + } + + /// Sets the number of bytes that can be read before this instance will + /// return EOF. This is the same as constructing a new `Take` instance, so + /// the amount of bytes read and the previous limit value don't matter when + /// calling this method. + pub fn set_limit(&mut self, limit: u64) { + self.limit_ = limit + } + + /// Gets a reference to the underlying reader. + pub fn get_ref(&self) -> &R { + &self.inner + } + + /// Gets a mutable reference to the underlying reader. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying reader as doing so may corrupt the internal limit of this + /// `Take`. + pub fn get_mut(&mut self) -> &mut R { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying reader. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying reader as doing so may corrupt the internal limit of this + /// `Take`. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().inner + } + + /// Consumes the `Take`, returning the wrapped reader. + pub fn into_inner(self) -> R { + self.inner + } +} + +impl<R: AsyncRead> AsyncRead for Take<R> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), io::Error>> { + if self.limit_ == 0 { + return Poll::Ready(Ok(())); + } + + let me = self.project(); + let mut b = buf.take(*me.limit_ as usize); + + let buf_ptr = b.filled().as_ptr(); + ready!(me.inner.poll_read(cx, &mut b))?; + assert_eq!(b.filled().as_ptr(), buf_ptr); + + let n = b.filled().len(); + + // We need to update the original ReadBuf + unsafe { + buf.assume_init(n); + } + buf.advance(n); + *me.limit_ -= n as u64; + Poll::Ready(Ok(())) + } +} + +impl<R: AsyncBufRead> AsyncBufRead for Take<R> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + let me = self.project(); + + // Don't call into inner reader at all at EOF because it may still block + if *me.limit_ == 0 { + return Poll::Ready(Ok(&[])); + } + + let buf = ready!(me.inner.poll_fill_buf(cx)?); + let cap = cmp::min(buf.len() as u64, *me.limit_) as usize; + Poll::Ready(Ok(&buf[..cap])) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + // Don't let callers reset the limit by passing an overlarge value + let amt = cmp::min(amt as u64, *me.limit_) as usize; + *me.limit_ -= amt as u64; + me.inner.consume(amt); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + crate::is_unpin::<Take<()>>(); + } +} diff --git a/third_party/rust/tokio/src/io/util/vec_with_initialized.rs b/third_party/rust/tokio/src/io/util/vec_with_initialized.rs new file mode 100644 index 0000000000..208cc939c1 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/vec_with_initialized.rs @@ -0,0 +1,132 @@ +use crate::io::ReadBuf; +use std::mem::MaybeUninit; + +mod private { + pub trait Sealed {} + + impl Sealed for Vec<u8> {} + impl Sealed for &mut Vec<u8> {} +} + +/// A sealed trait that constrains the generic type parameter in `VecWithInitialized<V>`. That struct's safety relies +/// on certain invariants upheld by `Vec<u8>`. +pub(crate) trait VecU8: AsMut<Vec<u8>> + private::Sealed {} + +impl VecU8 for Vec<u8> {} +impl VecU8 for &mut Vec<u8> {} +/// This struct wraps a `Vec<u8>` or `&mut Vec<u8>`, combining it with a +/// `num_initialized`, which keeps track of the number of initialized bytes +/// in the unused capacity. +/// +/// The purpose of this struct is to remember how many bytes were initialized +/// through a `ReadBuf` from call to call. +/// +/// This struct has the safety invariant that the first `num_initialized` of the +/// vector's allocation must be initialized at any time. +#[derive(Debug)] +pub(crate) struct VecWithInitialized<V> { + vec: V, + // The number of initialized bytes in the vector. + // Always between `vec.len()` and `vec.capacity()`. + num_initialized: usize, +} + +impl VecWithInitialized<Vec<u8>> { + #[cfg(feature = "io-util")] + pub(crate) fn take(&mut self) -> Vec<u8> { + self.num_initialized = 0; + std::mem::take(&mut self.vec) + } +} + +impl<V> VecWithInitialized<V> +where + V: VecU8, +{ + pub(crate) fn new(mut vec: V) -> Self { + // SAFETY: The safety invariants of vector guarantee that the bytes up + // to its length are initialized. + Self { + num_initialized: vec.as_mut().len(), + vec, + } + } + + pub(crate) fn reserve(&mut self, num_bytes: usize) { + let vec = self.vec.as_mut(); + if vec.capacity() - vec.len() >= num_bytes { + return; + } + // SAFETY: Setting num_initialized to `vec.len()` is correct as + // `reserve` does not change the length of the vector. + self.num_initialized = vec.len(); + vec.reserve(num_bytes); + } + + #[cfg(feature = "io-util")] + pub(crate) fn is_empty(&mut self) -> bool { + self.vec.as_mut().is_empty() + } + + pub(crate) fn get_read_buf<'a>(&'a mut self) -> ReadBuf<'a> { + let num_initialized = self.num_initialized; + + // SAFETY: Creating the slice is safe because of the safety invariants + // on Vec<u8>. The safety invariants of `ReadBuf` will further guarantee + // that no bytes in the slice are de-initialized. + let vec = self.vec.as_mut(); + let len = vec.len(); + let cap = vec.capacity(); + let ptr = vec.as_mut_ptr().cast::<MaybeUninit<u8>>(); + let slice = unsafe { std::slice::from_raw_parts_mut::<'a, MaybeUninit<u8>>(ptr, cap) }; + + // SAFETY: This is safe because the safety invariants of + // VecWithInitialized say that the first num_initialized bytes must be + // initialized. + let mut read_buf = ReadBuf::uninit(slice); + unsafe { + read_buf.assume_init(num_initialized); + } + read_buf.set_filled(len); + + read_buf + } + + pub(crate) fn apply_read_buf(&mut self, parts: ReadBufParts) { + let vec = self.vec.as_mut(); + assert_eq!(vec.as_ptr(), parts.ptr); + + // SAFETY: + // The ReadBufParts really does point inside `self.vec` due to the above + // check, and the safety invariants of `ReadBuf` guarantee that the + // first `parts.initialized` bytes of `self.vec` really have been + // initialized. Additionally, `ReadBuf` guarantees that `parts.len` is + // at most `parts.initialized`, so the first `parts.len` bytes are also + // initialized. + // + // Note that this relies on the fact that `V` is either `Vec<u8>` or + // `&mut Vec<u8>`, so the vector returned by `self.vec.as_mut()` cannot + // change from call to call. + unsafe { + self.num_initialized = parts.initialized; + vec.set_len(parts.len); + } + } +} + +pub(crate) struct ReadBufParts { + // Pointer is only used to check that the ReadBuf actually came from the + // right VecWithInitialized. + ptr: *const u8, + len: usize, + initialized: usize, +} + +// This is needed to release the borrow on `VecWithInitialized<V>`. +pub(crate) fn into_read_buf_parts(rb: ReadBuf<'_>) -> ReadBufParts { + ReadBufParts { + ptr: rb.filled().as_ptr(), + len: rb.filled().len(), + initialized: rb.initialized().len(), + } +} diff --git a/third_party/rust/tokio/src/io/util/write.rs b/third_party/rust/tokio/src/io/util/write.rs new file mode 100644 index 0000000000..92169ebc1d --- /dev/null +++ b/third_party/rust/tokio/src/io/util/write.rs @@ -0,0 +1,46 @@ +use crate::io::AsyncWrite; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Write<'a, W: ?Sized> { + writer: &'a mut W, + buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +pub(crate) fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> Write<'a, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + Write { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl<W> Future for Write<'_, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + Pin::new(&mut *me.writer).poll_write(cx, me.buf) + } +} diff --git a/third_party/rust/tokio/src/io/util/write_all.rs b/third_party/rust/tokio/src/io/util/write_all.rs new file mode 100644 index 0000000000..abd3e39d31 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/write_all.rs @@ -0,0 +1,55 @@ +use crate::io::AsyncWrite; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteAll<'a, W: ?Sized> { + writer: &'a mut W, + buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + WriteAll { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl<W> Future for WriteAll<'_, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = self.project(); + while !me.buf.is_empty() { + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?; + { + let (_, rest) = mem::take(&mut *me.buf).split_at(n); + *me.buf = rest; + } + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio/src/io/util/write_all_buf.rs b/third_party/rust/tokio/src/io/util/write_all_buf.rs new file mode 100644 index 0000000000..05af7fe99b --- /dev/null +++ b/third_party/rust/tokio/src/io/util/write_all_buf.rs @@ -0,0 +1,56 @@ +use crate::io::AsyncWrite; + +use bytes::Buf; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteAllBuf<'a, W, B> { + writer: &'a mut W, + buf: &'a mut B, + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +pub(crate) fn write_all_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteAllBuf<'a, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + WriteAllBuf { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl<W, B> Future for WriteAllBuf<'_, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = self.project(); + while me.buf.has_remaining() { + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?); + me.buf.advance(n); + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + + Poll::Ready(Ok(())) + } +} diff --git a/third_party/rust/tokio/src/io/util/write_buf.rs b/third_party/rust/tokio/src/io/util/write_buf.rs new file mode 100644 index 0000000000..82fd7a759f --- /dev/null +++ b/third_party/rust/tokio/src/io/util/write_buf.rs @@ -0,0 +1,55 @@ +use crate::io::AsyncWrite; + +use bytes::Buf; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteBuf<'a, W, B> { + writer: &'a mut W, + buf: &'a mut B, + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +pub(crate) fn write_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteBuf<'a, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + WriteBuf { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl<W, B> Future for WriteBuf<'_, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + + if !me.buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.chunk()))?; + me.buf.advance(n); + Poll::Ready(Ok(n)) + } +} diff --git a/third_party/rust/tokio/src/io/util/write_int.rs b/third_party/rust/tokio/src/io/util/write_int.rs new file mode 100644 index 0000000000..63cd49126f --- /dev/null +++ b/third_party/rust/tokio/src/io/util/write_int.rs @@ -0,0 +1,152 @@ +use crate::io::AsyncWrite; + +use bytes::BufMut; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::mem::size_of; +use std::pin::Pin; +use std::task::{Context, Poll}; + +macro_rules! writer { + ($name:ident, $ty:ty, $writer:ident) => { + writer!($name, $ty, $writer, size_of::<$ty>()); + }; + ($name:ident, $ty:ty, $writer:ident, $bytes:expr) => { + pin_project! { + #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct $name<W> { + #[pin] + dst: W, + buf: [u8; $bytes], + written: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } + } + + impl<W> $name<W> { + pub(crate) fn new(w: W, value: $ty) -> Self { + let mut writer = Self { + buf: [0; $bytes], + written: 0, + dst: w, + _pin: PhantomPinned, + }; + BufMut::$writer(&mut &mut writer.buf[..], value); + writer + } + } + + impl<W> Future for $name<W> + where + W: AsyncWrite, + { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut me = self.project(); + + if *me.written == $bytes as u8 { + return Poll::Ready(Ok(())); + } + + while *me.written < $bytes as u8 { + *me.written += match me + .dst + .as_mut() + .poll_write(cx, &me.buf[*me.written as usize..]) + { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Ready(Ok(0)) => { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + Poll::Ready(Ok(n)) => n as u8, + }; + } + Poll::Ready(Ok(())) + } + } + }; +} + +macro_rules! writer8 { + ($name:ident, $ty:ty) => { + pin_project! { + #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct $name<W> { + #[pin] + dst: W, + byte: $ty, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } + } + + impl<W> $name<W> { + pub(crate) fn new(dst: W, byte: $ty) -> Self { + Self { + dst, + byte, + _pin: PhantomPinned, + } + } + } + + impl<W> Future for $name<W> + where + W: AsyncWrite, + { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + let buf = [*me.byte as u8]; + + match me.dst.poll_write(cx, &buf[..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Poll::Ready(Ok(1)) => Poll::Ready(Ok(())), + Poll::Ready(Ok(_)) => unreachable!(), + } + } + } + }; +} + +writer8!(WriteU8, u8); +writer8!(WriteI8, i8); + +writer!(WriteU16, u16, put_u16); +writer!(WriteU32, u32, put_u32); +writer!(WriteU64, u64, put_u64); +writer!(WriteU128, u128, put_u128); + +writer!(WriteI16, i16, put_i16); +writer!(WriteI32, i32, put_i32); +writer!(WriteI64, i64, put_i64); +writer!(WriteI128, i128, put_i128); + +writer!(WriteF32, f32, put_f32); +writer!(WriteF64, f64, put_f64); + +writer!(WriteU16Le, u16, put_u16_le); +writer!(WriteU32Le, u32, put_u32_le); +writer!(WriteU64Le, u64, put_u64_le); +writer!(WriteU128Le, u128, put_u128_le); + +writer!(WriteI16Le, i16, put_i16_le); +writer!(WriteI32Le, i32, put_i32_le); +writer!(WriteI64Le, i64, put_i64_le); +writer!(WriteI128Le, i128, put_i128_le); + +writer!(WriteF32Le, f32, put_f32_le); +writer!(WriteF64Le, f64, put_f64_le); diff --git a/third_party/rust/tokio/src/io/util/write_vectored.rs b/third_party/rust/tokio/src/io/util/write_vectored.rs new file mode 100644 index 0000000000..be40322943 --- /dev/null +++ b/third_party/rust/tokio/src/io/util/write_vectored.rs @@ -0,0 +1,47 @@ +use crate::io::AsyncWrite; + +use pin_project_lite::pin_project; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{future::Future, io::IoSlice}; + +pin_project! { + /// A future to write a slice of buffers to an `AsyncWrite`. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteVectored<'a, 'b, W: ?Sized> { + writer: &'a mut W, + bufs: &'a [IoSlice<'b>], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn write_vectored<'a, 'b, W>( + writer: &'a mut W, + bufs: &'a [IoSlice<'b>], +) -> WriteVectored<'a, 'b, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + WriteVectored { + writer, + bufs, + _pin: PhantomPinned, + } +} + +impl<W> Future for WriteVectored<'_, '_, W> +where + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + Pin::new(&mut *me.writer).poll_write_vectored(cx, me.bufs) + } +} diff --git a/third_party/rust/tokio/src/lib.rs b/third_party/rust/tokio/src/lib.rs new file mode 100644 index 0000000000..35295d837a --- /dev/null +++ b/third_party/rust/tokio/src/lib.rs @@ -0,0 +1,537 @@ +#![allow( + clippy::cognitive_complexity, + clippy::large_enum_variant, + clippy::needless_doctest_main +)] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![deny(unused_must_use)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, allow(unused_attributes))] + +//! A runtime for writing reliable network applications without compromising speed. +//! +//! Tokio is an event-driven, non-blocking I/O platform for writing asynchronous +//! applications with the Rust programming language. At a high level, it +//! provides a few major components: +//! +//! * Tools for [working with asynchronous tasks][tasks], including +//! [synchronization primitives and channels][sync] and [timeouts, sleeps, and +//! intervals][time]. +//! * APIs for [performing asynchronous I/O][io], including [TCP and UDP][net] sockets, +//! [filesystem][fs] operations, and [process] and [signal] management. +//! * A [runtime] for executing asynchronous code, including a task scheduler, +//! an I/O driver backed by the operating system's event queue (epoll, kqueue, +//! IOCP, etc...), and a high performance timer. +//! +//! Guide level documentation is found on the [website]. +//! +//! [tasks]: #working-with-tasks +//! [sync]: crate::sync +//! [time]: crate::time +//! [io]: #asynchronous-io +//! [net]: crate::net +//! [fs]: crate::fs +//! [process]: crate::process +//! [signal]: crate::signal +//! [fs]: crate::fs +//! [runtime]: crate::runtime +//! [website]: https://tokio.rs/tokio/tutorial +//! +//! # A Tour of Tokio +//! +//! Tokio consists of a number of modules that provide a range of functionality +//! essential for implementing asynchronous applications in Rust. In this +//! section, we will take a brief tour of Tokio, summarizing the major APIs and +//! their uses. +//! +//! The easiest way to get started is to enable all features. Do this by +//! enabling the `full` feature flag: +//! +//! ```toml +//! tokio = { version = "1", features = ["full"] } +//! ``` +//! +//! ### Authoring applications +//! +//! Tokio is great for writing applications and most users in this case shouldn't +//! worry too much about what features they should pick. If you're unsure, we suggest +//! going with `full` to ensure that you don't run into any road blocks while you're +//! building your application. +//! +//! #### Example +//! +//! This example shows the quickest way to get started with Tokio. +//! +//! ```toml +//! tokio = { version = "1", features = ["full"] } +//! ``` +//! +//! ### Authoring libraries +//! +//! As a library author your goal should be to provide the lightest weight crate +//! that is based on Tokio. To achieve this you should ensure that you only enable +//! the features you need. This allows users to pick up your crate without having +//! to enable unnecessary features. +//! +//! #### Example +//! +//! This example shows how you may want to import features for a library that just +//! needs to `tokio::spawn` and use a `TcpStream`. +//! +//! ```toml +//! tokio = { version = "1", features = ["rt", "net"] } +//! ``` +//! +//! ## Working With Tasks +//! +//! Asynchronous programs in Rust are based around lightweight, non-blocking +//! units of execution called [_tasks_][tasks]. The [`tokio::task`] module provides +//! important tools for working with tasks: +//! +//! * The [`spawn`] function and [`JoinHandle`] type, for scheduling a new task +//! on the Tokio runtime and awaiting the output of a spawned task, respectively, +//! * Functions for [running blocking operations][blocking] in an asynchronous +//! task context. +//! +//! The [`tokio::task`] module is present only when the "rt" feature flag +//! is enabled. +//! +//! [tasks]: task/index.html#what-are-tasks +//! [`tokio::task`]: crate::task +//! [`spawn`]: crate::task::spawn() +//! [`JoinHandle`]: crate::task::JoinHandle +//! [blocking]: task/index.html#blocking-and-yielding +//! +//! The [`tokio::sync`] module contains synchronization primitives to use when +//! needing to communicate or share data. These include: +//! +//! * channels ([`oneshot`], [`mpsc`], and [`watch`]), for sending values +//! between tasks, +//! * a non-blocking [`Mutex`], for controlling access to a shared, mutable +//! value, +//! * an asynchronous [`Barrier`] type, for multiple tasks to synchronize before +//! beginning a computation. +//! +//! The `tokio::sync` module is present only when the "sync" feature flag is +//! enabled. +//! +//! [`tokio::sync`]: crate::sync +//! [`Mutex`]: crate::sync::Mutex +//! [`Barrier`]: crate::sync::Barrier +//! [`oneshot`]: crate::sync::oneshot +//! [`mpsc`]: crate::sync::mpsc +//! [`watch`]: crate::sync::watch +//! +//! The [`tokio::time`] module provides utilities for tracking time and +//! scheduling work. This includes functions for setting [timeouts][timeout] for +//! tasks, [sleeping][sleep] work to run in the future, or [repeating an operation at an +//! interval][interval]. +//! +//! In order to use `tokio::time`, the "time" feature flag must be enabled. +//! +//! [`tokio::time`]: crate::time +//! [sleep]: crate::time::sleep() +//! [interval]: crate::time::interval() +//! [timeout]: crate::time::timeout() +//! +//! Finally, Tokio provides a _runtime_ for executing asynchronous tasks. Most +//! applications can use the [`#[tokio::main]`][main] macro to run their code on the +//! Tokio runtime. However, this macro provides only basic configuration options. As +//! an alternative, the [`tokio::runtime`] module provides more powerful APIs for configuring +//! and managing runtimes. You should use that module if the `#[tokio::main]` macro doesn't +//! provide the functionality you need. +//! +//! Using the runtime requires the "rt" or "rt-multi-thread" feature flags, to +//! enable the basic [single-threaded scheduler][rt] and the [thread-pool +//! scheduler][rt-multi-thread], respectively. See the [`runtime` module +//! documentation][rt-features] for details. In addition, the "macros" feature +//! flag enables the `#[tokio::main]` and `#[tokio::test]` attributes. +//! +//! [main]: attr.main.html +//! [`tokio::runtime`]: crate::runtime +//! [`Builder`]: crate::runtime::Builder +//! [`Runtime`]: crate::runtime::Runtime +//! [rt]: runtime/index.html#current-thread-scheduler +//! [rt-multi-thread]: runtime/index.html#multi-thread-scheduler +//! [rt-features]: runtime/index.html#runtime-scheduler +//! +//! ## CPU-bound tasks and blocking code +//! +//! Tokio is able to concurrently run many tasks on a few threads by repeatedly +//! swapping the currently running task on each thread. However, this kind of +//! swapping can only happen at `.await` points, so code that spends a long time +//! without reaching an `.await` will prevent other tasks from running. To +//! combat this, Tokio provides two kinds of threads: Core threads and blocking +//! threads. The core threads are where all asynchronous code runs, and Tokio +//! will by default spawn one for each CPU core. The blocking threads are +//! spawned on demand, can be used to run blocking code that would otherwise +//! block other tasks from running and are kept alive when not used for a certain +//! amount of time which can be configured with [`thread_keep_alive`]. +//! Since it is not possible for Tokio to swap out blocking tasks, like it +//! can do with asynchronous code, the upper limit on the number of blocking +//! threads is very large. These limits can be configured on the [`Builder`]. +//! +//! To spawn a blocking task, you should use the [`spawn_blocking`] function. +//! +//! [`Builder`]: crate::runtime::Builder +//! [`spawn_blocking`]: crate::task::spawn_blocking() +//! [`thread_keep_alive`]: crate::runtime::Builder::thread_keep_alive() +//! +//! ``` +//! #[tokio::main] +//! async fn main() { +//! // This is running on a core thread. +//! +//! let blocking_task = tokio::task::spawn_blocking(|| { +//! // This is running on a blocking thread. +//! // Blocking here is ok. +//! }); +//! +//! // We can wait for the blocking task like this: +//! // If the blocking task panics, the unwrap below will propagate the +//! // panic. +//! blocking_task.await.unwrap(); +//! } +//! ``` +//! +//! If your code is CPU-bound and you wish to limit the number of threads used +//! to run it, you should use a separate thread pool dedicated to CPU bound tasks. +//! For example, you could consider using the [rayon] library for CPU-bound +//! tasks. It is also possible to create an extra Tokio runtime dedicated to +//! CPU-bound tasks, but if you do this, you should be careful that the extra +//! runtime runs _only_ CPU-bound tasks, as IO-bound tasks on that runtime +//! will behave poorly. +//! +//! Hint: If using rayon, you can use a [`oneshot`] channel to send the result back +//! to Tokio when the rayon task finishes. +//! +//! [rayon]: https://docs.rs/rayon +//! [`oneshot`]: crate::sync::oneshot +//! +//! ## Asynchronous IO +//! +//! As well as scheduling and running tasks, Tokio provides everything you need +//! to perform input and output asynchronously. +//! +//! The [`tokio::io`] module provides Tokio's asynchronous core I/O primitives, +//! the [`AsyncRead`], [`AsyncWrite`], and [`AsyncBufRead`] traits. In addition, +//! when the "io-util" feature flag is enabled, it also provides combinators and +//! functions for working with these traits, forming as an asynchronous +//! counterpart to [`std::io`]. +//! +//! Tokio also includes APIs for performing various kinds of I/O and interacting +//! with the operating system asynchronously. These include: +//! +//! * [`tokio::net`], which contains non-blocking versions of [TCP], [UDP], and +//! [Unix Domain Sockets][UDS] (enabled by the "net" feature flag), +//! * [`tokio::fs`], similar to [`std::fs`] but for performing filesystem I/O +//! asynchronously (enabled by the "fs" feature flag), +//! * [`tokio::signal`], for asynchronously handling Unix and Windows OS signals +//! (enabled by the "signal" feature flag), +//! * [`tokio::process`], for spawning and managing child processes (enabled by +//! the "process" feature flag). +//! +//! [`tokio::io`]: crate::io +//! [`AsyncRead`]: crate::io::AsyncRead +//! [`AsyncWrite`]: crate::io::AsyncWrite +//! [`AsyncBufRead`]: crate::io::AsyncBufRead +//! [`std::io`]: std::io +//! [`tokio::net`]: crate::net +//! [TCP]: crate::net::tcp +//! [UDP]: crate::net::UdpSocket +//! [UDS]: crate::net::unix +//! [`tokio::fs`]: crate::fs +//! [`std::fs`]: std::fs +//! [`tokio::signal`]: crate::signal +//! [`tokio::process`]: crate::process +//! +//! # Examples +//! +//! A simple TCP echo server: +//! +//! ```no_run +//! use tokio::net::TcpListener; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; +//! +//! loop { +//! let (mut socket, _) = listener.accept().await?; +//! +//! tokio::spawn(async move { +//! let mut buf = [0; 1024]; +//! +//! // In a loop, read data from the socket and write the data back. +//! loop { +//! let n = match socket.read(&mut buf).await { +//! // socket closed +//! Ok(n) if n == 0 => return, +//! Ok(n) => n, +//! Err(e) => { +//! eprintln!("failed to read from socket; err = {:?}", e); +//! return; +//! } +//! }; +//! +//! // Write the data back +//! if let Err(e) = socket.write_all(&buf[0..n]).await { +//! eprintln!("failed to write to socket; err = {:?}", e); +//! return; +//! } +//! } +//! }); +//! } +//! } +//! ``` +//! +//! ## Feature flags +//! +//! Tokio uses a set of [feature flags] to reduce the amount of compiled code. It +//! is possible to just enable certain features over others. By default, Tokio +//! does not enable any features but allows one to enable a subset for their use +//! case. Below is a list of the available feature flags. You may also notice +//! above each function, struct and trait there is listed one or more feature flags +//! that are required for that item to be used. If you are new to Tokio it is +//! recommended that you use the `full` feature flag which will enable all public APIs. +//! Beware though that this will pull in many extra dependencies that you may not +//! need. +//! +//! - `full`: Enables all features listed below except `test-util` and `tracing`. +//! - `rt`: Enables `tokio::spawn`, the basic (current thread) scheduler, +//! and non-scheduler utilities. +//! - `rt-multi-thread`: Enables the heavier, multi-threaded, work-stealing scheduler. +//! - `io-util`: Enables the IO based `Ext` traits. +//! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types. +//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and +//! `UdpSocket`, as well as (on Unix-like systems) `AsyncFd` and (on +//! FreeBSD) `PollAio`. +//! - `time`: Enables `tokio::time` types and allows the schedulers to enable +//! the built in timer. +//! - `process`: Enables `tokio::process` types. +//! - `macros`: Enables `#[tokio::main]` and `#[tokio::test]` macros. +//! - `sync`: Enables all `tokio::sync` types. +//! - `signal`: Enables all `tokio::signal` types. +//! - `fs`: Enables `tokio::fs` types. +//! - `test-util`: Enables testing based infrastructure for the Tokio runtime. +//! +//! _Note: `AsyncRead` and `AsyncWrite` traits do not require any features and are +//! always available._ +//! +//! ### Internal features +//! +//! These features do not expose any new API, but influence internal +//! implementation aspects of Tokio, and can pull in additional +//! dependencies. +//! +//! - `parking_lot`: As a potential optimization, use the _parking_lot_ crate's +//! synchronization primitives internally. MSRV may increase according to the +//! _parking_lot_ release in use. +//! +//! ### Unstable features +//! +//! These feature flags enable **unstable** features. The public API may break in 1.x +//! releases. To enable these features, the `--cfg tokio_unstable` must be passed to +//! `rustc` when compiling. This is easiest done using the `RUSTFLAGS` env variable: +//! `RUSTFLAGS="--cfg tokio_unstable"`. +//! +//! - `tracing`: Enables tracing events. +//! +//! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section + +// Test that pointer width is compatible. This asserts that e.g. usize is at +// least 32 bits, which a lot of components in Tokio currently assumes. +// +// TODO: improve once we have MSRV access to const eval to make more flexible. +#[cfg(not(any( + target_pointer_width = "32", + target_pointer_width = "64", + target_pointer_width = "128" +)))] +compile_error! { + "Tokio requires the platform pointer width to be 32, 64, or 128 bits" +} + +// Includes re-exports used by macros. +// +// This module is not intended to be part of the public API. In general, any +// `doc(hidden)` code is not part of Tokio's public and stable API. +#[macro_use] +#[doc(hidden)] +pub mod macros; + +cfg_fs! { + pub mod fs; +} + +mod future; + +pub mod io; +pub mod net; + +mod loom; +mod park; + +cfg_process! { + pub mod process; +} + +#[cfg(any(feature = "net", feature = "fs", feature = "io-std"))] +mod blocking; + +cfg_rt! { + pub mod runtime; +} + +pub(crate) mod coop; + +cfg_signal! { + pub mod signal; +} + +cfg_signal_internal! { + #[cfg(not(feature = "signal"))] + #[allow(dead_code)] + #[allow(unreachable_pub)] + pub(crate) mod signal; +} + +cfg_sync! { + pub mod sync; +} +cfg_not_sync! { + mod sync; +} + +pub mod task; +cfg_rt! { + pub use task::spawn; +} + +cfg_time! { + pub mod time; +} + +mod util; + +/// Due to the `Stream` trait's inclusion in `std` landing later than Tokio's 1.0 +/// release, most of the Tokio stream utilities have been moved into the [`tokio-stream`] +/// crate. +/// +/// # Why was `Stream` not included in Tokio 1.0? +/// +/// Originally, we had planned to ship Tokio 1.0 with a stable `Stream` type +/// but unfortunately the [RFC] had not been merged in time for `Stream` to +/// reach `std` on a stable compiler in time for the 1.0 release of Tokio. For +/// this reason, the team has decided to move all `Stream` based utilities to +/// the [`tokio-stream`] crate. While this is not ideal, once `Stream` has made +/// it into the standard library and the MSRV period has passed, we will implement +/// stream for our different types. +/// +/// While this may seem unfortunate, not all is lost as you can get much of the +/// `Stream` support with `async/await` and `while let` loops. It is also possible +/// to create a `impl Stream` from `async fn` using the [`async-stream`] crate. +/// +/// [`tokio-stream`]: https://docs.rs/tokio-stream +/// [`async-stream`]: https://docs.rs/async-stream +/// [RFC]: https://github.com/rust-lang/rfcs/pull/2996 +/// +/// # Example +/// +/// Convert a [`sync::mpsc::Receiver`] to an `impl Stream`. +/// +/// ```rust,no_run +/// use tokio::sync::mpsc; +/// +/// let (tx, mut rx) = mpsc::channel::<usize>(16); +/// +/// let stream = async_stream::stream! { +/// while let Some(item) = rx.recv().await { +/// yield item; +/// } +/// }; +/// ``` +pub mod stream {} + +// local re-exports of platform specific things, allowing for decent +// documentation to be shimmed in on docs.rs + +#[cfg(docsrs)] +pub mod doc; + +#[cfg(docsrs)] +#[allow(unused)] +pub(crate) use self::doc::os; + +#[cfg(not(docsrs))] +#[allow(unused)] +pub(crate) use std::os; + +#[cfg(docsrs)] +#[allow(unused)] +pub(crate) use self::doc::winapi; + +#[cfg(all(not(docsrs), windows, feature = "net"))] +#[allow(unused)] +pub(crate) use ::winapi; + +cfg_macros! { + /// Implementation detail of the `select!` macro. This macro is **not** + /// intended to be used as part of the public API and is permitted to + /// change. + #[doc(hidden)] + pub use tokio_macros::select_priv_declare_output_enum; + + /// Implementation detail of the `select!` macro. This macro is **not** + /// intended to be used as part of the public API and is permitted to + /// change. + #[doc(hidden)] + pub use tokio_macros::select_priv_clean_pattern; + + cfg_rt! { + #[cfg(feature = "rt-multi-thread")] + #[cfg(not(test))] // Work around for rust-lang/rust#62127 + #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + #[doc(inline)] + pub use tokio_macros::main; + + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + #[doc(inline)] + pub use tokio_macros::test; + + cfg_not_rt_multi_thread! { + #[cfg(not(test))] // Work around for rust-lang/rust#62127 + #[doc(inline)] + pub use tokio_macros::main_rt as main; + + #[doc(inline)] + pub use tokio_macros::test_rt as test; + } + } + + // Always fail if rt is not enabled. + cfg_not_rt! { + #[cfg(not(test))] + #[doc(inline)] + pub use tokio_macros::main_fail as main; + + #[doc(inline)] + pub use tokio_macros::test_fail as test; + } +} + +// TODO: rm +#[cfg(feature = "io-util")] +#[cfg(test)] +fn is_unpin<T: Unpin>() {} diff --git a/third_party/rust/tokio/src/loom/mocked.rs b/third_party/rust/tokio/src/loom/mocked.rs new file mode 100644 index 0000000000..367d59b43a --- /dev/null +++ b/third_party/rust/tokio/src/loom/mocked.rs @@ -0,0 +1,40 @@ +pub(crate) use loom::*; + +pub(crate) mod sync { + + pub(crate) use loom::sync::MutexGuard; + + #[derive(Debug)] + pub(crate) struct Mutex<T>(loom::sync::Mutex<T>); + + #[allow(dead_code)] + impl<T> Mutex<T> { + #[inline] + pub(crate) fn new(t: T) -> Mutex<T> { + Mutex(loom::sync::Mutex::new(t)) + } + + #[inline] + pub(crate) fn lock(&self) -> MutexGuard<'_, T> { + self.0.lock().unwrap() + } + + #[inline] + pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> { + self.0.try_lock().ok() + } + } + pub(crate) use loom::sync::*; +} + +pub(crate) mod rand { + pub(crate) fn seed() -> u64 { + 1 + } +} + +pub(crate) mod sys { + pub(crate) fn num_cpus() -> usize { + 2 + } +} diff --git a/third_party/rust/tokio/src/loom/mod.rs b/third_party/rust/tokio/src/loom/mod.rs new file mode 100644 index 0000000000..5957b5377d --- /dev/null +++ b/third_party/rust/tokio/src/loom/mod.rs @@ -0,0 +1,14 @@ +//! This module abstracts over `loom` and `std::sync` depending on whether we +//! are running tests or not. + +#![allow(unused)] + +#[cfg(not(all(test, loom)))] +mod std; +#[cfg(not(all(test, loom)))] +pub(crate) use self::std::*; + +#[cfg(all(test, loom))] +mod mocked; +#[cfg(all(test, loom))] +pub(crate) use self::mocked::*; diff --git a/third_party/rust/tokio/src/loom/std/atomic_ptr.rs b/third_party/rust/tokio/src/loom/std/atomic_ptr.rs new file mode 100644 index 0000000000..236645f037 --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/atomic_ptr.rs @@ -0,0 +1,34 @@ +use std::fmt; +use std::ops::{Deref, DerefMut}; + +/// `AtomicPtr` providing an additional `load_unsync` function. +pub(crate) struct AtomicPtr<T> { + inner: std::sync::atomic::AtomicPtr<T>, +} + +impl<T> AtomicPtr<T> { + pub(crate) fn new(ptr: *mut T) -> AtomicPtr<T> { + let inner = std::sync::atomic::AtomicPtr::new(ptr); + AtomicPtr { inner } + } +} + +impl<T> Deref for AtomicPtr<T> { + type Target = std::sync::atomic::AtomicPtr<T>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<T> DerefMut for AtomicPtr<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<T> fmt::Debug for AtomicPtr<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.deref().fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/loom/std/atomic_u16.rs b/third_party/rust/tokio/src/loom/std/atomic_u16.rs new file mode 100644 index 0000000000..c1c531208c --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/atomic_u16.rs @@ -0,0 +1,44 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::ops::Deref; + +/// `AtomicU16` providing an additional `load_unsync` function. +pub(crate) struct AtomicU16 { + inner: UnsafeCell<std::sync::atomic::AtomicU16>, +} + +unsafe impl Send for AtomicU16 {} +unsafe impl Sync for AtomicU16 {} + +impl AtomicU16 { + pub(crate) const fn new(val: u16) -> AtomicU16 { + let inner = UnsafeCell::new(std::sync::atomic::AtomicU16::new(val)); + AtomicU16 { inner } + } + + /// Performs an unsynchronized load. + /// + /// # Safety + /// + /// All mutations must have happened before the unsynchronized load. + /// Additionally, there must be no concurrent mutations. + pub(crate) unsafe fn unsync_load(&self) -> u16 { + *(*self.inner.get()).get_mut() + } +} + +impl Deref for AtomicU16 { + type Target = std::sync::atomic::AtomicU16; + + fn deref(&self) -> &Self::Target { + // safety: it is always safe to access `&self` fns on the inner value as + // we never perform unsafe mutations. + unsafe { &*self.inner.get() } + } +} + +impl fmt::Debug for AtomicU16 { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.deref().fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/loom/std/atomic_u32.rs b/third_party/rust/tokio/src/loom/std/atomic_u32.rs new file mode 100644 index 0000000000..61f95fb30c --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/atomic_u32.rs @@ -0,0 +1,34 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::ops::Deref; + +/// `AtomicU32` providing an additional `load_unsync` function. +pub(crate) struct AtomicU32 { + inner: UnsafeCell<std::sync::atomic::AtomicU32>, +} + +unsafe impl Send for AtomicU32 {} +unsafe impl Sync for AtomicU32 {} + +impl AtomicU32 { + pub(crate) const fn new(val: u32) -> AtomicU32 { + let inner = UnsafeCell::new(std::sync::atomic::AtomicU32::new(val)); + AtomicU32 { inner } + } +} + +impl Deref for AtomicU32 { + type Target = std::sync::atomic::AtomicU32; + + fn deref(&self) -> &Self::Target { + // safety: it is always safe to access `&self` fns on the inner value as + // we never perform unsafe mutations. + unsafe { &*self.inner.get() } + } +} + +impl fmt::Debug for AtomicU32 { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.deref().fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/loom/std/atomic_u64.rs b/third_party/rust/tokio/src/loom/std/atomic_u64.rs new file mode 100644 index 0000000000..113992d977 --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/atomic_u64.rs @@ -0,0 +1,78 @@ +//! Implementation of an atomic u64 cell. On 64 bit platforms, this is a +//! re-export of `AtomicU64`. On 32 bit platforms, this is implemented using a +//! `Mutex`. + +// `AtomicU64` can only be used on targets with `target_has_atomic` is 64 or greater. +// Once `cfg_target_has_atomic` feature is stable, we can replace it with +// `#[cfg(target_has_atomic = "64")]`. +// Refs: https://github.com/rust-lang/rust/tree/master/src/librustc_target +cfg_has_atomic_u64! { + pub(crate) use std::sync::atomic::AtomicU64; +} + +cfg_not_has_atomic_u64! { + use crate::loom::sync::Mutex; + use std::sync::atomic::Ordering; + + #[derive(Debug)] + pub(crate) struct AtomicU64 { + inner: Mutex<u64>, + } + + impl AtomicU64 { + pub(crate) fn new(val: u64) -> Self { + Self { + inner: Mutex::new(val), + } + } + + pub(crate) fn load(&self, _: Ordering) -> u64 { + *self.inner.lock() + } + + pub(crate) fn store(&self, val: u64, _: Ordering) { + *self.inner.lock() = val; + } + + pub(crate) fn fetch_add(&self, val: u64, _: Ordering) -> u64 { + let mut lock = self.inner.lock(); + let prev = *lock; + *lock = prev + val; + prev + } + + pub(crate) fn fetch_or(&self, val: u64, _: Ordering) -> u64 { + let mut lock = self.inner.lock(); + let prev = *lock; + *lock = prev | val; + prev + } + + pub(crate) fn compare_exchange( + &self, + current: u64, + new: u64, + _success: Ordering, + _failure: Ordering, + ) -> Result<u64, u64> { + let mut lock = self.inner.lock(); + + if *lock == current { + *lock = new; + Ok(current) + } else { + Err(*lock) + } + } + + pub(crate) fn compare_exchange_weak( + &self, + current: u64, + new: u64, + success: Ordering, + failure: Ordering, + ) -> Result<u64, u64> { + self.compare_exchange(current, new, success, failure) + } + } +} diff --git a/third_party/rust/tokio/src/loom/std/atomic_u8.rs b/third_party/rust/tokio/src/loom/std/atomic_u8.rs new file mode 100644 index 0000000000..408aea338c --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/atomic_u8.rs @@ -0,0 +1,34 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::ops::Deref; + +/// `AtomicU8` providing an additional `load_unsync` function. +pub(crate) struct AtomicU8 { + inner: UnsafeCell<std::sync::atomic::AtomicU8>, +} + +unsafe impl Send for AtomicU8 {} +unsafe impl Sync for AtomicU8 {} + +impl AtomicU8 { + pub(crate) const fn new(val: u8) -> AtomicU8 { + let inner = UnsafeCell::new(std::sync::atomic::AtomicU8::new(val)); + AtomicU8 { inner } + } +} + +impl Deref for AtomicU8 { + type Target = std::sync::atomic::AtomicU8; + + fn deref(&self) -> &Self::Target { + // safety: it is always safe to access `&self` fns on the inner value as + // we never perform unsafe mutations. + unsafe { &*self.inner.get() } + } +} + +impl fmt::Debug for AtomicU8 { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.deref().fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/loom/std/atomic_usize.rs b/third_party/rust/tokio/src/loom/std/atomic_usize.rs new file mode 100644 index 0000000000..0d5f36e431 --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/atomic_usize.rs @@ -0,0 +1,56 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::ops; + +/// `AtomicUsize` providing an additional `load_unsync` function. +pub(crate) struct AtomicUsize { + inner: UnsafeCell<std::sync::atomic::AtomicUsize>, +} + +unsafe impl Send for AtomicUsize {} +unsafe impl Sync for AtomicUsize {} + +impl AtomicUsize { + pub(crate) const fn new(val: usize) -> AtomicUsize { + let inner = UnsafeCell::new(std::sync::atomic::AtomicUsize::new(val)); + AtomicUsize { inner } + } + + /// Performs an unsynchronized load. + /// + /// # Safety + /// + /// All mutations must have happened before the unsynchronized load. + /// Additionally, there must be no concurrent mutations. + pub(crate) unsafe fn unsync_load(&self) -> usize { + *(*self.inner.get()).get_mut() + } + + pub(crate) fn with_mut<R>(&mut self, f: impl FnOnce(&mut usize) -> R) -> R { + // safety: we have mutable access + f(unsafe { (*self.inner.get()).get_mut() }) + } +} + +impl ops::Deref for AtomicUsize { + type Target = std::sync::atomic::AtomicUsize; + + fn deref(&self) -> &Self::Target { + // safety: it is always safe to access `&self` fns on the inner value as + // we never perform unsafe mutations. + unsafe { &*self.inner.get() } + } +} + +impl ops::DerefMut for AtomicUsize { + fn deref_mut(&mut self) -> &mut Self::Target { + // safety: we hold `&mut self` + unsafe { &mut *self.inner.get() } + } +} + +impl fmt::Debug for AtomicUsize { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/loom/std/mod.rs b/third_party/rust/tokio/src/loom/std/mod.rs new file mode 100644 index 0000000000..0c70bee74e --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/mod.rs @@ -0,0 +1,108 @@ +#![cfg_attr(any(not(feature = "full"), loom), allow(unused_imports, dead_code))] + +mod atomic_ptr; +mod atomic_u16; +mod atomic_u32; +mod atomic_u64; +mod atomic_u8; +mod atomic_usize; +mod mutex; +#[cfg(feature = "parking_lot")] +mod parking_lot; +mod unsafe_cell; + +pub(crate) mod cell { + pub(crate) use super::unsafe_cell::UnsafeCell; +} + +#[cfg(any( + feature = "net", + feature = "process", + feature = "signal", + feature = "sync", +))] +pub(crate) mod future { + pub(crate) use crate::sync::AtomicWaker; +} + +pub(crate) mod hint { + pub(crate) use std::hint::spin_loop; +} + +pub(crate) mod rand { + use std::collections::hash_map::RandomState; + use std::hash::{BuildHasher, Hash, Hasher}; + use std::sync::atomic::AtomicU32; + use std::sync::atomic::Ordering::Relaxed; + + static COUNTER: AtomicU32 = AtomicU32::new(1); + + pub(crate) fn seed() -> u64 { + let rand_state = RandomState::new(); + + let mut hasher = rand_state.build_hasher(); + + // Hash some unique-ish data to generate some new state + COUNTER.fetch_add(1, Relaxed).hash(&mut hasher); + + // Get the seed + hasher.finish() + } +} + +pub(crate) mod sync { + pub(crate) use std::sync::{Arc, Weak}; + + // Below, make sure all the feature-influenced types are exported for + // internal use. Note however that some are not _currently_ named by + // consuming code. + + #[cfg(feature = "parking_lot")] + #[allow(unused_imports)] + pub(crate) use crate::loom::std::parking_lot::{ + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + }; + + #[cfg(not(feature = "parking_lot"))] + #[allow(unused_imports)] + pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult}; + + #[cfg(not(feature = "parking_lot"))] + pub(crate) use crate::loom::std::mutex::Mutex; + + pub(crate) mod atomic { + pub(crate) use crate::loom::std::atomic_ptr::AtomicPtr; + pub(crate) use crate::loom::std::atomic_u16::AtomicU16; + pub(crate) use crate::loom::std::atomic_u32::AtomicU32; + pub(crate) use crate::loom::std::atomic_u64::AtomicU64; + pub(crate) use crate::loom::std::atomic_u8::AtomicU8; + pub(crate) use crate::loom::std::atomic_usize::AtomicUsize; + + pub(crate) use std::sync::atomic::{fence, AtomicBool, Ordering}; + } +} + +pub(crate) mod sys { + #[cfg(feature = "rt-multi-thread")] + pub(crate) fn num_cpus() -> usize { + usize::max(1, num_cpus::get()) + } + + #[cfg(not(feature = "rt-multi-thread"))] + pub(crate) fn num_cpus() -> usize { + 1 + } +} + +pub(crate) mod thread { + #[inline] + pub(crate) fn yield_now() { + std::hint::spin_loop(); + } + + #[allow(unused_imports)] + pub(crate) use std::thread::{ + current, panicking, park, park_timeout, sleep, spawn, Builder, JoinHandle, LocalKey, + Result, Thread, ThreadId, + }; +} diff --git a/third_party/rust/tokio/src/loom/std/mutex.rs b/third_party/rust/tokio/src/loom/std/mutex.rs new file mode 100644 index 0000000000..3f686e0a78 --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/mutex.rs @@ -0,0 +1,31 @@ +use std::sync::{self, MutexGuard, TryLockError}; + +/// Adapter for `std::Mutex` that removes the poisoning aspects +/// from its api. +#[derive(Debug)] +pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>); + +#[allow(dead_code)] +impl<T> Mutex<T> { + #[inline] + pub(crate) fn new(t: T) -> Mutex<T> { + Mutex(sync::Mutex::new(t)) + } + + #[inline] + pub(crate) fn lock(&self) -> MutexGuard<'_, T> { + match self.0.lock() { + Ok(guard) => guard, + Err(p_err) => p_err.into_inner(), + } + } + + #[inline] + pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> { + match self.0.try_lock() { + Ok(guard) => Some(guard), + Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()), + Err(TryLockError::WouldBlock) => None, + } + } +} diff --git a/third_party/rust/tokio/src/loom/std/parking_lot.rs b/third_party/rust/tokio/src/loom/std/parking_lot.rs new file mode 100644 index 0000000000..034a0ce69a --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/parking_lot.rs @@ -0,0 +1,184 @@ +//! A minimal adaption of the `parking_lot` synchronization primitives to the +//! equivalent `std::sync` types. +//! +//! This can be extended to additional types/methods as required. + +use std::fmt; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; +use std::sync::LockResult; +use std::time::Duration; + +// All types in this file are marked with PhantomData to ensure that +// parking_lot's send_guard feature does not leak through and affect when Tokio +// types are Send. +// +// See <https://github.com/tokio-rs/tokio/pull/4359> for more info. + +// Types that do not need wrapping +pub(crate) use parking_lot::WaitTimeoutResult; + +#[derive(Debug)] +pub(crate) struct Mutex<T: ?Sized>(PhantomData<std::sync::Mutex<T>>, parking_lot::Mutex<T>); + +#[derive(Debug)] +pub(crate) struct RwLock<T>(PhantomData<std::sync::RwLock<T>>, parking_lot::RwLock<T>); + +#[derive(Debug)] +pub(crate) struct Condvar(PhantomData<std::sync::Condvar>, parking_lot::Condvar); + +#[derive(Debug)] +pub(crate) struct MutexGuard<'a, T: ?Sized>( + PhantomData<std::sync::MutexGuard<'a, T>>, + parking_lot::MutexGuard<'a, T>, +); + +#[derive(Debug)] +pub(crate) struct RwLockReadGuard<'a, T: ?Sized>( + PhantomData<std::sync::RwLockReadGuard<'a, T>>, + parking_lot::RwLockReadGuard<'a, T>, +); + +#[derive(Debug)] +pub(crate) struct RwLockWriteGuard<'a, T: ?Sized>( + PhantomData<std::sync::RwLockWriteGuard<'a, T>>, + parking_lot::RwLockWriteGuard<'a, T>, +); + +impl<T> Mutex<T> { + #[inline] + pub(crate) fn new(t: T) -> Mutex<T> { + Mutex(PhantomData, parking_lot::Mutex::new(t)) + } + + #[inline] + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "parking_lot",))))] + pub(crate) const fn const_new(t: T) -> Mutex<T> { + Mutex(PhantomData, parking_lot::const_mutex(t)) + } + + #[inline] + pub(crate) fn lock(&self) -> MutexGuard<'_, T> { + MutexGuard(PhantomData, self.1.lock()) + } + + #[inline] + pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> { + self.1 + .try_lock() + .map(|guard| MutexGuard(PhantomData, guard)) + } + + #[inline] + pub(crate) fn get_mut(&mut self) -> &mut T { + self.1.get_mut() + } + + // Note: Additional methods `is_poisoned` and `into_inner`, can be + // provided here as needed. +} + +impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.1.deref() + } +} + +impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.1.deref_mut() + } +} + +impl<T> RwLock<T> { + pub(crate) fn new(t: T) -> RwLock<T> { + RwLock(PhantomData, parking_lot::RwLock::new(t)) + } + + pub(crate) fn read(&self) -> LockResult<RwLockReadGuard<'_, T>> { + Ok(RwLockReadGuard(PhantomData, self.1.read())) + } + + pub(crate) fn write(&self) -> LockResult<RwLockWriteGuard<'_, T>> { + Ok(RwLockWriteGuard(PhantomData, self.1.write())) + } +} + +impl<'a, T: ?Sized> Deref for RwLockReadGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.1.deref() + } +} + +impl<'a, T: ?Sized> Deref for RwLockWriteGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.1.deref() + } +} + +impl<'a, T: ?Sized> DerefMut for RwLockWriteGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.1.deref_mut() + } +} + +impl Condvar { + #[inline] + pub(crate) fn new() -> Condvar { + Condvar(PhantomData, parking_lot::Condvar::new()) + } + + #[inline] + pub(crate) fn notify_one(&self) { + self.1.notify_one(); + } + + #[inline] + pub(crate) fn notify_all(&self) { + self.1.notify_all(); + } + + #[inline] + pub(crate) fn wait<'a, T>( + &self, + mut guard: MutexGuard<'a, T>, + ) -> LockResult<MutexGuard<'a, T>> { + self.1.wait(&mut guard.1); + Ok(guard) + } + + #[inline] + pub(crate) fn wait_timeout<'a, T>( + &self, + mut guard: MutexGuard<'a, T>, + timeout: Duration, + ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> { + let wtr = self.1.wait_for(&mut guard.1, timeout); + Ok((guard, wtr)) + } + + // Note: Additional methods `wait_timeout_ms`, `wait_timeout_until`, + // `wait_until` can be provided here as needed. +} + +impl<'a, T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.1, f) + } +} + +impl<'a, T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.1, f) + } +} + +impl<'a, T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.1, f) + } +} diff --git a/third_party/rust/tokio/src/loom/std/unsafe_cell.rs b/third_party/rust/tokio/src/loom/std/unsafe_cell.rs new file mode 100644 index 0000000000..66c1d7943e --- /dev/null +++ b/third_party/rust/tokio/src/loom/std/unsafe_cell.rs @@ -0,0 +1,16 @@ +#[derive(Debug)] +pub(crate) struct UnsafeCell<T>(std::cell::UnsafeCell<T>); + +impl<T> UnsafeCell<T> { + pub(crate) const fn new(data: T) -> UnsafeCell<T> { + UnsafeCell(std::cell::UnsafeCell::new(data)) + } + + pub(crate) fn with<R>(&self, f: impl FnOnce(*const T) -> R) -> R { + f(self.0.get()) + } + + pub(crate) fn with_mut<R>(&self, f: impl FnOnce(*mut T) -> R) -> R { + f(self.0.get()) + } +} diff --git a/third_party/rust/tokio/src/macros/cfg.rs b/third_party/rust/tokio/src/macros/cfg.rs new file mode 100644 index 0000000000..b6beb3d695 --- /dev/null +++ b/third_party/rust/tokio/src/macros/cfg.rs @@ -0,0 +1,457 @@ +#![allow(unused_macros)] + +macro_rules! feature { + ( + #![$meta:meta] + $($item:item)* + ) => { + $( + #[cfg($meta)] + #[cfg_attr(docsrs, doc(cfg($meta)))] + $item + )* + } +} + +/// Enables enter::block_on. +macro_rules! cfg_block_on { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "fs", + feature = "net", + feature = "io-std", + feature = "rt", + ))] + $item + )* + } +} + +/// Enables internal `AtomicWaker` impl. +macro_rules! cfg_atomic_waker_impl { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "time", + ))] + #[cfg(not(loom))] + $item + )* + } +} + +macro_rules! cfg_aio { + ($($item:item)*) => { + $( + #[cfg(all(any(docsrs, target_os = "freebsd"), feature = "net"))] + #[cfg_attr(docsrs, + doc(cfg(all(target_os = "freebsd", feature = "net"))) + )] + $item + )* + } +} + +macro_rules! cfg_fs { + ($($item:item)*) => { + $( + #[cfg(feature = "fs")] + #[cfg_attr(docsrs, doc(cfg(feature = "fs")))] + $item + )* + } +} + +macro_rules! cfg_io_blocking { + ($($item:item)*) => { + $( #[cfg(any(feature = "io-std", feature = "fs"))] $item )* + } +} + +macro_rules! cfg_io_driver { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + ))] + #[cfg_attr(docsrs, doc(cfg(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + ))))] + $item + )* + } +} + +macro_rules! cfg_io_driver_impl { + ( $( $item:item )* ) => { + $( + #[cfg(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + ))] + $item + )* + } +} + +macro_rules! cfg_not_io_driver { + ($($item:item)*) => { + $( + #[cfg(not(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + )))] + $item + )* + } +} + +macro_rules! cfg_io_readiness { + ($($item:item)*) => { + $( + #[cfg(feature = "net")] + $item + )* + } +} + +macro_rules! cfg_io_std { + ($($item:item)*) => { + $( + #[cfg(feature = "io-std")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-std")))] + $item + )* + } +} + +macro_rules! cfg_io_util { + ($($item:item)*) => { + $( + #[cfg(feature = "io-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + $item + )* + } +} + +macro_rules! cfg_not_io_util { + ($($item:item)*) => { + $( #[cfg(not(feature = "io-util"))] $item )* + } +} + +macro_rules! cfg_loom { + ($($item:item)*) => { + $( #[cfg(loom)] $item )* + } +} + +macro_rules! cfg_not_loom { + ($($item:item)*) => { + $( #[cfg(not(loom))] $item )* + } +} + +macro_rules! cfg_macros { + ($($item:item)*) => { + $( + #[cfg(feature = "macros")] + #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + $item + )* + } +} + +macro_rules! cfg_metrics { + ($($item:item)*) => { + $( + // For now, metrics is only disabled in loom tests. + // When stabilized, it might have a dedicated feature flag. + #[cfg(all(tokio_unstable, not(loom)))] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + $item + )* + } +} + +macro_rules! cfg_not_metrics { + ($($item:item)*) => { + $( + #[cfg(not(all(tokio_unstable, not(loom))))] + $item + )* + } +} + +macro_rules! cfg_net { + ($($item:item)*) => { + $( + #[cfg(feature = "net")] + #[cfg_attr(docsrs, doc(cfg(feature = "net")))] + $item + )* + } +} + +macro_rules! cfg_net_unix { + ($($item:item)*) => { + $( + #[cfg(all(unix, feature = "net"))] + #[cfg_attr(docsrs, doc(cfg(all(unix, feature = "net"))))] + $item + )* + } +} + +macro_rules! cfg_net_windows { + ($($item:item)*) => { + $( + #[cfg(all(any(all(doc, docsrs), windows), feature = "net"))] + #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "net"))))] + $item + )* + } +} + +macro_rules! cfg_process { + ($($item:item)*) => { + $( + #[cfg(feature = "process")] + #[cfg_attr(docsrs, doc(cfg(feature = "process")))] + #[cfg(not(loom))] + $item + )* + } +} + +macro_rules! cfg_process_driver { + ($($item:item)*) => { + #[cfg(unix)] + #[cfg(not(loom))] + cfg_process! { $($item)* } + } +} + +macro_rules! cfg_not_process_driver { + ($($item:item)*) => { + $( + #[cfg(not(all(unix, not(loom), feature = "process")))] + $item + )* + } +} + +macro_rules! cfg_signal { + ($($item:item)*) => { + $( + #[cfg(feature = "signal")] + #[cfg_attr(docsrs, doc(cfg(feature = "signal")))] + #[cfg(not(loom))] + $item + )* + } +} + +macro_rules! cfg_signal_internal { + ($($item:item)*) => { + $( + #[cfg(any(feature = "signal", all(unix, feature = "process")))] + #[cfg(not(loom))] + $item + )* + } +} + +macro_rules! cfg_not_signal_internal { + ($($item:item)*) => { + $( + #[cfg(any(loom, not(unix), not(any(feature = "signal", all(unix, feature = "process")))))] + $item + )* + } +} + +macro_rules! cfg_sync { + ($($item:item)*) => { + $( + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] + $item + )* + } +} + +macro_rules! cfg_not_sync { + ($($item:item)*) => { + $( #[cfg(not(feature = "sync"))] $item )* + } +} + +macro_rules! cfg_rt { + ($($item:item)*) => { + $( + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + $item + )* + } +} + +macro_rules! cfg_not_rt { + ($($item:item)*) => { + $( #[cfg(not(feature = "rt"))] $item )* + } +} + +macro_rules! cfg_rt_multi_thread { + ($($item:item)*) => { + $( + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + $item + )* + } +} + +macro_rules! cfg_not_rt_multi_thread { + ($($item:item)*) => { + $( #[cfg(not(feature = "rt-multi-thread"))] $item )* + } +} + +macro_rules! cfg_test_util { + ($($item:item)*) => { + $( + #[cfg(feature = "test-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] + $item + )* + } +} + +macro_rules! cfg_not_test_util { + ($($item:item)*) => { + $( #[cfg(not(feature = "test-util"))] $item )* + } +} + +macro_rules! cfg_time { + ($($item:item)*) => { + $( + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + $item + )* + } +} + +macro_rules! cfg_not_time { + ($($item:item)*) => { + $( #[cfg(not(feature = "time"))] $item )* + } +} + +macro_rules! cfg_trace { + ($($item:item)*) => { + $( + #[cfg(all(tokio_unstable, feature = "tracing"))] + #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))] + $item + )* + }; +} + +macro_rules! cfg_unstable { + ($($item:item)*) => { + $( + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + $item + )* + }; +} + +macro_rules! cfg_not_trace { + ($($item:item)*) => { + $( + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + $item + )* + } +} + +macro_rules! cfg_coop { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + $item + )* + } +} + +macro_rules! cfg_not_coop { + ($($item:item)*) => { + $( + #[cfg(not(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + )))] + $item + )* + } +} + +macro_rules! cfg_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(not(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc", + target_arch = "riscv32" + )))] + $item + )* + } +} + +macro_rules! cfg_not_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc", + target_arch = "riscv32" + ))] + $item + )* + } +} diff --git a/third_party/rust/tokio/src/macros/join.rs b/third_party/rust/tokio/src/macros/join.rs new file mode 100644 index 0000000000..f91b5f1914 --- /dev/null +++ b/third_party/rust/tokio/src/macros/join.rs @@ -0,0 +1,119 @@ +/// Waits on multiple concurrent branches, returning when **all** branches +/// complete. +/// +/// The `join!` macro must be used inside of async functions, closures, and +/// blocks. +/// +/// The `join!` macro takes a list of async expressions and evaluates them +/// concurrently on the same task. Each async expression evaluates to a future +/// and the futures from each expression are multiplexed on the current task. +/// +/// When working with async expressions returning `Result`, `join!` will wait +/// for **all** branches complete regardless if any complete with `Err`. Use +/// [`try_join!`] to return early when `Err` is encountered. +/// +/// [`try_join!`]: macro@try_join +/// +/// # Notes +/// +/// The supplied futures are stored inline and does not require allocating a +/// `Vec`. +/// +/// ### Runtime characteristics +/// +/// By running all async expressions on the current task, the expressions are +/// able to run **concurrently** but not in **parallel**. This means all +/// expressions are run on the same thread and if one branch blocks the thread, +/// all other expressions will be unable to continue. If parallelism is +/// required, spawn each async expression using [`tokio::spawn`] and pass the +/// join handle to `join!`. +/// +/// [`tokio::spawn`]: crate::spawn +/// +/// # Examples +/// +/// Basic join with two branches +/// +/// ``` +/// async fn do_stuff_async() { +/// // async work +/// } +/// +/// async fn more_async_work() { +/// // more here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let (first, second) = tokio::join!( +/// do_stuff_async(), +/// more_async_work()); +/// +/// // do something with the values +/// } +/// ``` +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! join { + (@ { + // One `_` for each branch in the `join!` macro. This is not used once + // normalization is complete. + ( $($count:tt)* ) + + // Normalized join! branches + $( ( $($skip:tt)* ) $e:expr, )* + + }) => {{ + use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; + use $crate::macros::support::Poll::{Ready, Pending}; + + // Safety: nothing must be moved out of `futures`. This is to satisfy + // the requirement of `Pin::new_unchecked` called below. + let mut futures = ( $( maybe_done($e), )* ); + + poll_fn(move |cx| { + let mut is_pending = false; + + $( + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + // Try polling + if fut.poll(cx).is_pending() { + is_pending = true; + } + )* + + if is_pending { + Pending + } else { + Ready(($({ + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + fut.take_output().expect("expected completed future") + },)*)) + } + }).await + }}; + + // ===== Normalize ===== + + (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + }; + + // ===== Entry point ===== + + ( $($e:expr),* $(,)?) => { + $crate::join!(@{ () } $($e,)*) + }; +} diff --git a/third_party/rust/tokio/src/macros/loom.rs b/third_party/rust/tokio/src/macros/loom.rs new file mode 100644 index 0000000000..d57d9fb0f7 --- /dev/null +++ b/third_party/rust/tokio/src/macros/loom.rs @@ -0,0 +1,12 @@ +macro_rules! if_loom { + ($($t:tt)*) => {{ + #[cfg(loom)] + const LOOM: bool = true; + #[cfg(not(loom))] + const LOOM: bool = false; + + if LOOM { + $($t)* + } + }} +} diff --git a/third_party/rust/tokio/src/macros/mod.rs b/third_party/rust/tokio/src/macros/mod.rs new file mode 100644 index 0000000000..a1839c8305 --- /dev/null +++ b/third_party/rust/tokio/src/macros/mod.rs @@ -0,0 +1,40 @@ +#![cfg_attr(not(feature = "full"), allow(unused_macros))] + +#[macro_use] +mod cfg; + +#[macro_use] +mod loom; + +#[macro_use] +mod pin; + +#[macro_use] +mod ready; + +#[macro_use] +mod thread_local; + +cfg_trace! { + #[macro_use] + mod trace; +} + +#[macro_use] +#[cfg(feature = "rt")] +pub(crate) mod scoped_tls; + +cfg_macros! { + #[macro_use] + mod select; + + #[macro_use] + mod join; + + #[macro_use] + mod try_join; +} + +// Includes re-exports needed to implement macros +#[doc(hidden)] +pub mod support; diff --git a/third_party/rust/tokio/src/macros/pin.rs b/third_party/rust/tokio/src/macros/pin.rs new file mode 100644 index 0000000000..7af9ce7d1f --- /dev/null +++ b/third_party/rust/tokio/src/macros/pin.rs @@ -0,0 +1,144 @@ +/// Pins a value on the stack. +/// +/// Calls to `async fn` return anonymous [`Future`] values that are `!Unpin`. +/// These values must be pinned before they can be polled. Calling `.await` will +/// handle this, but consumes the future. If it is required to call `.await` on +/// a `&mut _` reference, the caller is responsible for pinning the future. +/// +/// Pinning may be done by allocating with [`Box::pin`] or by using the stack +/// with the `pin!` macro. +/// +/// The following will **fail to compile**: +/// +/// ```compile_fail +/// async fn my_async_fn() { +/// // async logic here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let mut future = my_async_fn(); +/// (&mut future).await; +/// } +/// ``` +/// +/// To make this work requires pinning: +/// +/// ``` +/// use tokio::pin; +/// +/// async fn my_async_fn() { +/// // async logic here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let future = my_async_fn(); +/// pin!(future); +/// +/// (&mut future).await; +/// } +/// ``` +/// +/// Pinning is useful when using `select!` and stream operators that require `T: +/// Stream + Unpin`. +/// +/// [`Future`]: trait@std::future::Future +/// [`Box::pin`]: std::boxed::Box::pin +/// +/// # Usage +/// +/// The `pin!` macro takes **identifiers** as arguments. It does **not** work +/// with expressions. +/// +/// The following does not compile as an expression is passed to `pin!`. +/// +/// ```compile_fail +/// async fn my_async_fn() { +/// // async logic here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let mut future = pin!(my_async_fn()); +/// (&mut future).await; +/// } +/// ``` +/// +/// # Examples +/// +/// Using with select: +/// +/// ``` +/// use tokio::{pin, select}; +/// use tokio_stream::{self as stream, StreamExt}; +/// +/// async fn my_async_fn() { +/// // async logic here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let mut stream = stream::iter(vec![1, 2, 3, 4]); +/// +/// let future = my_async_fn(); +/// pin!(future); +/// +/// loop { +/// select! { +/// _ = &mut future => { +/// // Stop looping `future` will be polled after completion +/// break; +/// } +/// Some(val) = stream.next() => { +/// println!("got value = {}", val); +/// } +/// } +/// } +/// } +/// ``` +/// +/// Because assigning to a variable followed by pinning is common, there is also +/// a variant of the macro that supports doing both in one go. +/// +/// ``` +/// use tokio::{pin, select}; +/// +/// async fn my_async_fn() { +/// // async logic here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// pin! { +/// let future1 = my_async_fn(); +/// let future2 = my_async_fn(); +/// } +/// +/// select! { +/// _ = &mut future1 => {} +/// _ = &mut future2 => {} +/// } +/// } +/// ``` +#[macro_export] +macro_rules! pin { + ($($x:ident),*) => { $( + // Move the value to ensure that it is owned + let mut $x = $x; + // Shadow the original binding so that it can't be directly accessed + // ever again. + #[allow(unused_mut)] + let mut $x = unsafe { + $crate::macros::support::Pin::new_unchecked(&mut $x) + }; + )* }; + ($( + let $x:ident = $init:expr; + )*) => { + $( + let $x = $init; + $crate::pin!($x); + )* + }; +} diff --git a/third_party/rust/tokio/src/macros/ready.rs b/third_party/rust/tokio/src/macros/ready.rs new file mode 100644 index 0000000000..1f48623b80 --- /dev/null +++ b/third_party/rust/tokio/src/macros/ready.rs @@ -0,0 +1,8 @@ +macro_rules! ready { + ($e:expr $(,)?) => { + match $e { + std::task::Poll::Ready(t) => t, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} diff --git a/third_party/rust/tokio/src/macros/scoped_tls.rs b/third_party/rust/tokio/src/macros/scoped_tls.rs new file mode 100644 index 0000000000..f2504cbadf --- /dev/null +++ b/third_party/rust/tokio/src/macros/scoped_tls.rs @@ -0,0 +1,77 @@ +use crate::loom::thread::LocalKey; + +use std::cell::Cell; +use std::marker; + +/// Sets a reference as a thread-local. +macro_rules! scoped_thread_local { + ($(#[$attrs:meta])* $vis:vis static $name:ident: $ty:ty) => ( + $(#[$attrs])* + $vis static $name: $crate::macros::scoped_tls::ScopedKey<$ty> + = $crate::macros::scoped_tls::ScopedKey { + inner: { + thread_local!(static FOO: ::std::cell::Cell<*const ()> = { + std::cell::Cell::new(::std::ptr::null()) + }); + &FOO + }, + _marker: ::std::marker::PhantomData, + }; + ) +} + +/// Type representing a thread local storage key corresponding to a reference +/// to the type parameter `T`. +pub(crate) struct ScopedKey<T> { + pub(crate) inner: &'static LocalKey<Cell<*const ()>>, + pub(crate) _marker: marker::PhantomData<T>, +} + +unsafe impl<T> Sync for ScopedKey<T> {} + +impl<T> ScopedKey<T> { + /// Inserts a value into this scoped thread local storage slot for a + /// duration of a closure. + pub(crate) fn set<F, R>(&'static self, t: &T, f: F) -> R + where + F: FnOnce() -> R, + { + struct Reset { + key: &'static LocalKey<Cell<*const ()>>, + val: *const (), + } + + impl Drop for Reset { + fn drop(&mut self) { + self.key.with(|c| c.set(self.val)); + } + } + + let prev = self.inner.with(|c| { + let prev = c.get(); + c.set(t as *const _ as *const ()); + prev + }); + + let _reset = Reset { + key: self.inner, + val: prev, + }; + + f() + } + + /// Gets a value out of this scoped variable. + pub(crate) fn with<F, R>(&'static self, f: F) -> R + where + F: FnOnce(Option<&T>) -> R, + { + let val = self.inner.with(|c| c.get()); + + if val.is_null() { + f(None) + } else { + unsafe { f(Some(&*(val as *const T))) } + } + } +} diff --git a/third_party/rust/tokio/src/macros/select.rs b/third_party/rust/tokio/src/macros/select.rs new file mode 100644 index 0000000000..051f8cb72a --- /dev/null +++ b/third_party/rust/tokio/src/macros/select.rs @@ -0,0 +1,1001 @@ +/// Waits on multiple concurrent branches, returning when the **first** branch +/// completes, cancelling the remaining branches. +/// +/// The `select!` macro must be used inside of async functions, closures, and +/// blocks. +/// +/// The `select!` macro accepts one or more branches with the following pattern: +/// +/// ```text +/// <pattern> = <async expression> (, if <precondition>)? => <handler>, +/// ``` +/// +/// Additionally, the `select!` macro may include a single, optional `else` +/// branch, which evaluates if none of the other branches match their patterns: +/// +/// ```text +/// else => <expression> +/// ``` +/// +/// The macro aggregates all `<async expression>` expressions and runs them +/// concurrently on the **current** task. Once the **first** expression +/// completes with a value that matches its `<pattern>`, the `select!` macro +/// returns the result of evaluating the completed branch's `<handler>` +/// expression. +/// +/// Additionally, each branch may include an optional `if` precondition. If the +/// precondition returns `false`, then the branch is disabled. The provided +/// `<async expression>` is still evaluated but the resulting future is never +/// polled. This capability is useful when using `select!` within a loop. +/// +/// The complete lifecycle of a `select!` expression is as follows: +/// +/// 1. Evaluate all provided `<precondition>` expressions. If the precondition +/// returns `false`, disable the branch for the remainder of the current call +/// to `select!`. Re-entering `select!` due to a loop clears the "disabled" +/// state. +/// 2. Aggregate the `<async expression>`s from each branch, including the +/// disabled ones. If the branch is disabled, `<async expression>` is still +/// evaluated, but the resulting future is not polled. +/// 3. Concurrently await on the results for all remaining `<async expression>`s. +/// 4. Once an `<async expression>` returns a value, attempt to apply the value +/// to the provided `<pattern>`, if the pattern matches, evaluate `<handler>` +/// and return. If the pattern **does not** match, disable the current branch +/// and for the remainder of the current call to `select!`. Continue from step 3. +/// 5. If **all** branches are disabled, evaluate the `else` expression. If no +/// else branch is provided, panic. +/// +/// # Runtime characteristics +/// +/// By running all async expressions on the current task, the expressions are +/// able to run **concurrently** but not in **parallel**. This means all +/// expressions are run on the same thread and if one branch blocks the thread, +/// all other expressions will be unable to continue. If parallelism is +/// required, spawn each async expression using [`tokio::spawn`] and pass the +/// join handle to `select!`. +/// +/// [`tokio::spawn`]: crate::spawn +/// +/// # Fairness +/// +/// By default, `select!` randomly picks a branch to check first. This provides +/// some level of fairness when calling `select!` in a loop with branches that +/// are always ready. +/// +/// This behavior can be overridden by adding `biased;` to the beginning of the +/// macro usage. See the examples for details. This will cause `select` to poll +/// the futures in the order they appear from top to bottom. There are a few +/// reasons you may want this: +/// +/// - The random number generation of `tokio::select!` has a non-zero CPU cost +/// - Your futures may interact in a way where known polling order is significant +/// +/// But there is an important caveat to this mode. It becomes your responsibility +/// to ensure that the polling order of your futures is fair. If for example you +/// are selecting between a stream and a shutdown future, and the stream has a +/// huge volume of messages and zero or nearly zero time between them, you should +/// place the shutdown future earlier in the `select!` list to ensure that it is +/// always polled, and will not be ignored due to the stream being constantly +/// ready. +/// +/// # Panics +/// +/// The `select!` macro panics if all branches are disabled **and** there is no +/// provided `else` branch. A branch is disabled when the provided `if` +/// precondition returns `false` **or** when the pattern does not match the +/// result of `<async expression>`. +/// +/// # Cancellation safety +/// +/// When using `select!` in a loop to receive messages from multiple sources, +/// you should make sure that the receive call is cancellation safe to avoid +/// losing messages. This section goes through various common methods and +/// describes whether they are cancel safe. The lists in this section are not +/// exhaustive. +/// +/// The following methods are cancellation safe: +/// +/// * [`tokio::sync::mpsc::Receiver::recv`](crate::sync::mpsc::Receiver::recv) +/// * [`tokio::sync::mpsc::UnboundedReceiver::recv`](crate::sync::mpsc::UnboundedReceiver::recv) +/// * [`tokio::sync::broadcast::Receiver::recv`](crate::sync::broadcast::Receiver::recv) +/// * [`tokio::sync::watch::Receiver::changed`](crate::sync::watch::Receiver::changed) +/// * [`tokio::net::TcpListener::accept`](crate::net::TcpListener::accept) +/// * [`tokio::net::UnixListener::accept`](crate::net::UnixListener::accept) +/// * [`tokio::io::AsyncReadExt::read`](crate::io::AsyncReadExt::read) on any `AsyncRead` +/// * [`tokio::io::AsyncReadExt::read_buf`](crate::io::AsyncReadExt::read_buf) on any `AsyncRead` +/// * [`tokio::io::AsyncWriteExt::write`](crate::io::AsyncWriteExt::write) on any `AsyncWrite` +/// * [`tokio::io::AsyncWriteExt::write_buf`](crate::io::AsyncWriteExt::write_buf) on any `AsyncWrite` +/// * [`tokio_stream::StreamExt::next`](https://docs.rs/tokio-stream/0.1/tokio_stream/trait.StreamExt.html#method.next) on any `Stream` +/// * [`futures::stream::StreamExt::next`](https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.next) on any `Stream` +/// +/// The following methods are not cancellation safe and can lead to loss of data: +/// +/// * [`tokio::io::AsyncReadExt::read_exact`](crate::io::AsyncReadExt::read_exact) +/// * [`tokio::io::AsyncReadExt::read_to_end`](crate::io::AsyncReadExt::read_to_end) +/// * [`tokio::io::AsyncReadExt::read_to_string`](crate::io::AsyncReadExt::read_to_string) +/// * [`tokio::io::AsyncWriteExt::write_all`](crate::io::AsyncWriteExt::write_all) +/// +/// The following methods are not cancellation safe because they use a queue for +/// fairness and cancellation makes you lose your place in the queue: +/// +/// * [`tokio::sync::Mutex::lock`](crate::sync::Mutex::lock) +/// * [`tokio::sync::RwLock::read`](crate::sync::RwLock::read) +/// * [`tokio::sync::RwLock::write`](crate::sync::RwLock::write) +/// * [`tokio::sync::Semaphore::acquire`](crate::sync::Semaphore::acquire) +/// * [`tokio::sync::Notify::notified`](crate::sync::Notify::notified) +/// +/// To determine whether your own methods are cancellation safe, look for the +/// location of uses of `.await`. This is because when an asynchronous method is +/// cancelled, that always happens at an `.await`. If your function behaves +/// correctly even if it is restarted while waiting at an `.await`, then it is +/// cancellation safe. +/// +/// Be aware that cancelling something that is not cancellation safe is not +/// necessarily wrong. For example, if you are cancelling a task because the +/// application is shutting down, then you probably don't care that partially +/// read data is lost. +/// +/// # Examples +/// +/// Basic select with two branches. +/// +/// ``` +/// async fn do_stuff_async() { +/// // async work +/// } +/// +/// async fn more_async_work() { +/// // more here +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// tokio::select! { +/// _ = do_stuff_async() => { +/// println!("do_stuff_async() completed first") +/// } +/// _ = more_async_work() => { +/// println!("more_async_work() completed first") +/// } +/// }; +/// } +/// ``` +/// +/// Basic stream selecting. +/// +/// ``` +/// use tokio_stream::{self as stream, StreamExt}; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut stream1 = stream::iter(vec![1, 2, 3]); +/// let mut stream2 = stream::iter(vec![4, 5, 6]); +/// +/// let next = tokio::select! { +/// v = stream1.next() => v.unwrap(), +/// v = stream2.next() => v.unwrap(), +/// }; +/// +/// assert!(next == 1 || next == 4); +/// } +/// ``` +/// +/// Collect the contents of two streams. In this example, we rely on pattern +/// matching and the fact that `stream::iter` is "fused", i.e. once the stream +/// is complete, all calls to `next()` return `None`. +/// +/// ``` +/// use tokio_stream::{self as stream, StreamExt}; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut stream1 = stream::iter(vec![1, 2, 3]); +/// let mut stream2 = stream::iter(vec![4, 5, 6]); +/// +/// let mut values = vec![]; +/// +/// loop { +/// tokio::select! { +/// Some(v) = stream1.next() => values.push(v), +/// Some(v) = stream2.next() => values.push(v), +/// else => break, +/// } +/// } +/// +/// values.sort(); +/// assert_eq!(&[1, 2, 3, 4, 5, 6], &values[..]); +/// } +/// ``` +/// +/// Using the same future in multiple `select!` expressions can be done by passing +/// a reference to the future. Doing so requires the future to be [`Unpin`]. A +/// future can be made [`Unpin`] by either using [`Box::pin`] or stack pinning. +/// +/// [`Unpin`]: std::marker::Unpin +/// [`Box::pin`]: std::boxed::Box::pin +/// +/// Here, a stream is consumed for at most 1 second. +/// +/// ``` +/// use tokio_stream::{self as stream, StreamExt}; +/// use tokio::time::{self, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut stream = stream::iter(vec![1, 2, 3]); +/// let sleep = time::sleep(Duration::from_secs(1)); +/// tokio::pin!(sleep); +/// +/// loop { +/// tokio::select! { +/// maybe_v = stream.next() => { +/// if let Some(v) = maybe_v { +/// println!("got = {}", v); +/// } else { +/// break; +/// } +/// } +/// _ = &mut sleep => { +/// println!("timeout"); +/// break; +/// } +/// } +/// } +/// } +/// ``` +/// +/// Joining two values using `select!`. +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx1, mut rx1) = oneshot::channel(); +/// let (tx2, mut rx2) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// tx1.send("first").unwrap(); +/// }); +/// +/// tokio::spawn(async move { +/// tx2.send("second").unwrap(); +/// }); +/// +/// let mut a = None; +/// let mut b = None; +/// +/// while a.is_none() || b.is_none() { +/// tokio::select! { +/// v1 = (&mut rx1), if a.is_none() => a = Some(v1.unwrap()), +/// v2 = (&mut rx2), if b.is_none() => b = Some(v2.unwrap()), +/// } +/// } +/// +/// let res = (a.unwrap(), b.unwrap()); +/// +/// assert_eq!(res.0, "first"); +/// assert_eq!(res.1, "second"); +/// } +/// ``` +/// +/// Using the `biased;` mode to control polling order. +/// +/// ``` +/// #[tokio::main] +/// async fn main() { +/// let mut count = 0u8; +/// +/// loop { +/// tokio::select! { +/// // If you run this example without `biased;`, the polling order is +/// // pseudo-random, and the assertions on the value of count will +/// // (probably) fail. +/// biased; +/// +/// _ = async {}, if count < 1 => { +/// count += 1; +/// assert_eq!(count, 1); +/// } +/// _ = async {}, if count < 2 => { +/// count += 1; +/// assert_eq!(count, 2); +/// } +/// _ = async {}, if count < 3 => { +/// count += 1; +/// assert_eq!(count, 3); +/// } +/// _ = async {}, if count < 4 => { +/// count += 1; +/// assert_eq!(count, 4); +/// } +/// +/// else => { +/// break; +/// } +/// }; +/// } +/// } +/// ``` +/// +/// ## Avoid racy `if` preconditions +/// +/// Given that `if` preconditions are used to disable `select!` branches, some +/// caution must be used to avoid missing values. +/// +/// For example, here is **incorrect** usage of `sleep` with `if`. The objective +/// is to repeatedly run an asynchronous task for up to 50 milliseconds. +/// However, there is a potential for the `sleep` completion to be missed. +/// +/// ```no_run,should_panic +/// use tokio::time::{self, Duration}; +/// +/// async fn some_async_work() { +/// // do work +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); +/// +/// while !sleep.is_elapsed() { +/// tokio::select! { +/// _ = &mut sleep, if !sleep.is_elapsed() => { +/// println!("operation timed out"); +/// } +/// _ = some_async_work() => { +/// println!("operation completed"); +/// } +/// } +/// } +/// +/// panic!("This example shows how not to do it!"); +/// } +/// ``` +/// +/// In the above example, `sleep.is_elapsed()` may return `true` even if +/// `sleep.poll()` never returned `Ready`. This opens up a potential race +/// condition where `sleep` expires between the `while !sleep.is_elapsed()` +/// check and the call to `select!` resulting in the `some_async_work()` call to +/// run uninterrupted despite the sleep having elapsed. +/// +/// One way to write the above example without the race would be: +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// +/// async fn some_async_work() { +/// # time::sleep(Duration::from_millis(10)).await; +/// // do work +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); +/// +/// loop { +/// tokio::select! { +/// _ = &mut sleep => { +/// println!("operation timed out"); +/// break; +/// } +/// _ = some_async_work() => { +/// println!("operation completed"); +/// } +/// } +/// } +/// } +/// ``` +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! select { + // Uses a declarative macro to do **most** of the work. While it is possible + // to implement fully with a declarative macro, a procedural macro is used + // to enable improved error messages. + // + // The macro is structured as a tt-muncher. All branches are processed and + // normalized. Once the input is normalized, it is passed to the top-most + // rule. When entering the macro, `@{ }` is inserted at the front. This is + // used to collect the normalized input. + // + // The macro only recurses once per branch. This allows using `select!` + // without requiring the user to increase the recursion limit. + + // All input is normalized, now transform. + (@ { + // The index of the future to poll first (in bias mode), or the RNG + // expression to use to pick a future to poll first. + start=$start:expr; + + // One `_` for each branch in the `select!` macro. Passing this to + // `count!` converts $skip to an integer. + ( $($count:tt)* ) + + // Normalized select branches. `( $skip )` is a set of `_` characters. + // There is one `_` for each select branch **before** this one. Given + // that all input futures are stored in a tuple, $skip is useful for + // generating a pattern to reference the future for the current branch. + // $skip is also used as an argument to `count!`, returning the index of + // the current select branch. + $( ( $($skip:tt)* ) $bind:pat = $fut:expr, if $c:expr => $handle:expr, )+ + + // Fallback expression used when all select branches have been disabled. + ; $else:expr + + }) => {{ + // Enter a context where stable "function-like" proc macros can be used. + // + // This module is defined within a scope and should not leak out of this + // macro. + mod util { + // Generate an enum with one variant per select branch + $crate::select_priv_declare_output_enum!( ( $($count)* ) ); + } + + // `tokio::macros::support` is a public, but doc(hidden) module + // including a re-export of all types needed by this macro. + use $crate::macros::support::Future; + use $crate::macros::support::Pin; + use $crate::macros::support::Poll::{Ready, Pending}; + + const BRANCHES: u32 = $crate::count!( $($count)* ); + + let mut disabled: util::Mask = Default::default(); + + // First, invoke all the pre-conditions. For any that return true, + // set the appropriate bit in `disabled`. + $( + if !$c { + let mask: util::Mask = 1 << $crate::count!( $($skip)* ); + disabled |= mask; + } + )* + + // Create a scope to separate polling from handling the output. This + // adds borrow checker flexibility when using the macro. + let mut output = { + // Safety: Nothing must be moved out of `futures`. This is to + // satisfy the requirement of `Pin::new_unchecked` called below. + let mut futures = ( $( $fut , )+ ); + + $crate::macros::support::poll_fn(|cx| { + // Track if any branch returns pending. If no branch completes + // **or** returns pending, this implies that all branches are + // disabled. + let mut is_pending = false; + + // Choose a starting index to begin polling the futures at. In + // practice, this will either be a pseudo-randomly generated + // number by default, or the constant 0 if `biased;` is + // supplied. + let start = $start; + + for i in 0..BRANCHES { + let branch; + #[allow(clippy::modulo_one)] + { + branch = (start + i) % BRANCHES; + } + match branch { + $( + #[allow(unreachable_code)] + $crate::count!( $($skip)* ) => { + // First, if the future has previously been + // disabled, do not poll it again. This is done + // by checking the associated bit in the + // `disabled` bit field. + let mask = 1 << branch; + + if disabled & mask == mask { + // The future has been disabled. + continue; + } + + // Extract the future for this branch from the + // tuple + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + // Try polling it + let out = match Future::poll(fut, cx) { + Ready(out) => out, + Pending => { + // Track that at least one future is + // still pending and continue polling. + is_pending = true; + continue; + } + }; + + // Disable the future from future polling. + disabled |= mask; + + // The future returned a value, check if matches + // the specified pattern. + #[allow(unused_variables)] + #[allow(unused_mut)] + match &out { + $crate::select_priv_clean_pattern!($bind) => {} + _ => continue, + } + + // The select is complete, return the value + return Ready($crate::select_variant!(util::Out, ($($skip)*))(out)); + } + )* + _ => unreachable!("reaching this means there probably is an off by one bug"), + } + } + + if is_pending { + Pending + } else { + // All branches have been disabled. + Ready(util::Out::Disabled) + } + }).await + }; + + match output { + $( + $crate::select_variant!(util::Out, ($($skip)*) ($bind)) => $handle, + )* + util::Out::Disabled => $else, + _ => unreachable!("failed to match bind"), + } + }}; + + // ==== Normalize ===== + + // These rules match a single `select!` branch and normalize it for + // processing by the first rule. + + (@ { start=$start:expr; $($t:tt)* } ) => { + // No `else` branch + $crate::select!(@{ start=$start; $($t)*; panic!("all branches are disabled and there is no else branch") }) + }; + (@ { start=$start:expr; $($t:tt)* } else => $else:expr $(,)?) => { + $crate::select!(@{ start=$start; $($t)*; $else }) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:block, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:block, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:block $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:block $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:expr ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, }) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:expr ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, }) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr, if $c:expr => $h:expr, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if $c => $h, } $($r)*) + }; + (@ { start=$start:expr; ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:expr, $($r:tt)* ) => { + $crate::select!(@{ start=$start; ($($s)* _) $($t)* ($($s)*) $p = $f, if true => $h, } $($r)*) + }; + + // ===== Entry point ===== + + (biased; $p:pat = $($t:tt)* ) => { + $crate::select!(@{ start=0; () } $p = $($t)*) + }; + + ( $p:pat = $($t:tt)* ) => { + // Randomly generate a starting point. This makes `select!` a bit more + // fair and avoids always polling the first future. + $crate::select!(@{ start={ $crate::macros::support::thread_rng_n(BRANCHES) }; () } $p = $($t)*) + }; + () => { + compile_error!("select! requires at least one branch.") + }; +} + +// And here... we manually list out matches for up to 64 branches... I'm not +// happy about it either, but this is how we manage to use a declarative macro! + +#[macro_export] +#[doc(hidden)] +macro_rules! count { + () => { + 0 + }; + (_) => { + 1 + }; + (_ _) => { + 2 + }; + (_ _ _) => { + 3 + }; + (_ _ _ _) => { + 4 + }; + (_ _ _ _ _) => { + 5 + }; + (_ _ _ _ _ _) => { + 6 + }; + (_ _ _ _ _ _ _) => { + 7 + }; + (_ _ _ _ _ _ _ _) => { + 8 + }; + (_ _ _ _ _ _ _ _ _) => { + 9 + }; + (_ _ _ _ _ _ _ _ _ _) => { + 10 + }; + (_ _ _ _ _ _ _ _ _ _ _) => { + 11 + }; + (_ _ _ _ _ _ _ _ _ _ _ _) => { + 12 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _) => { + 13 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 14 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 15 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 16 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 17 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 18 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 19 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 20 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 21 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 22 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 23 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 24 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 25 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 26 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 27 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 28 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 29 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 30 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 31 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 32 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 33 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 34 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 35 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 36 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 37 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 38 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 39 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 40 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 41 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 42 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 43 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 44 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 45 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 46 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 47 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 48 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 49 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 50 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 51 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 52 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 53 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 54 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 55 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 56 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 57 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 58 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 59 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 60 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 61 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 62 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 63 + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! select_variant { + ($($p:ident)::*, () $($t:tt)*) => { + $($p)::*::_0 $($t)* + }; + ($($p:ident)::*, (_) $($t:tt)*) => { + $($p)::*::_1 $($t)* + }; + ($($p:ident)::*, (_ _) $($t:tt)*) => { + $($p)::*::_2 $($t)* + }; + ($($p:ident)::*, (_ _ _) $($t:tt)*) => { + $($p)::*::_3 $($t)* + }; + ($($p:ident)::*, (_ _ _ _) $($t:tt)*) => { + $($p)::*::_4 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _) $($t:tt)*) => { + $($p)::*::_5 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_6 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_7 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_8 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_9 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_10 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_11 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_12 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_13 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_14 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_15 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_16 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_17 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_18 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_19 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_20 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_21 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_22 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_23 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_24 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_25 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_26 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_27 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_28 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_29 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_30 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_31 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_32 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_33 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_34 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_35 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_36 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_37 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_38 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_39 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_40 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_41 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_42 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_43 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_44 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_45 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_46 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_47 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_48 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_49 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_50 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_51 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_52 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_53 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_54 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_55 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_56 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_57 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_58 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_59 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_60 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_61 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_62 $($t)* + }; + ($($p:ident)::*, (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) $($t:tt)*) => { + $($p)::*::_63 $($t)* + }; +} diff --git a/third_party/rust/tokio/src/macros/support.rs b/third_party/rust/tokio/src/macros/support.rs new file mode 100644 index 0000000000..7f11bc6800 --- /dev/null +++ b/third_party/rust/tokio/src/macros/support.rs @@ -0,0 +1,9 @@ +cfg_macros! { + pub use crate::future::poll_fn; + pub use crate::future::maybe_done::maybe_done; + pub use crate::util::thread_rng_n; +} + +pub use std::future::Future; +pub use std::pin::Pin; +pub use std::task::Poll; diff --git a/third_party/rust/tokio/src/macros/thread_local.rs b/third_party/rust/tokio/src/macros/thread_local.rs new file mode 100644 index 0000000000..d848947350 --- /dev/null +++ b/third_party/rust/tokio/src/macros/thread_local.rs @@ -0,0 +1,4 @@ +#[cfg(all(loom, test))] +macro_rules! thread_local { + ($($tts:tt)+) => { loom::thread_local!{ $($tts)+ } } +} diff --git a/third_party/rust/tokio/src/macros/trace.rs b/third_party/rust/tokio/src/macros/trace.rs new file mode 100644 index 0000000000..80a257e189 --- /dev/null +++ b/third_party/rust/tokio/src/macros/trace.rs @@ -0,0 +1,26 @@ +cfg_trace! { + macro_rules! trace_op { + ($name:expr, $readiness:literal) => { + tracing::trace!( + target: "runtime::resource::poll_op", + op_name = $name, + is_ready = $readiness + ); + } + } + + macro_rules! trace_poll_op { + ($name:expr, $poll:expr $(,)*) => { + match $poll { + std::task::Poll::Ready(t) => { + trace_op!($name, true); + std::task::Poll::Ready(t) + } + std::task::Poll::Pending => { + trace_op!($name, false); + return std::task::Poll::Pending; + } + } + }; + } +} diff --git a/third_party/rust/tokio/src/macros/try_join.rs b/third_party/rust/tokio/src/macros/try_join.rs new file mode 100644 index 0000000000..6d3a893b7e --- /dev/null +++ b/third_party/rust/tokio/src/macros/try_join.rs @@ -0,0 +1,171 @@ +/// Waits on multiple concurrent branches, returning when **all** branches +/// complete with `Ok(_)` or on the first `Err(_)`. +/// +/// The `try_join!` macro must be used inside of async functions, closures, and +/// blocks. +/// +/// Similar to [`join!`], the `try_join!` macro takes a list of async +/// expressions and evaluates them concurrently on the same task. Each async +/// expression evaluates to a future and the futures from each expression are +/// multiplexed on the current task. The `try_join!` macro returns when **all** +/// branches return with `Ok` or when the **first** branch returns with `Err`. +/// +/// [`join!`]: macro@join +/// +/// # Notes +/// +/// The supplied futures are stored inline and does not require allocating a +/// `Vec`. +/// +/// ### Runtime characteristics +/// +/// By running all async expressions on the current task, the expressions are +/// able to run **concurrently** but not in **parallel**. This means all +/// expressions are run on the same thread and if one branch blocks the thread, +/// all other expressions will be unable to continue. If parallelism is +/// required, spawn each async expression using [`tokio::spawn`] and pass the +/// join handle to `try_join!`. +/// +/// [`tokio::spawn`]: crate::spawn +/// +/// # Examples +/// +/// Basic try_join with two branches. +/// +/// ``` +/// async fn do_stuff_async() -> Result<(), &'static str> { +/// // async work +/// # Ok(()) +/// } +/// +/// async fn more_async_work() -> Result<(), &'static str> { +/// // more here +/// # Ok(()) +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let res = tokio::try_join!( +/// do_stuff_async(), +/// more_async_work()); +/// +/// match res { +/// Ok((first, second)) => { +/// // do something with the values +/// } +/// Err(err) => { +/// println!("processing failed; error = {}", err); +/// } +/// } +/// } +/// ``` +/// +/// Using `try_join!` with spawned tasks. +/// +/// ``` +/// use tokio::task::JoinHandle; +/// +/// async fn do_stuff_async() -> Result<(), &'static str> { +/// // async work +/// # Err("failed") +/// } +/// +/// async fn more_async_work() -> Result<(), &'static str> { +/// // more here +/// # Ok(()) +/// } +/// +/// async fn flatten<T>(handle: JoinHandle<Result<T, &'static str>>) -> Result<T, &'static str> { +/// match handle.await { +/// Ok(Ok(result)) => Ok(result), +/// Ok(Err(err)) => Err(err), +/// Err(err) => Err("handling failed"), +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let handle1 = tokio::spawn(do_stuff_async()); +/// let handle2 = tokio::spawn(more_async_work()); +/// match tokio::try_join!(flatten(handle1), flatten(handle2)) { +/// Ok(val) => { +/// // do something with the values +/// } +/// Err(err) => { +/// println!("Failed with {}.", err); +/// # assert_eq!(err, "failed"); +/// } +/// } +/// } +/// ``` +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! try_join { + (@ { + // One `_` for each branch in the `try_join!` macro. This is not used once + // normalization is complete. + ( $($count:tt)* ) + + // Normalized try_join! branches + $( ( $($skip:tt)* ) $e:expr, )* + + }) => {{ + use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; + use $crate::macros::support::Poll::{Ready, Pending}; + + // Safety: nothing must be moved out of `futures`. This is to satisfy + // the requirement of `Pin::new_unchecked` called below. + let mut futures = ( $( maybe_done($e), )* ); + + poll_fn(move |cx| { + let mut is_pending = false; + + $( + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + // Try polling + if fut.as_mut().poll(cx).is_pending() { + is_pending = true; + } else if fut.as_mut().output_mut().expect("expected completed future").is_err() { + return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap())) + } + )* + + if is_pending { + Pending + } else { + Ready(Ok(($({ + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + fut + .take_output() + .expect("expected completed future") + .ok() + .expect("expected Ok(_)") + },)*))) + } + }).await + }}; + + // ===== Normalize ===== + + (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::try_join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + }; + + // ===== Entry point ===== + + ( $($e:expr),* $(,)?) => { + $crate::try_join!(@{ () } $($e,)*) + }; +} diff --git a/third_party/rust/tokio/src/net/addr.rs b/third_party/rust/tokio/src/net/addr.rs new file mode 100644 index 0000000000..13f743c962 --- /dev/null +++ b/third_party/rust/tokio/src/net/addr.rs @@ -0,0 +1,318 @@ +use std::future; +use std::io; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +/// Converts or resolves without blocking to one or more `SocketAddr` values. +/// +/// # DNS +/// +/// Implementations of `ToSocketAddrs` for string types require a DNS lookup. +/// +/// # Calling +/// +/// Currently, this trait is only used as an argument to Tokio functions that +/// need to reference a target socket address. To perform a `SocketAddr` +/// conversion directly, use [`lookup_host()`](super::lookup_host()). +/// +/// This trait is sealed and is intended to be opaque. The details of the trait +/// will change. Stabilization is pending enhancements to the Rust language. +pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {} + +type ReadyFuture<T> = future::Ready<io::Result<T>>; + +cfg_net! { + pub(crate) fn to_socket_addrs<T>(arg: T) -> T::Future + where + T: ToSocketAddrs, + { + arg.to_socket_addrs(sealed::Internal) + } +} + +// ===== impl &impl ToSocketAddrs ===== + +impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {} + +impl<T> sealed::ToSocketAddrsPriv for &T +where + T: sealed::ToSocketAddrsPriv + ?Sized, +{ + type Iter = T::Iter; + type Future = T::Future; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + (**self).to_socket_addrs(sealed::Internal) + } +} + +// ===== impl SocketAddr ===== + +impl ToSocketAddrs for SocketAddr {} + +impl sealed::ToSocketAddrsPriv for SocketAddr { + type Iter = std::option::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + let iter = Some(*self).into_iter(); + future::ready(Ok(iter)) + } +} + +// ===== impl SocketAddrV4 ===== + +impl ToSocketAddrs for SocketAddrV4 {} + +impl sealed::ToSocketAddrsPriv for SocketAddrV4 { + type Iter = std::option::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + SocketAddr::V4(*self).to_socket_addrs(sealed::Internal) + } +} + +// ===== impl SocketAddrV6 ===== + +impl ToSocketAddrs for SocketAddrV6 {} + +impl sealed::ToSocketAddrsPriv for SocketAddrV6 { + type Iter = std::option::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + SocketAddr::V6(*self).to_socket_addrs(sealed::Internal) + } +} + +// ===== impl (IpAddr, u16) ===== + +impl ToSocketAddrs for (IpAddr, u16) {} + +impl sealed::ToSocketAddrsPriv for (IpAddr, u16) { + type Iter = std::option::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + let iter = Some(SocketAddr::from(*self)).into_iter(); + future::ready(Ok(iter)) + } +} + +// ===== impl (Ipv4Addr, u16) ===== + +impl ToSocketAddrs for (Ipv4Addr, u16) {} + +impl sealed::ToSocketAddrsPriv for (Ipv4Addr, u16) { + type Iter = std::option::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + let (ip, port) = *self; + SocketAddrV4::new(ip, port).to_socket_addrs(sealed::Internal) + } +} + +// ===== impl (Ipv6Addr, u16) ===== + +impl ToSocketAddrs for (Ipv6Addr, u16) {} + +impl sealed::ToSocketAddrsPriv for (Ipv6Addr, u16) { + type Iter = std::option::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + let (ip, port) = *self; + SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs(sealed::Internal) + } +} + +// ===== impl &[SocketAddr] ===== + +impl ToSocketAddrs for &[SocketAddr] {} + +impl sealed::ToSocketAddrsPriv for &[SocketAddr] { + type Iter = std::vec::IntoIter<SocketAddr>; + type Future = ReadyFuture<Self::Iter>; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + let iter = self.to_vec().into_iter(); + future::ready(Ok(iter)) + } +} + +cfg_net! { + // ===== impl str ===== + + impl ToSocketAddrs for str {} + + impl sealed::ToSocketAddrsPriv for str { + type Iter = sealed::OneOrMore; + type Future = sealed::MaybeReady; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + use crate::blocking::spawn_blocking; + use sealed::MaybeReady; + + // First check if the input parses as a socket address + let res: Result<SocketAddr, _> = self.parse(); + + if let Ok(addr) = res { + return MaybeReady(sealed::State::Ready(Some(addr))); + } + + // Run DNS lookup on the blocking pool + let s = self.to_owned(); + + MaybeReady(sealed::State::Blocking(spawn_blocking(move || { + std::net::ToSocketAddrs::to_socket_addrs(&s) + }))) + } + } + + // ===== impl (&str, u16) ===== + + impl ToSocketAddrs for (&str, u16) {} + + impl sealed::ToSocketAddrsPriv for (&str, u16) { + type Iter = sealed::OneOrMore; + type Future = sealed::MaybeReady; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + use crate::blocking::spawn_blocking; + use sealed::MaybeReady; + + let (host, port) = *self; + + // try to parse the host as a regular IP address first + if let Ok(addr) = host.parse::<Ipv4Addr>() { + let addr = SocketAddrV4::new(addr, port); + let addr = SocketAddr::V4(addr); + + return MaybeReady(sealed::State::Ready(Some(addr))); + } + + if let Ok(addr) = host.parse::<Ipv6Addr>() { + let addr = SocketAddrV6::new(addr, port, 0, 0); + let addr = SocketAddr::V6(addr); + + return MaybeReady(sealed::State::Ready(Some(addr))); + } + + let host = host.to_owned(); + + MaybeReady(sealed::State::Blocking(spawn_blocking(move || { + std::net::ToSocketAddrs::to_socket_addrs(&(&host[..], port)) + }))) + } + } + + // ===== impl (String, u16) ===== + + impl ToSocketAddrs for (String, u16) {} + + impl sealed::ToSocketAddrsPriv for (String, u16) { + type Iter = sealed::OneOrMore; + type Future = sealed::MaybeReady; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + (self.0.as_str(), self.1).to_socket_addrs(sealed::Internal) + } + } + + // ===== impl String ===== + + impl ToSocketAddrs for String {} + + impl sealed::ToSocketAddrsPriv for String { + type Iter = <str as sealed::ToSocketAddrsPriv>::Iter; + type Future = <str as sealed::ToSocketAddrsPriv>::Future; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + (&self[..]).to_socket_addrs(sealed::Internal) + } + } +} + +pub(crate) mod sealed { + //! The contents of this trait are intended to remain private and __not__ + //! part of the `ToSocketAddrs` public API. The details will change over + //! time. + + use std::future::Future; + use std::io; + use std::net::SocketAddr; + + #[doc(hidden)] + pub trait ToSocketAddrsPriv { + type Iter: Iterator<Item = SocketAddr> + Send + 'static; + type Future: Future<Output = io::Result<Self::Iter>> + Send + 'static; + + fn to_socket_addrs(&self, internal: Internal) -> Self::Future; + } + + #[allow(missing_debug_implementations)] + pub struct Internal; + + cfg_net! { + use crate::blocking::JoinHandle; + + use std::option; + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::vec; + + #[doc(hidden)] + #[derive(Debug)] + pub struct MaybeReady(pub(super) State); + + #[derive(Debug)] + pub(super) enum State { + Ready(Option<SocketAddr>), + Blocking(JoinHandle<io::Result<vec::IntoIter<SocketAddr>>>), + } + + #[doc(hidden)] + #[derive(Debug)] + pub enum OneOrMore { + One(option::IntoIter<SocketAddr>), + More(vec::IntoIter<SocketAddr>), + } + + impl Future for MaybeReady { + type Output = io::Result<OneOrMore>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match self.0 { + State::Ready(ref mut i) => { + let iter = OneOrMore::One(i.take().into_iter()); + Poll::Ready(Ok(iter)) + } + State::Blocking(ref mut rx) => { + let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More); + + Poll::Ready(res) + } + } + } + } + + impl Iterator for OneOrMore { + type Item = SocketAddr; + + fn next(&mut self) -> Option<Self::Item> { + match self { + OneOrMore::One(i) => i.next(), + OneOrMore::More(i) => i.next(), + } + } + + fn size_hint(&self) -> (usize, Option<usize>) { + match self { + OneOrMore::One(i) => i.size_hint(), + OneOrMore::More(i) => i.size_hint(), + } + } + } + } +} diff --git a/third_party/rust/tokio/src/net/lookup_host.rs b/third_party/rust/tokio/src/net/lookup_host.rs new file mode 100644 index 0000000000..28861849e4 --- /dev/null +++ b/third_party/rust/tokio/src/net/lookup_host.rs @@ -0,0 +1,38 @@ +cfg_net! { + use crate::net::addr::{self, ToSocketAddrs}; + + use std::io; + use std::net::SocketAddr; + + /// Performs a DNS resolution. + /// + /// The returned iterator may not actually yield any values depending on the + /// outcome of any resolution performed. + /// + /// This API is not intended to cover all DNS use cases. Anything beyond the + /// basic use case should be done with a specialized library. + /// + /// # Examples + /// + /// To resolve a DNS entry: + /// + /// ```no_run + /// use tokio::net; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// for addr in net::lookup_host("localhost:3000").await? { + /// println!("socket address is {}", addr); + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn lookup_host<T>(host: T) -> io::Result<impl Iterator<Item = SocketAddr>> + where + T: ToSocketAddrs + { + addr::to_socket_addrs(host).await + } +} diff --git a/third_party/rust/tokio/src/net/mod.rs b/third_party/rust/tokio/src/net/mod.rs new file mode 100644 index 0000000000..0b8c1ecd19 --- /dev/null +++ b/third_party/rust/tokio/src/net/mod.rs @@ -0,0 +1,52 @@ +#![cfg(not(loom))] + +//! TCP/UDP/Unix bindings for `tokio`. +//! +//! This module contains the TCP/UDP/Unix networking types, similar to the standard +//! library, which can be used to implement networking protocols. +//! +//! # Organization +//! +//! * [`TcpListener`] and [`TcpStream`] provide functionality for communication over TCP +//! * [`UdpSocket`] provides functionality for communication over UDP +//! * [`UnixListener`] and [`UnixStream`] provide functionality for communication over a +//! Unix Domain Stream Socket **(available on Unix only)** +//! * [`UnixDatagram`] provides functionality for communication +//! over Unix Domain Datagram Socket **(available on Unix only)** + +//! +//! [`TcpListener`]: TcpListener +//! [`TcpStream`]: TcpStream +//! [`UdpSocket`]: UdpSocket +//! [`UnixListener`]: UnixListener +//! [`UnixStream`]: UnixStream +//! [`UnixDatagram`]: UnixDatagram + +mod addr; +#[cfg(feature = "net")] +pub(crate) use addr::to_socket_addrs; +pub use addr::ToSocketAddrs; + +cfg_net! { + mod lookup_host; + pub use lookup_host::lookup_host; + + pub mod tcp; + pub use tcp::listener::TcpListener; + pub use tcp::socket::TcpSocket; + pub use tcp::stream::TcpStream; + + mod udp; + pub use udp::UdpSocket; +} + +cfg_net_unix! { + pub mod unix; + pub use unix::datagram::socket::UnixDatagram; + pub use unix::listener::UnixListener; + pub use unix::stream::UnixStream; +} + +cfg_net_windows! { + pub mod windows; +} diff --git a/third_party/rust/tokio/src/net/tcp/listener.rs b/third_party/rust/tokio/src/net/tcp/listener.rs new file mode 100644 index 0000000000..8aecb21aaa --- /dev/null +++ b/third_party/rust/tokio/src/net/tcp/listener.rs @@ -0,0 +1,397 @@ +use crate::io::{Interest, PollEvented}; +use crate::net::tcp::TcpStream; +use crate::net::{to_socket_addrs, ToSocketAddrs}; + +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::net::{self, SocketAddr}; +use std::task::{Context, Poll}; + +cfg_net! { + /// A TCP socket server, listening for connections. + /// + /// You can accept a new connection by using the [`accept`](`TcpListener::accept`) + /// method. + /// + /// A `TcpListener` can be turned into a `Stream` with [`TcpListenerStream`]. + /// + /// [`TcpListenerStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.TcpListenerStream.html + /// + /// # Errors + /// + /// Note that accepting a connection can lead to various errors and not all + /// of them are necessarily fatal ‒ for example having too many open file + /// descriptors or the other side closing the connection while it waits in + /// an accept queue. These would terminate the stream if not handled in any + /// way. + /// + /// # Examples + /// + /// Using `accept`: + /// ```no_run + /// use tokio::net::TcpListener; + /// + /// use std::io; + /// + /// async fn process_socket<T>(socket: T) { + /// # drop(socket); + /// // do work with socket here + /// } + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; + /// + /// loop { + /// let (socket, _) = listener.accept().await?; + /// process_socket(socket).await; + /// } + /// } + /// ``` + pub struct TcpListener { + io: PollEvented<mio::net::TcpListener>, + } +} + +impl TcpListener { + /// Creates a new TcpListener, which will be bound to the specified address. + /// + /// The returned listener is ready for accepting connections. + /// + /// Binding with a port number of 0 will request that the OS assigns a port + /// to this listener. The port allocated can be queried via the `local_addr` + /// method. + /// + /// The address type can be any implementor of the [`ToSocketAddrs`] trait. + /// If `addr` yields multiple addresses, bind will be attempted with each of + /// the addresses until one succeeds and returns the listener. If none of + /// the addresses succeed in creating a listener, the error returned from + /// the last attempt (the last address) is returned. + /// + /// This function sets the `SO_REUSEADDR` option on the socket. + /// + /// To configure the socket before binding, you can use the [`TcpSocket`] + /// type. + /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// [`TcpSocket`]: struct@crate::net::TcpSocket + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpListener; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:2345").await?; + /// + /// // use the listener + /// + /// # let _ = listener; + /// Ok(()) + /// } + /// ``` + pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> { + let addrs = to_socket_addrs(addr).await?; + + let mut last_err = None; + + for addr in addrs { + match TcpListener::bind_addr(addr) { + Ok(listener) => return Ok(listener), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) + } + + fn bind_addr(addr: SocketAddr) -> io::Result<TcpListener> { + let listener = mio::net::TcpListener::bind(addr)?; + TcpListener::new(listener) + } + + /// Accepts a new incoming connection from this listener. + /// + /// This function will yield once a new TCP connection is established. When + /// established, the corresponding [`TcpStream`] and the remote peer's + /// address will be returned. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. + /// + /// [`TcpStream`]: struct@crate::net::TcpStream + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpListener; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; + /// + /// match listener.accept().await { + /// Ok((_socket, addr)) => println!("new client: {:?}", addr), + /// Err(e) => println!("couldn't get client: {:?}", e), + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let (mio, addr) = self + .io + .registration() + .async_io(Interest::READABLE, || self.io.accept()) + .await?; + + let stream = TcpStream::new(mio)?; + Ok((stream, addr)) + } + + /// Polls to accept a new incoming connection to this listener. + /// + /// If there is no connection to accept, `Poll::Pending` is returned and the + /// current task will be notified by a waker. Note that on multiple calls + /// to `poll_accept`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> { + loop { + let ev = ready!(self.io.registration().poll_read_ready(cx))?; + + match self.io.accept() { + Ok((io, addr)) => { + let io = TcpStream::new(io)?; + return Poll::Ready(Ok((io, addr))); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.registration().clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + /// Creates new `TcpListener` from a `std::net::TcpListener`. + /// + /// This function is intended to be used to wrap a TCP listener from the + /// standard library in the Tokio equivalent. The conversion assumes nothing + /// about the underlying listener; it is left up to the user to set it in + /// non-blocking mode. + /// + /// This API is typically paired with the `socket2` crate and the `Socket` + /// type to build up and customize a listener before it's shipped off to the + /// backing event loop. This allows configuration of options like + /// `SO_REUSEPORT`, binding to multiple addresses, etc. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// use tokio::net::TcpListener; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let std_listener = std::net::TcpListener::bind("127.0.0.1:0")?; + /// std_listener.set_nonblocking(true)?; + /// let listener = TcpListener::from_std(std_listener)?; + /// Ok(()) + /// } + /// ``` + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> { + let io = mio::net::TcpListener::from_std(listener); + let io = PollEvented::new(io)?; + Ok(TcpListener { io }) + } + + /// Turns a [`tokio::net::TcpListener`] into a [`std::net::TcpListener`]. + /// + /// The returned [`std::net::TcpListener`] will have nonblocking mode set as + /// `true`. Use [`set_nonblocking`] to change the blocking mode if needed. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let tokio_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + /// let std_listener = tokio_listener.into_std()?; + /// std_listener.set_nonblocking(false)?; + /// Ok(()) + /// } + /// ``` + /// + /// [`tokio::net::TcpListener`]: TcpListener + /// [`std::net::TcpListener`]: std::net::TcpListener + /// [`set_nonblocking`]: fn@std::net::TcpListener::set_nonblocking + pub fn into_std(self) -> io::Result<std::net::TcpListener> { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::net::TcpListener::from_raw_fd(raw_fd) }) + } + + #[cfg(windows)] + { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + self.io + .into_inner() + .map(|io| io.into_raw_socket()) + .map(|raw_socket| unsafe { std::net::TcpListener::from_raw_socket(raw_socket) }) + } + } + + pub(crate) fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> { + let io = PollEvented::new(listener)?; + Ok(TcpListener { io }) + } + + /// Returns the local address that this listener is bound to. + /// + /// This can be useful, for example, when binding to port 0 to figure out + /// which port was actually bound. + /// + /// # Examples + /// + /// ```rust,no_run + /// use tokio::net::TcpListener; + /// + /// use std::io; + /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; + /// + /// assert_eq!(listener.local_addr()?, + /// SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080))); + /// + /// Ok(()) + /// } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.local_addr() + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`]. + /// + /// [`set_ttl`]: method@Self::set_ttl + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpListener; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// + /// listener.set_ttl(100).expect("could not set TTL"); + /// assert_eq!(listener.ttl()?, 100); + /// + /// Ok(()) + /// } + /// ``` + pub fn ttl(&self) -> io::Result<u32> { + self.io.ttl() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpListener; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// + /// listener.set_ttl(100).expect("could not set TTL"); + /// + /// Ok(()) + /// } + /// ``` + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.io.set_ttl(ttl) + } +} + +impl TryFrom<net::TcpListener> for TcpListener { + type Error = io::Error; + + /// Consumes stream, returning the tokio I/O object. + /// + /// This is equivalent to + /// [`TcpListener::from_std(stream)`](TcpListener::from_std). + fn try_from(stream: net::TcpListener) -> Result<Self, Self::Error> { + Self::from_std(stream) + } +} + +impl fmt::Debug for TcpListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.fmt(f) + } +} + +#[cfg(unix)] +mod sys { + use super::TcpListener; + use std::os::unix::prelude::*; + + impl AsRawFd for TcpListener { + fn as_raw_fd(&self) -> RawFd { + self.io.as_raw_fd() + } + } +} + +#[cfg(windows)] +mod sys { + use super::TcpListener; + use std::os::windows::prelude::*; + + impl AsRawSocket for TcpListener { + fn as_raw_socket(&self) -> RawSocket { + self.io.as_raw_socket() + } + } +} diff --git a/third_party/rust/tokio/src/net/tcp/mod.rs b/third_party/rust/tokio/src/net/tcp/mod.rs new file mode 100644 index 0000000000..cb8a8b238b --- /dev/null +++ b/third_party/rust/tokio/src/net/tcp/mod.rs @@ -0,0 +1,14 @@ +//! TCP utility types. + +pub(crate) mod listener; + +pub(crate) mod socket; + +mod split; +pub use split::{ReadHalf, WriteHalf}; + +mod split_owned; +pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; + +pub(crate) mod stream; +pub(crate) use stream::TcpStream; diff --git a/third_party/rust/tokio/src/net/tcp/socket.rs b/third_party/rust/tokio/src/net/tcp/socket.rs new file mode 100644 index 0000000000..171e240189 --- /dev/null +++ b/third_party/rust/tokio/src/net/tcp/socket.rs @@ -0,0 +1,690 @@ +use crate::net::{TcpListener, TcpStream}; + +use std::fmt; +use std::io; +use std::net::SocketAddr; + +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::time::Duration; + +cfg_net! { + /// A TCP socket that has not yet been converted to a `TcpStream` or + /// `TcpListener`. + /// + /// `TcpSocket` wraps an operating system socket and enables the caller to + /// configure the socket before establishing a TCP connection or accepting + /// inbound connections. The caller is able to set socket option and explicitly + /// bind the socket with a socket address. + /// + /// The underlying socket is closed when the `TcpSocket` value is dropped. + /// + /// `TcpSocket` should only be used directly if the default configuration used + /// by `TcpStream::connect` and `TcpListener::bind` does not meet the required + /// use case. + /// + /// Calling `TcpStream::connect("127.0.0.1:8080")` is equivalent to: + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// let stream = socket.connect(addr).await?; + /// # drop(stream); + /// + /// Ok(()) + /// } + /// ``` + /// + /// Calling `TcpListener::bind("127.0.0.1:8080")` is equivalent to: + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// // On platforms with Berkeley-derived sockets, this allows to quickly + /// // rebind a socket, without needing to wait for the OS to clean up the + /// // previous one. + /// // + /// // On Windows, this allows rebinding sockets which are actively in use, + /// // which allows “socket hijacking”, so we explicitly don't set it here. + /// // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + /// socket.set_reuseaddr(true)?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + /// + /// Setting socket options not explicitly provided by `TcpSocket` may be done by + /// accessing the `RawFd`/`RawSocket` using [`AsRawFd`]/[`AsRawSocket`] and + /// setting the option with a crate like [`socket2`]. + /// + /// [`RawFd`]: https://doc.rust-lang.org/std/os/unix/io/type.RawFd.html + /// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html + /// [`AsRawFd`]: https://doc.rust-lang.org/std/os/unix/io/trait.AsRawFd.html + /// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html + /// [`socket2`]: https://docs.rs/socket2/ + #[cfg_attr(docsrs, doc(alias = "connect_std"))] + pub struct TcpSocket { + inner: socket2::Socket, + } +} + +impl TcpSocket { + /// Creates a new socket configured for IPv4. + /// + /// Calls `socket(2)` with `AF_INET` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created `TcpSocket` is returned. If an error is + /// encountered, it is returned instead. + /// + /// # Examples + /// + /// Create a new IPv4 socket and start listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(128)?; + /// # drop(listener); + /// Ok(()) + /// } + /// ``` + pub fn new_v4() -> io::Result<TcpSocket> { + TcpSocket::new(socket2::Domain::IPV4) + } + + /// Creates a new socket configured for IPv6. + /// + /// Calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created `TcpSocket` is returned. If an error is + /// encountered, it is returned instead. + /// + /// # Examples + /// + /// Create a new IPv6 socket and start listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "[::1]:8080".parse().unwrap(); + /// let socket = TcpSocket::new_v6()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(128)?; + /// # drop(listener); + /// Ok(()) + /// } + /// ``` + pub fn new_v6() -> io::Result<TcpSocket> { + TcpSocket::new(socket2::Domain::IPV6) + } + + fn new(domain: socket2::Domain) -> io::Result<TcpSocket> { + let ty = socket2::Type::STREAM; + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + ))] + let ty = ty.nonblocking(); + let inner = socket2::Socket::new(domain, ty, Some(socket2::Protocol::TCP))?; + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + )))] + inner.set_nonblocking(true)?; + Ok(TcpSocket { inner }) + } + + /// Allows the socket to bind to an in-use address. + /// + /// Behavior is platform specific. Refer to the target platform's + /// documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.set_reuseaddr(true)?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> { + self.inner.set_reuse_address(reuseaddr) + } + + /// Retrieves the value set for `SO_REUSEADDR` on this socket. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.set_reuseaddr(true)?; + /// assert!(socket.reuseaddr().unwrap()); + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// Ok(()) + /// } + /// ``` + pub fn reuseaddr(&self) -> io::Result<bool> { + self.inner.reuse_address() + } + + /// Allows the socket to bind to an in-use port. Only available for unix systems + /// (excluding Solaris & Illumos). + /// + /// Behavior is platform specific. Refer to the target platform's + /// documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.set_reuseport(true)?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// Ok(()) + /// } + /// ``` + #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] + #[cfg_attr( + docsrs, + doc(cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))) + )] + pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> { + self.inner.set_reuse_port(reuseport) + } + + /// Allows the socket to bind to an in-use port. Only available for unix systems + /// (excluding Solaris & Illumos). + /// + /// Behavior is platform specific. Refer to the target platform's + /// documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.set_reuseport(true)?; + /// assert!(socket.reuseport().unwrap()); + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// Ok(()) + /// } + /// ``` + #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] + #[cfg_attr( + docsrs, + doc(cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))) + )] + pub fn reuseport(&self) -> io::Result<bool> { + self.inner.reuse_port() + } + + /// Sets the size of the TCP send buffer on this socket. + /// + /// On most operating systems, this sets the `SO_SNDBUF` socket option. + pub fn set_send_buffer_size(&self, size: u32) -> io::Result<()> { + self.inner.set_send_buffer_size(size as usize) + } + + /// Returns the size of the TCP send buffer for this socket. + /// + /// On most operating systems, this is the value of the `SO_SNDBUF` socket + /// option. + /// + /// Note that if [`set_send_buffer_size`] has been called on this socket + /// previously, the value returned by this function may not be the same as + /// the argument provided to `set_send_buffer_size`. This is for the + /// following reasons: + /// + /// * Most operating systems have minimum and maximum allowed sizes for the + /// send buffer, and will clamp the provided value if it is below the + /// minimum or above the maximum. The minimum and maximum buffer sizes are + /// OS-dependent. + /// * Linux will double the buffer size to account for internal bookkeeping + /// data, and returns the doubled value from `getsockopt(2)`. As per `man + /// 7 socket`: + /// > Sets or gets the maximum socket send buffer in bytes. The + /// > kernel doubles this value (to allow space for bookkeeping + /// > overhead) when it is set using `setsockopt(2)`, and this doubled + /// > value is returned by `getsockopt(2)`. + /// + /// [`set_send_buffer_size`]: #method.set_send_buffer_size + pub fn send_buffer_size(&self) -> io::Result<u32> { + self.inner.send_buffer_size().map(|n| n as u32) + } + + /// Sets the size of the TCP receive buffer on this socket. + /// + /// On most operating systems, this sets the `SO_RCVBUF` socket option. + pub fn set_recv_buffer_size(&self, size: u32) -> io::Result<()> { + self.inner.set_recv_buffer_size(size as usize) + } + + /// Returns the size of the TCP receive buffer for this socket. + /// + /// On most operating systems, this is the value of the `SO_RCVBUF` socket + /// option. + /// + /// Note that if [`set_recv_buffer_size`] has been called on this socket + /// previously, the value returned by this function may not be the same as + /// the argument provided to `set_send_buffer_size`. This is for the + /// following reasons: + /// + /// * Most operating systems have minimum and maximum allowed sizes for the + /// receive buffer, and will clamp the provided value if it is below the + /// minimum or above the maximum. The minimum and maximum buffer sizes are + /// OS-dependent. + /// * Linux will double the buffer size to account for internal bookkeeping + /// data, and returns the doubled value from `getsockopt(2)`. As per `man + /// 7 socket`: + /// > Sets or gets the maximum socket send buffer in bytes. The + /// > kernel doubles this value (to allow space for bookkeeping + /// > overhead) when it is set using `setsockopt(2)`, and this doubled + /// > value is returned by `getsockopt(2)`. + /// + /// [`set_recv_buffer_size`]: #method.set_recv_buffer_size + pub fn recv_buffer_size(&self) -> io::Result<u32> { + self.inner.recv_buffer_size().map(|n| n as u32) + } + + /// Sets the linger duration of this socket by setting the SO_LINGER option. + /// + /// This option controls the action taken when a stream has unsent messages and the stream is + /// closed. If SO_LINGER is set, the system shall block the process until it can transmit the + /// data or until the time expires. + /// + /// If SO_LINGER is not specified, and the socket is closed, the system handles the call in a + /// way that allows the process to continue as quickly as possible. + pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> { + self.inner.set_linger(dur) + } + + /// Reads the linger duration for this socket by getting the `SO_LINGER` + /// option. + /// + /// For more information about this option, see [`set_linger`]. + /// + /// [`set_linger`]: TcpSocket::set_linger + pub fn linger(&self) -> io::Result<Option<Duration>> { + self.inner.linger() + } + + /// Gets the local address of this socket. + /// + /// Will fail on windows if called before `bind`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// assert_eq!(socket.local_addr().unwrap().to_string(), "127.0.0.1:8080"); + /// let listener = socket.listen(1024)?; + /// Ok(()) + /// } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr().and_then(convert_address) + } + + /// Binds the socket to the given address. + /// + /// This calls the `bind(2)` operating-system function. Behavior is + /// platform specific. Refer to the target platform's documentation for more + /// details. + /// + /// # Examples + /// + /// Bind a socket before listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + self.inner.bind(&addr.into()) + } + + /// Establishes a TCP connection with a peer at the specified socket address. + /// + /// The `TcpSocket` is consumed. Once the connection is established, a + /// connected [`TcpStream`] is returned. If the connection fails, the + /// encountered error is returned. + /// + /// [`TcpStream`]: TcpStream + /// + /// This calls the `connect(2)` operating-system function. Behavior is + /// platform specific. Refer to the target platform's documentation for more + /// details. + /// + /// # Examples + /// + /// Connecting to a peer. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// let stream = socket.connect(addr).await?; + /// # drop(stream); + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> { + if let Err(err) = self.inner.connect(&addr.into()) { + #[cfg(unix)] + if err.raw_os_error() != Some(libc::EINPROGRESS) { + return Err(err); + } + #[cfg(windows)] + if err.kind() != io::ErrorKind::WouldBlock { + return Err(err); + } + } + #[cfg(unix)] + let mio = { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = self.inner.into_raw_fd(); + unsafe { mio::net::TcpStream::from_raw_fd(raw_fd) } + }; + + #[cfg(windows)] + let mio = { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let raw_socket = self.inner.into_raw_socket(); + unsafe { mio::net::TcpStream::from_raw_socket(raw_socket) } + }; + + TcpStream::connect_mio(mio).await + } + + /// Converts the socket into a `TcpListener`. + /// + /// `backlog` defines the maximum number of pending connections are queued + /// by the operating system at any given time. Connection are removed from + /// the queue with [`TcpListener::accept`]. When the queue is full, the + /// operating-system will start rejecting connections. + /// + /// [`TcpListener::accept`]: TcpListener::accept + /// + /// This calls the `listen(2)` operating-system function, marking the socket + /// as a passive socket. Behavior is platform specific. Refer to the target + /// platform's documentation for more details. + /// + /// # Examples + /// + /// Create a `TcpListener`. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn listen(self, backlog: u32) -> io::Result<TcpListener> { + self.inner.listen(backlog as i32)?; + #[cfg(unix)] + let mio = { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = self.inner.into_raw_fd(); + unsafe { mio::net::TcpListener::from_raw_fd(raw_fd) } + }; + + #[cfg(windows)] + let mio = { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let raw_socket = self.inner.into_raw_socket(); + unsafe { mio::net::TcpListener::from_raw_socket(raw_socket) } + }; + + TcpListener::new(mio) + } + + /// Converts a [`std::net::TcpStream`] into a `TcpSocket`. The provided + /// socket must not have been connected prior to calling this function. This + /// function is typically used together with crates such as [`socket2`] to + /// configure socket options that are not available on `TcpSocket`. + /// + /// [`std::net::TcpStream`]: struct@std::net::TcpStream + /// [`socket2`]: https://docs.rs/socket2/ + /// + /// # Examples + /// + /// ``` + /// use tokio::net::TcpSocket; + /// use socket2::{Domain, Socket, Type}; + /// + /// #[tokio::main] + /// async fn main() -> std::io::Result<()> { + /// + /// let socket2_socket = Socket::new(Domain::IPV4, Type::STREAM, None)?; + /// + /// let socket = TcpSocket::from_std_stream(socket2_socket.into()); + /// + /// Ok(()) + /// } + /// ``` + pub fn from_std_stream(std_stream: std::net::TcpStream) -> TcpSocket { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = std_stream.into_raw_fd(); + unsafe { TcpSocket::from_raw_fd(raw_fd) } + } + + #[cfg(windows)] + { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let raw_socket = std_stream.into_raw_socket(); + unsafe { TcpSocket::from_raw_socket(raw_socket) } + } + } +} + +fn convert_address(address: socket2::SockAddr) -> io::Result<SocketAddr> { + match address.as_socket() { + Some(address) => Ok(address), + None => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid address family (not IPv4 or IPv6)", + )), + } +} + +impl fmt::Debug for TcpSocket { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(fmt) + } +} + +#[cfg(unix)] +impl AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(unix)] +impl FromRawFd for TcpSocket { + /// Converts a `RawFd` to a `TcpSocket`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket { + let inner = socket2::Socket::from_raw_fd(fd); + TcpSocket { inner } + } +} + +#[cfg(unix)] +impl IntoRawFd for TcpSocket { + fn into_raw_fd(self) -> RawFd { + self.inner.into_raw_fd() + } +} + +#[cfg(windows)] +impl IntoRawSocket for TcpSocket { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_raw_socket() + } +} + +#[cfg(windows)] +impl AsRawSocket for TcpSocket { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +impl FromRawSocket for TcpSocket { + /// Converts a `RawSocket` to a `TcpStream`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket { + let inner = socket2::Socket::from_raw_socket(socket); + TcpSocket { inner } + } +} diff --git a/third_party/rust/tokio/src/net/tcp/split.rs b/third_party/rust/tokio/src/net/tcp/split.rs new file mode 100644 index 0000000000..0e02928495 --- /dev/null +++ b/third_party/rust/tokio/src/net/tcp/split.rs @@ -0,0 +1,401 @@ +//! `TcpStream` split support. +//! +//! A `TcpStream` can be split into a `ReadHalf` and a +//! `WriteHalf` with the `TcpStream::split` method. `ReadHalf` +//! implements `AsyncRead` while `WriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::future::poll_fn; +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; +use crate::net::TcpStream; + +use std::io; +use std::net::{Shutdown, SocketAddr}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + use bytes::BufMut; +} + +/// Borrowed read half of a [`TcpStream`], created by [`split`]. +/// +/// Reading from a `ReadHalf` is usually done using the convenience methods found on the +/// [`AsyncReadExt`] trait. +/// +/// [`TcpStream`]: TcpStream +/// [`split`]: TcpStream::split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +#[derive(Debug)] +pub struct ReadHalf<'a>(&'a TcpStream); + +/// Borrowed write half of a [`TcpStream`], created by [`split`]. +/// +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will +/// shut down the TCP stream in the write direction. +/// +/// Writing to an `WriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. +/// +/// [`TcpStream`]: TcpStream +/// [`split`]: TcpStream::split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +#[derive(Debug)] +pub struct WriteHalf<'a>(&'a TcpStream); + +pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) { + (ReadHalf(&*stream), WriteHalf(&*stream)) +} + +impl ReadHalf<'_> { + /// Attempts to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. + /// + /// Note that on multiple calls to `poll_peek` or `poll_read`, only the + /// `Waker` from the `Context` passed to the most recent call is scheduled + /// to receive a wakeup. + /// + /// See the [`TcpStream::poll_peek`] level documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, ReadBuf}; + /// use tokio::net::TcpStream; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let (mut read_half, _) = stream.split(); + /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); + /// + /// poll_fn(|cx| { + /// read_half.poll_peek(cx, &mut buf) + /// }).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::poll_peek`]: TcpStream::poll_peek + pub fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { + self.0.poll_peek(cx, buf) + } + + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// See the [`TcpStream::peek`] level documentation for more details. + /// + /// [`TcpStream::peek`]: TcpStream::peek + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::io::AsyncReadExt; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect("127.0.0.1:8080").await?; + /// let (mut read_half, _) = stream.split(); + /// + /// let mut b1 = [0; 10]; + /// let mut b2 = [0; 10]; + /// + /// // Peek at the data + /// let n = read_half.peek(&mut b1).await?; + /// + /// // Read the data + /// assert_eq!(n, read_half.read(&mut b2[..n]).await?); + /// assert_eq!(&b1[..n], &b2[..n]); + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`read`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read`]: fn@crate::io::AsyncReadExt::read + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// This function is also equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.0.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.0.try_read(buf) + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.0.try_read_vectored(bufs) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.0.try_read_buf(buf) + } + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + +impl WriteHalf<'_> { + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.0.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.0.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.0.try_write_vectored(bufs) + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + +impl AsyncRead for ReadHalf<'_> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.0.poll_read_priv(cx, buf) + } +} + +impl AsyncWrite for WriteHalf<'_> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.0.poll_write_priv(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.0.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + self.0.shutdown_std(Shutdown::Write).into() + } +} + +impl AsRef<TcpStream> for ReadHalf<'_> { + fn as_ref(&self) -> &TcpStream { + self.0 + } +} + +impl AsRef<TcpStream> for WriteHalf<'_> { + fn as_ref(&self) -> &TcpStream { + self.0 + } +} diff --git a/third_party/rust/tokio/src/net/tcp/split_owned.rs b/third_party/rust/tokio/src/net/tcp/split_owned.rs new file mode 100644 index 0000000000..ef4e7b5361 --- /dev/null +++ b/third_party/rust/tokio/src/net/tcp/split_owned.rs @@ -0,0 +1,485 @@ +//! `TcpStream` owned split support. +//! +//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` +//! with the `TcpStream::into_split` method. `OwnedReadHalf` implements +//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::future::poll_fn; +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; +use crate::net::TcpStream; + +use std::error::Error; +use std::net::{Shutdown, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +cfg_io_util! { + use bytes::BufMut; +} + +/// Owned read half of a [`TcpStream`], created by [`into_split`]. +/// +/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found +/// on the [`AsyncReadExt`] trait. +/// +/// [`TcpStream`]: TcpStream +/// [`into_split`]: TcpStream::into_split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +#[derive(Debug)] +pub struct OwnedReadHalf { + inner: Arc<TcpStream>, +} + +/// Owned write half of a [`TcpStream`], created by [`into_split`]. +/// +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will +/// shut down the TCP stream in the write direction. Dropping the write half +/// will also shut down the write half of the TCP stream. +/// +/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. +/// +/// [`TcpStream`]: TcpStream +/// [`into_split`]: TcpStream::into_split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +#[derive(Debug)] +pub struct OwnedWriteHalf { + inner: Arc<TcpStream>, + shutdown_on_drop: bool, +} + +pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + shutdown_on_drop: true, + }; + (read, write) +} + +pub(crate) fn reunite( + read: OwnedReadHalf, + write: OwnedWriteHalf, +) -> Result<TcpStream, ReuniteError> { + if Arc::ptr_eq(&read.inner, &write.inner) { + write.forget(); + // This unwrap cannot fail as the api does not allow creating more than two Arcs, + // and we just dropped the other half. + Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(read, write)) + } +} + +/// Error indicating that two halves were not from the same socket, and thus could +/// not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> { + reunite(self, other) + } + + /// Attempt to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. + /// + /// Note that on multiple calls to `poll_peek` or `poll_read`, only the + /// `Waker` from the `Context` passed to the most recent call is scheduled + /// to receive a wakeup. + /// + /// See the [`TcpStream::poll_peek`] level documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, ReadBuf}; + /// use tokio::net::TcpStream; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let (mut read_half, _) = stream.into_split(); + /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); + /// + /// poll_fn(|cx| { + /// read_half.poll_peek(cx, &mut buf) + /// }).await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`TcpStream::poll_peek`]: TcpStream::poll_peek + pub fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { + self.inner.poll_peek(cx, buf) + } + + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// See the [`TcpStream::peek`] level documentation for more details. + /// + /// [`TcpStream::peek`]: TcpStream::peek + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::io::AsyncReadExt; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// let (mut read_half, _) = stream.into_split(); + /// + /// let mut b1 = [0; 10]; + /// let mut b2 = [0; 10]; + /// + /// // Peek at the data + /// let n = read_half.peek(&mut b1).await?; + /// + /// // Read the data + /// assert_eq!(n, read_half.read(&mut b2[..n]).await?); + /// assert_eq!(&b1[..n], &b2[..n]); + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`read`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read`]: fn@crate::io::AsyncReadExt::read + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// This function is also equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.inner.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.try_read(buf) + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.inner.try_read_vectored(bufs) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.inner.try_read_buf(buf) + } + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } +} + +impl AsyncRead for OwnedReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.inner.poll_read_priv(cx, buf) + } +} + +impl OwnedWriteHalf { + /// Attempts to put the two halves of a `TcpStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> { + reunite(other, self) + } + + /// Destroys the write half, but don't close the write half of the stream + /// until the read half is dropped. If the read half has already been + /// dropped, this closes the stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.inner.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.inner.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.inner.try_write_vectored(bufs) + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } +} + +impl Drop for OwnedWriteHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown_std(Shutdown::Write); + } + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.inner.poll_write_priv(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.inner.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + let res = self.inner.shutdown_std(Shutdown::Write); + if res.is_ok() { + Pin::into_inner(self).shutdown_on_drop = false; + } + res.into() + } +} + +impl AsRef<TcpStream> for OwnedReadHalf { + fn as_ref(&self) -> &TcpStream { + &*self.inner + } +} + +impl AsRef<TcpStream> for OwnedWriteHalf { + fn as_ref(&self) -> &TcpStream { + &*self.inner + } +} diff --git a/third_party/rust/tokio/src/net/tcp/stream.rs b/third_party/rust/tokio/src/net/tcp/stream.rs new file mode 100644 index 0000000000..ebb67b84d1 --- /dev/null +++ b/third_party/rust/tokio/src/net/tcp/stream.rs @@ -0,0 +1,1310 @@ +use crate::future::poll_fn; +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; +use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; +use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; +use crate::net::{to_socket_addrs, ToSocketAddrs}; + +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::net::{Shutdown, SocketAddr}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +cfg_io_util! { + use bytes::BufMut; +} + +cfg_net! { + /// A TCP stream between a local and a remote socket. + /// + /// A TCP stream can either be created by connecting to an endpoint, via the + /// [`connect`] method, or by [accepting] a connection from a [listener]. A + /// TCP stream can also be created via the [`TcpSocket`] type. + /// + /// Reading and writing to a `TcpStream` is usually done using the + /// convenience methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] + /// traits. + /// + /// [`connect`]: method@TcpStream::connect + /// [accepting]: method@crate::net::TcpListener::accept + /// [listener]: struct@crate::net::TcpListener + /// [`TcpSocket`]: struct@crate::net::TcpSocket + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::io::AsyncWriteExt; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// // Write some data. + /// stream.write_all(b"hello world!").await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + /// + /// To shut down the stream in the write direction, you can call the + /// [`shutdown()`] method. This will cause the other peer to receive a read of + /// length 0, indicating that no more data will be sent. This only closes + /// the stream in one direction. + /// + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown + pub struct TcpStream { + io: PollEvented<mio::net::TcpStream>, + } +} + +impl TcpStream { + /// Opens a TCP connection to a remote host. + /// + /// `addr` is an address of the remote host. Anything which implements the + /// [`ToSocketAddrs`] trait can be supplied as the address. If `addr` + /// yields multiple addresses, connect will be attempted with each of the + /// addresses until a connection is successful. If none of the addresses + /// result in a successful connection, the error returned from the last + /// connection attempt (the last address) is returned. + /// + /// To configure the socket before connecting, you can use the [`TcpSocket`] + /// type. + /// + /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs + /// [`TcpSocket`]: struct@crate::net::TcpSocket + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::io::AsyncWriteExt; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// // Write some data. + /// stream.write_all(b"hello world!").await?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. + /// + /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all + /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> { + let addrs = to_socket_addrs(addr).await?; + + let mut last_err = None; + + for addr in addrs { + match TcpStream::connect_addr(addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) + } + + /// Establishes a connection to the specified `addr`. + async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> { + let sys = mio::net::TcpStream::connect(addr)?; + TcpStream::connect_mio(sys).await + } + + pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> { + let stream = TcpStream::new(sys)?; + + // Once we've connected, wait for the stream to be writable as + // that's when the actual connection has been initiated. Once we're + // writable we check for `take_socket_error` to see if the connect + // actually hit an error or not. + // + // If all that succeeded then we ship everything on up. + poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; + + if let Some(e) = stream.io.take_error()? { + return Err(e); + } + + Ok(stream) + } + + pub(crate) fn new(connected: mio::net::TcpStream) -> io::Result<TcpStream> { + let io = PollEvented::new(connected)?; + Ok(TcpStream { io }) + } + + /// Creates new `TcpStream` from a `std::net::TcpStream`. + /// + /// This function is intended to be used to wrap a TCP stream from the + /// standard library in the Tokio equivalent. The conversion assumes nothing + /// about the underlying stream; it is left up to the user to set it in + /// non-blocking mode. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// use tokio::net::TcpStream; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let std_stream = std::net::TcpStream::connect("127.0.0.1:34254")?; + /// std_stream.set_nonblocking(true)?; + /// let stream = TcpStream::from_std(std_stream)?; + /// Ok(()) + /// } + /// ``` + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn from_std(stream: std::net::TcpStream) -> io::Result<TcpStream> { + let io = mio::net::TcpStream::from_std(stream); + let io = PollEvented::new(io)?; + Ok(TcpStream { io }) + } + + /// Turns a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`]. + /// + /// The returned [`std::net::TcpStream`] will have nonblocking mode set as `true`. + /// Use [`set_nonblocking`] to change the blocking mode if needed. + /// + /// # Examples + /// + /// ``` + /// use std::error::Error; + /// use std::io::Read; + /// use tokio::net::TcpListener; + /// # use tokio::net::TcpStream; + /// # use tokio::io::AsyncWriteExt; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let mut data = [0u8; 12]; + /// let listener = TcpListener::bind("127.0.0.1:34254").await?; + /// # let handle = tokio::spawn(async { + /// # let mut stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap(); + /// # stream.write(b"Hello world!").await.unwrap(); + /// # }); + /// let (tokio_tcp_stream, _) = listener.accept().await?; + /// let mut std_tcp_stream = tokio_tcp_stream.into_std()?; + /// # handle.await.expect("The task being joined has panicked"); + /// std_tcp_stream.set_nonblocking(false)?; + /// std_tcp_stream.read_exact(&mut data)?; + /// # assert_eq!(b"Hello world!", &data); + /// Ok(()) + /// } + /// ``` + /// [`tokio::net::TcpStream`]: TcpStream + /// [`std::net::TcpStream`]: std::net::TcpStream + /// [`set_nonblocking`]: fn@std::net::TcpStream::set_nonblocking + pub fn into_std(self) -> io::Result<std::net::TcpStream> { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::net::TcpStream::from_raw_fd(raw_fd) }) + } + + #[cfg(windows)] + { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + self.io + .into_inner() + .map(|io| io.into_raw_socket()) + .map(|raw_socket| unsafe { std::net::TcpStream::from_raw_socket(raw_socket) }) + } + } + + /// Returns the local address that this stream is bound to. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// println!("{:?}", stream.local_addr()?); + /// # Ok(()) + /// # } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.local_addr() + } + + /// Returns the remote address that this stream is connected to. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// println!("{:?}", stream.peer_addr()?); + /// # Ok(()) + /// # } + /// ``` + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.io.peer_addr() + } + + /// Attempts to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. + /// + /// Note that on multiple calls to `poll_peek`, `poll_read` or + /// `poll_read_ready`, only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// `poll_write` retains a second, independent waker.) + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if data is not yet available. + /// * `Poll::Ready(Ok(n))` if data is available. `n` is the number of bytes peeked. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io::{self, ReadBuf}; + /// use tokio::net::TcpStream; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); + /// + /// poll_fn(|cx| { + /// stream.poll_peek(cx, &mut buf) + /// }).await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn poll_peek( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { + loop { + let ev = ready!(self.io.registration().poll_read_ready(cx))?; + + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + match self.io.peek(b) { + Ok(ret) => { + unsafe { buf.assume_init(ret) }; + buf.advance(ret); + return Poll::Ready(Ok(ret)); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.registration().clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// Concurrently read and write to the stream on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// let ready = stream.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the tcp stream is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the tcp + /// stream becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_read_ready`, `poll_read` or + /// `poll_peek`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. (However, + /// `poll_write_ready` retains a second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the tcp stream is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the tcp stream is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: TcpStream::readable() + /// [`ready()`]: TcpStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + use std::io::Read; + + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: TcpStream::try_read() + /// [`readable()`]: TcpStream::readable() + /// [`ready()`]: TcpStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + use std::io::Read; + + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: TcpStream::readable() + /// [`ready()`]: TcpStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// let mut buf = Vec::with_capacity(4096); + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_buf(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.io.registration().try_io(Interest::READABLE, || { + use std::io::Read; + + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + + // Safety: We trust `TcpStream::read` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).read(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the tcp stream is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the tcp + /// stream becomes ready for writing, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the tcp stream is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the tcp stream is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Try to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + use std::io::Write; + + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: TcpStream::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { + use std::io::Write; + + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(bufs)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `TcpStream` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: TcpStream::readable() + /// [`writable()`]: TcpStream::writable() + /// [`ready()`]: TcpStream::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// Successive calls return the same data. This is accomplished by passing + /// `MSG_PEEK` as a flag to the underlying recv system call. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use tokio::io::AsyncReadExt; + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let mut stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// let mut b1 = [0; 10]; + /// let mut b2 = [0; 10]; + /// + /// // Peek at the data + /// let n = stream.peek(&mut b1).await?; + /// + /// // Read the data + /// assert_eq!(n, stream.read(&mut b2[..n]).await?); + /// assert_eq!(&b1[..n], &b2[..n]); + /// + /// Ok(()) + /// } + /// ``` + /// + /// The [`read`] method is defined on the [`AsyncReadExt`] trait. + /// + /// [`read`]: fn@crate::io::AsyncReadExt::read + /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt + pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .async_io(Interest::READABLE, || self.io.peek(buf)) + .await + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O on the specified + /// portions to return immediately with an appropriate value (see the + /// documentation of `Shutdown`). + pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> { + self.io.shutdown(how) + } + + /// Gets the value of the `TCP_NODELAY` option on this socket. + /// + /// For more information about this option, see [`set_nodelay`]. + /// + /// [`set_nodelay`]: TcpStream::set_nodelay + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// println!("{:?}", stream.nodelay()?); + /// # Ok(()) + /// # } + /// ``` + pub fn nodelay(&self) -> io::Result<bool> { + self.io.nodelay() + } + + /// Sets the value of the `TCP_NODELAY` option on this socket. + /// + /// If set, this option disables the Nagle algorithm. This means that + /// segments are always sent as soon as possible, even if there is only a + /// small amount of data. When not set, data is buffered until there is a + /// sufficient amount to send out, thereby avoiding the frequent sending of + /// small packets. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_nodelay(true)?; + /// # Ok(()) + /// # } + /// ``` + pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { + self.io.set_nodelay(nodelay) + } + + /// Reads the linger duration for this socket by getting the `SO_LINGER` + /// option. + /// + /// For more information about this option, see [`set_linger`]. + /// + /// [`set_linger`]: TcpStream::set_linger + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// println!("{:?}", stream.linger()?); + /// # Ok(()) + /// # } + /// ``` + pub fn linger(&self) -> io::Result<Option<Duration>> { + socket2::SockRef::from(self).linger() + } + + /// Sets the linger duration of this socket by setting the SO_LINGER option. + /// + /// This option controls the action taken when a stream has unsent messages and the stream is + /// closed. If SO_LINGER is set, the system shall block the process until it can transmit the + /// data or until the time expires. + /// + /// If SO_LINGER is not specified, and the stream is closed, the system handles the call in a + /// way that allows the process to continue as quickly as possible. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_linger(None)?; + /// # Ok(()) + /// # } + /// ``` + pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> { + socket2::SockRef::from(self).set_linger(dur) + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`]. + /// + /// [`set_ttl`]: TcpStream::set_ttl + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// println!("{:?}", stream.ttl()?); + /// # Ok(()) + /// # } + /// ``` + pub fn ttl(&self) -> io::Result<u32> { + self.io.ttl() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_ttl(123)?; + /// # Ok(()) + /// # } + /// ``` + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.io.set_ttl(ttl) + } + + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] + /// Splits a `TcpStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// This method is more efficient than [`into_split`], but the halves cannot be + /// moved into independently spawned tasks. + /// + /// [`into_split`]: TcpStream::into_split() + pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) { + split(self) + } + + /// Splits a `TcpStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, however + /// this comes at the cost of a heap allocation. + /// + /// **Note:** Dropping the write half will shut down the write half of the TCP + /// stream. This is equivalent to calling [`shutdown()`] on the `TcpStream`. + /// + /// [`split`]: TcpStream::split() + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + split_owned(self) + } + + // == Poll IO functions that takes `&self` == + // + // To read or write without mutable access to the `UnixStream`, combine the + // `poll_read_ready` or `poll_write_ready` methods with the `try_read` or + // `try_write` methods. + + pub(crate) fn poll_read_priv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // Safety: `TcpStream::read` correctly handles reads into uninitialized memory + unsafe { self.io.poll_read(cx, buf) } + } + + pub(super) fn poll_write_priv( + &self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + pub(super) fn poll_write_vectored_priv( + &self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } +} + +impl TryFrom<std::net::TcpStream> for TcpStream { + type Error = io::Error; + + /// Consumes stream, returning the tokio I/O object. + /// + /// This is equivalent to + /// [`TcpStream::from_std(stream)`](TcpStream::from_std). + fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> { + Self::from_std(stream) + } +} + +// ===== impl Read / Write ===== + +impl AsyncRead for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.poll_read_priv(cx, buf) + } +} + +impl AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.poll_write_priv(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + self.shutdown_std(std::net::Shutdown::Write)?; + Poll::Ready(Ok(())) + } +} + +impl fmt::Debug for TcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.fmt(f) + } +} + +#[cfg(unix)] +mod sys { + use super::TcpStream; + use std::os::unix::prelude::*; + + impl AsRawFd for TcpStream { + fn as_raw_fd(&self) -> RawFd { + self.io.as_raw_fd() + } + } +} + +#[cfg(windows)] +mod sys { + use super::TcpStream; + use std::os::windows::prelude::*; + + impl AsRawSocket for TcpStream { + fn as_raw_socket(&self) -> RawSocket { + self.io.as_raw_socket() + } + } +} diff --git a/third_party/rust/tokio/src/net/udp.rs b/third_party/rust/tokio/src/net/udp.rs new file mode 100644 index 0000000000..12af5152c2 --- /dev/null +++ b/third_party/rust/tokio/src/net/udp.rs @@ -0,0 +1,1589 @@ +use crate::io::{Interest, PollEvented, ReadBuf, Ready}; +use crate::net::{to_socket_addrs, ToSocketAddrs}; + +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::task::{Context, Poll}; + +cfg_io_util! { + use bytes::BufMut; +} + +cfg_net! { + /// A UDP socket. + /// + /// UDP is "connectionless", unlike TCP. Meaning, regardless of what address you've bound to, a `UdpSocket` + /// is free to communicate with many different remotes. In tokio there are basically two main ways to use `UdpSocket`: + /// + /// * one to many: [`bind`](`UdpSocket::bind`) and use [`send_to`](`UdpSocket::send_to`) + /// and [`recv_from`](`UdpSocket::recv_from`) to communicate with many different addresses + /// * one to one: [`connect`](`UdpSocket::connect`) and associate with a single address, using [`send`](`UdpSocket::send`) + /// and [`recv`](`UdpSocket::recv`) to communicate only with that remote address + /// + /// This type does not provide a `split` method, because this functionality + /// can be achieved by instead wrapping the socket in an [`Arc`]. Note that + /// you do not need a `Mutex` to share the `UdpSocket` — an `Arc<UdpSocket>` + /// is enough. This is because all of the methods take `&self` instead of + /// `&mut self`. Once you have wrapped it in an `Arc`, you can call + /// `.clone()` on the `Arc<UdpSocket>` to get multiple shared handles to the + /// same socket. An example of such usage can be found further down. + /// + /// [`Arc`]: std::sync::Arc + /// + /// # Streams + /// + /// If you need to listen over UDP and produce a [`Stream`], you can look + /// at [`UdpFramed`]. + /// + /// [`UdpFramed`]: https://docs.rs/tokio-util/latest/tokio_util/udp/struct.UdpFramed.html + /// [`Stream`]: https://docs.rs/futures/0.3/futures/stream/trait.Stream.html + /// + /// # Example: one to many (bind) + /// + /// Using `bind` we can create a simple echo server that sends and recv's with many different clients: + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080").await?; + /// let mut buf = [0; 1024]; + /// loop { + /// let (len, addr) = sock.recv_from(&mut buf).await?; + /// println!("{:?} bytes received from {:?}", len, addr); + /// + /// let len = sock.send_to(&buf[..len], addr).await?; + /// println!("{:?} bytes sent", len); + /// } + /// } + /// ``` + /// + /// # Example: one to one (connect) + /// + /// Or using `connect` we can echo with a single remote address using `send` and `recv`: + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080").await?; + /// + /// let remote_addr = "127.0.0.1:59611"; + /// sock.connect(remote_addr).await?; + /// let mut buf = [0; 1024]; + /// loop { + /// let len = sock.recv(&mut buf).await?; + /// println!("{:?} bytes received from {:?}", len, remote_addr); + /// + /// let len = sock.send(&buf[..len]).await?; + /// println!("{:?} bytes sent", len); + /// } + /// } + /// ``` + /// + /// # Example: Splitting with `Arc` + /// + /// Because `send_to` and `recv_from` take `&self`. It's perfectly alright + /// to use an `Arc<UdpSocket>` and share the references to multiple tasks. + /// Here is a similar "echo" example that supports concurrent + /// sending/receiving: + /// + /// ```no_run + /// use tokio::{net::UdpSocket, sync::mpsc}; + /// use std::{io, net::SocketAddr, sync::Arc}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// let r = Arc::new(sock); + /// let s = r.clone(); + /// let (tx, mut rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(1_000); + /// + /// tokio::spawn(async move { + /// while let Some((bytes, addr)) = rx.recv().await { + /// let len = s.send_to(&bytes, &addr).await.unwrap(); + /// println!("{:?} bytes sent", len); + /// } + /// }); + /// + /// let mut buf = [0; 1024]; + /// loop { + /// let (len, addr) = r.recv_from(&mut buf).await?; + /// println!("{:?} bytes received from {:?}", len, addr); + /// tx.send((buf[..len].to_vec(), addr)).await.unwrap(); + /// } + /// } + /// ``` + /// + pub struct UdpSocket { + io: PollEvented<mio::net::UdpSocket>, + } +} + +impl UdpSocket { + /// This function will create a new UDP socket and attempt to bind it to + /// the `addr` provided. + /// + /// Binding with a port number of 0 will request that the OS assigns a port + /// to this listener. The port allocated can be queried via the `local_addr` + /// method. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080").await?; + /// // use `sock` + /// # let _ = sock; + /// Ok(()) + /// } + /// ``` + pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> { + let addrs = to_socket_addrs(addr).await?; + let mut last_err = None; + + for addr in addrs { + match UdpSocket::bind_addr(addr) { + Ok(socket) => return Ok(socket), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) + } + + fn bind_addr(addr: SocketAddr) -> io::Result<UdpSocket> { + let sys = mio::net::UdpSocket::bind(addr)?; + UdpSocket::new(sys) + } + + fn new(socket: mio::net::UdpSocket) -> io::Result<UdpSocket> { + let io = PollEvented::new(socket)?; + Ok(UdpSocket { io }) + } + + /// Creates new `UdpSocket` from a previously bound `std::net::UdpSocket`. + /// + /// This function is intended to be used to wrap a UDP socket from the + /// standard library in the Tokio equivalent. The conversion assumes nothing + /// about the underlying socket; it is left up to the user to set it in + /// non-blocking mode. + /// + /// This can be used in conjunction with socket2's `Socket` interface to + /// configure a socket before it's handed off, such as setting options like + /// `reuse_address` or binding to multiple addresses. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap(); + /// let std_sock = std::net::UdpSocket::bind(addr)?; + /// std_sock.set_nonblocking(true)?; + /// let sock = UdpSocket::from_std(std_sock)?; + /// // use `sock` + /// # Ok(()) + /// # } + /// ``` + pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> { + let io = mio::net::UdpSocket::from_std(socket); + UdpSocket::new(io) + } + + /// Turns a [`tokio::net::UdpSocket`] into a [`std::net::UdpSocket`]. + /// + /// The returned [`std::net::UdpSocket`] will have nonblocking mode set as + /// `true`. Use [`set_nonblocking`] to change the blocking mode if needed. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let tokio_socket = tokio::net::UdpSocket::bind("127.0.0.1:0").await?; + /// let std_socket = tokio_socket.into_std()?; + /// std_socket.set_nonblocking(false)?; + /// Ok(()) + /// } + /// ``` + /// + /// [`tokio::net::UdpSocket`]: UdpSocket + /// [`std::net::UdpSocket`]: std::net::UdpSocket + /// [`set_nonblocking`]: fn@std::net::UdpSocket::set_nonblocking + pub fn into_std(self) -> io::Result<std::net::UdpSocket> { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::net::UdpSocket::from_raw_fd(raw_fd) }) + } + + #[cfg(windows)] + { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + self.io + .into_inner() + .map(|io| io.into_raw_socket()) + .map(|raw_socket| unsafe { std::net::UdpSocket::from_raw_socket(raw_socket) }) + } + } + + /// Returns the local address that this socket is bound to. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap(); + /// let sock = UdpSocket::bind(addr).await?; + /// // the address the socket is bound to + /// let local_addr = sock.local_addr()?; + /// # Ok(()) + /// # } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.local_addr() + } + + /// Connects the UDP socket setting the default destination for send() and + /// limiting packets that are read via recv from the address specified in + /// `addr`. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// + /// let remote_addr = "127.0.0.1:59600".parse::<SocketAddr>().unwrap(); + /// sock.connect(remote_addr).await?; + /// let mut buf = [0u8; 32]; + /// // recv from remote_addr + /// let len = sock.recv(&mut buf).await?; + /// // send to remote_addr + /// let _len = sock.send(&buf[..len]).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> { + let addrs = to_socket_addrs(addr).await?; + let mut last_err = None; + + for addr in addrs { + match self.io.connect(addr) { + Ok(_) => return Ok(()), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_recv()` or `try_send()`. It + /// can be used to concurrently recv / send to the same socket on a single + /// task without splitting the socket. + /// + /// The function may complete without the socket being ready. This is a + /// false-positive and attempting an operation will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// Concurrently receive from and send to the socket on the same task + /// without splitting. + /// + /// ```no_run + /// use tokio::io::{self, Interest}; + /// use tokio::net::UdpSocket; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// let ready = socket.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// // The buffer is **not** included in the async task and will only exist + /// // on the stack. + /// let mut data = [0; 1024]; + /// match socket.try_recv(&mut data[..]) { + /// Ok(n) => { + /// println!("received {:?}", &data[..n]); + /// } + /// // False-positive, continue + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Write some data + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// println!("sent {} bytes", n); + /// } + /// // False-positive, continue + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is + /// usually paired with `try_send()` or `try_send_to()`. + /// + /// The function may complete without the socket being writable. This is a + /// false-positive and attempting a `try_send()` will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Bind socket + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write/send readiness. + /// + /// If the udp stream is not currently ready for sending, this method will + /// store a clone of the `Waker` from the provided `Context`. When the udp + /// stream becomes ready for sending, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_send_ready` or `poll_send`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_recv_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the udp stream is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the udp stream is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Sends data on the socket to the remote address that the socket is + /// connected to. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + /// + /// # Return + /// + /// On success, the number of bytes sent is returned, otherwise, the + /// encountered error is returned. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::io; + /// use tokio::net::UdpSocket; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Bind socket + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// // Send a message + /// socket.send(b"hello world").await?; + /// + /// Ok(()) + /// } + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .async_io(Interest::WRITABLE, || self.io.send(buf)) + .await + } + + /// Attempts to send data on the socket to the remote address to which it + /// was previously `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, + /// only the `Waker` from the `Context` passed to the most recent call will + /// be scheduled to receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not available to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { + self.io + .registration() + .poll_write_io(cx, || self.io.send(buf)) + } + + /// Tries to send data on the socket to the remote address to which it is + /// connected. + /// + /// When the socket buffer is full, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `writable()`. + /// + /// # Returns + /// + /// If successful, `Ok(n)` is returned, where `n` is the number of bytes + /// sent. If the socket is not ready to send data, + /// `Err(ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Bind a UDP socket + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// // Connect to a peer + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || self.io.send(buf)) + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_recv()`. + /// + /// The function may complete without the socket being readable. This is a + /// false-positive and attempting a `try_recv()` will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read/receive readiness. + /// + /// If the udp stream is not currently ready for receiving, this method will + /// store a clone of the `Waker` from the provided `Context`. When the udp + /// socket becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_recv_ready`, `poll_recv` or + /// `poll_peek`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. (However, + /// `poll_send_ready` retains a second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the udp stream is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the udp stream is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Receives a single datagram message on the socket from the remote address + /// to which it is connected. On success, returns the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// + /// [`connect`]: method@Self::connect + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Bind socket + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// let mut buf = vec![0; 10]; + /// let n = socket.recv(&mut buf).await?; + /// + /// println!("received {} bytes {:?}", n, &buf[..n]); + /// + /// Ok(()) + /// } + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .async_io(Interest::READABLE, || self.io.recv(buf)) + .await + } + + /// Attempts to receive a single datagram message on the socket from the remote + /// address to which it is `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. This method + /// resolves to an error if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(()))` reads data `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { + let n = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uninitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.recv(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(())) + } + + /// Tries to receive a single datagram message on the socket from the remote + /// address to which it is connected. On success, returns the number of + /// bytes read. + /// + /// The function must be called with valid byte array buf of sufficient size + /// to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || self.io.recv(buf)) + } + + cfg_io_util! { + /// Tries to receive data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// The function must be called with valid byte array buf of sufficient size + /// to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_buf(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + + // Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).recv(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + + /// Tries to receive a single datagram message on the socket. On success, + /// returns the number of bytes read and the origin. + /// + /// The function must be called with valid byte array buf of sufficient size + /// to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_buf_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> { + self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + + // Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the + // buffer. + let (n, addr) = (&*self.io).recv_from(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok((n, addr)) + }) + } + } + + /// Sends data on the socket to the given address. On success, returns the + /// number of bytes written. + /// + /// Address type can be any implementor of [`ToSocketAddrs`] trait. See its + /// documentation for concrete examples. + /// + /// It is possible for `addr` to yield multiple addresses, but `send_to` + /// will only send data to the first address yielded by `addr`. + /// + /// This will return an error when the IP version of the local socket does + /// not match that returned from [`ToSocketAddrs`]. + /// + /// [`ToSocketAddrs`]: crate::net::ToSocketAddrs + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send_to` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// let len = socket.send_to(b"hello world", "127.0.0.1:8081").await?; + /// + /// println!("Sent {} bytes", len); + /// + /// Ok(()) + /// } + /// ``` + pub async fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> io::Result<usize> { + let mut addrs = to_socket_addrs(target).await?; + + match addrs.next() { + Some(target) => self.send_to_addr(buf, target).await, + None => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "no addresses to send data to", + )), + } + } + + /// Attempts to send data on the socket to a given address. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: SocketAddr, + ) -> Poll<io::Result<usize>> { + self.io + .registration() + .poll_write_io(cx, || self.io.send_to(buf, target)) + } + + /// Tries to send data on the socket to the given address, but if the send is + /// blocked this will return right away. + /// + /// This function is usually paired with `writable()`. + /// + /// # Returns + /// + /// If successful, returns the number of bytes sent + /// + /// Users should ensure that when the remote cannot receive, the + /// [`ErrorKind::WouldBlock`] is properly handled. An error can also occur + /// if the IP version of the socket does not match that of `target`. + /// + /// [`ErrorKind::WouldBlock`]: std::io::ErrorKind::WouldBlock + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// let dst = "127.0.0.1:8081".parse()?; + /// + /// loop { + /// socket.writable().await?; + /// + /// match socket.try_send_to(&b"hello world"[..], dst) { + /// Ok(sent) => { + /// println!("sent {} bytes", sent); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// // Writable false positive. + /// continue; + /// } + /// Err(e) => return Err(e.into()), + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || self.io.send_to(buf, target)) + } + + async fn send_to_addr(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> { + self.io + .registration() + .async_io(Interest::WRITABLE, || self.io.send_to(buf, target)) + .await + } + + /// Receives a single datagram message on the socket. On success, returns + /// the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// let mut buf = vec![0u8; 32]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// + /// println!("received {:?} bytes from {:?}", len, addr); + /// + /// Ok(()) + /// } + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io + .registration() + .async_io(Interest::READABLE, || self.io.recv_from(buf)) + .await + } + + /// Attempts to receive a single datagram on the socket. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<SocketAddr>> { + let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uninitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.recv_from(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(addr)) + } + + /// Tries to receive a single datagram message on the socket. On success, + /// returns the number of bytes read and the origin. + /// + /// The function must be called with valid byte array buf of sufficient size + /// to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io + .registration() + .try_io(Interest::READABLE, || self.io.recv_from(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UdpSocket` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UdpSocket::readable() + /// [`writable()`]: UdpSocket::writable() + /// [`ready()`]: UdpSocket::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + + /// Receives data from the socket, without removing it from the input queue. + /// On success, returns the number of bytes read and the address from whence + /// the data came. + /// + /// # Notes + /// + /// On Windows, if the data is larger than the buffer specified, the buffer + /// is filled with the first part of the data, and peek_from returns the error + /// WSAEMSGSIZE(10040). The excess data is lost. + /// Make sure to always use a sufficiently large buffer to hold the + /// maximum UDP packet size, which can be up to 65536 bytes in size. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// let mut buf = vec![0u8; 32]; + /// let (len, addr) = socket.peek_from(&mut buf).await?; + /// + /// println!("peeked {:?} bytes from {:?}", len, addr); + /// + /// Ok(()) + /// } + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io + .registration() + .async_io(Interest::READABLE, || self.io.peek_from(buf)) + .await + } + + /// Receives data from the socket, without removing it from the input queue. + /// On success, returns the number of bytes read. + /// + /// # Notes + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup + /// + /// On Windows, if the data is larger than the buffer specified, the buffer + /// is filled with the first part of the data, and peek returns the error + /// WSAEMSGSIZE(10040). The excess data is lost. + /// Make sure to always use a sufficiently large buffer to hold the + /// maximum UDP packet size, which can be up to 65536 bytes in size. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_peek_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<SocketAddr>> { + let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uninitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.peek_from(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(addr)) + } + + /// Gets the value of the `SO_BROADCAST` option for this socket. + /// + /// For more information about this option, see [`set_broadcast`]. + /// + /// [`set_broadcast`]: method@Self::set_broadcast + pub fn broadcast(&self) -> io::Result<bool> { + self.io.broadcast() + } + + /// Sets the value of the `SO_BROADCAST` option for this socket. + /// + /// When enabled, this socket is allowed to send packets to a broadcast + /// address. + pub fn set_broadcast(&self, on: bool) -> io::Result<()> { + self.io.set_broadcast(on) + } + + /// Gets the value of the `IP_MULTICAST_LOOP` option for this socket. + /// + /// For more information about this option, see [`set_multicast_loop_v4`]. + /// + /// [`set_multicast_loop_v4`]: method@Self::set_multicast_loop_v4 + pub fn multicast_loop_v4(&self) -> io::Result<bool> { + self.io.multicast_loop_v4() + } + + /// Sets the value of the `IP_MULTICAST_LOOP` option for this socket. + /// + /// If enabled, multicast packets will be looped back to the local socket. + /// + /// # Note + /// + /// This may not have any affect on IPv6 sockets. + pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> { + self.io.set_multicast_loop_v4(on) + } + + /// Gets the value of the `IP_MULTICAST_TTL` option for this socket. + /// + /// For more information about this option, see [`set_multicast_ttl_v4`]. + /// + /// [`set_multicast_ttl_v4`]: method@Self::set_multicast_ttl_v4 + pub fn multicast_ttl_v4(&self) -> io::Result<u32> { + self.io.multicast_ttl_v4() + } + + /// Sets the value of the `IP_MULTICAST_TTL` option for this socket. + /// + /// Indicates the time-to-live value of outgoing multicast packets for + /// this socket. The default value is 1 which means that multicast packets + /// don't leave the local network unless explicitly requested. + /// + /// # Note + /// + /// This may not have any affect on IPv6 sockets. + pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> { + self.io.set_multicast_ttl_v4(ttl) + } + + /// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket. + /// + /// For more information about this option, see [`set_multicast_loop_v6`]. + /// + /// [`set_multicast_loop_v6`]: method@Self::set_multicast_loop_v6 + pub fn multicast_loop_v6(&self) -> io::Result<bool> { + self.io.multicast_loop_v6() + } + + /// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket. + /// + /// Controls whether this socket sees the multicast packets it sends itself. + /// + /// # Note + /// + /// This may not have any affect on IPv4 sockets. + pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> { + self.io.set_multicast_loop_v6(on) + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`]. + /// + /// [`set_ttl`]: method@Self::set_ttl + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::io; + /// + /// # async fn dox() -> io::Result<()> { + /// let sock = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// println!("{:?}", sock.ttl()?); + /// # Ok(()) + /// # } + /// ``` + pub fn ttl(&self) -> io::Result<u32> { + self.io.ttl() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::io; + /// + /// # async fn dox() -> io::Result<()> { + /// let sock = UdpSocket::bind("127.0.0.1:8080").await?; + /// sock.set_ttl(60)?; + /// + /// # Ok(()) + /// # } + /// ``` + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.io.set_ttl(ttl) + } + + /// Executes an operation of the `IP_ADD_MEMBERSHIP` type. + /// + /// This function specifies a new multicast group for this socket to join. + /// The address must be a valid multicast address, and `interface` is the + /// address of the local interface with which the system should join the + /// multicast group. If it's equal to `INADDR_ANY` then an appropriate + /// interface is chosen by the system. + pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { + self.io.join_multicast_v4(&multiaddr, &interface) + } + + /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type. + /// + /// This function specifies a new multicast group for this socket to join. + /// The address must be a valid multicast address, and `interface` is the + /// index of the interface to join/leave (or 0 to indicate any interface). + pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + self.io.join_multicast_v6(multiaddr, interface) + } + + /// Executes an operation of the `IP_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v4`]. + /// + /// [`join_multicast_v4`]: method@Self::join_multicast_v4 + pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { + self.io.leave_multicast_v4(&multiaddr, &interface) + } + + /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v6`]. + /// + /// [`join_multicast_v6`]: method@Self::join_multicast_v6 + pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + self.io.leave_multicast_v6(multiaddr, interface) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// ``` + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Create a socket + /// let socket = UdpSocket::bind("0.0.0.0:8080").await?; + /// + /// if let Ok(Some(err)) = socket.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.io.take_error() + } +} + +impl TryFrom<std::net::UdpSocket> for UdpSocket { + type Error = io::Error; + + /// Consumes stream, returning the tokio I/O object. + /// + /// This is equivalent to + /// [`UdpSocket::from_std(stream)`](UdpSocket::from_std). + fn try_from(stream: std::net::UdpSocket) -> Result<Self, Self::Error> { + Self::from_std(stream) + } +} + +impl fmt::Debug for UdpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.fmt(f) + } +} + +#[cfg(all(unix))] +mod sys { + use super::UdpSocket; + use std::os::unix::prelude::*; + + impl AsRawFd for UdpSocket { + fn as_raw_fd(&self) -> RawFd { + self.io.as_raw_fd() + } + } +} + +#[cfg(windows)] +mod sys { + use super::UdpSocket; + use std::os::windows::prelude::*; + + impl AsRawSocket for UdpSocket { + fn as_raw_socket(&self) -> RawSocket { + self.io.as_raw_socket() + } + } +} diff --git a/third_party/rust/tokio/src/net/unix/datagram/mod.rs b/third_party/rust/tokio/src/net/unix/datagram/mod.rs new file mode 100644 index 0000000000..6268b4ac90 --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/datagram/mod.rs @@ -0,0 +1,3 @@ +//! Unix datagram types. + +pub(crate) mod socket; diff --git a/third_party/rust/tokio/src/net/unix/datagram/socket.rs b/third_party/rust/tokio/src/net/unix/datagram/socket.rs new file mode 100644 index 0000000000..d5b618663d --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/datagram/socket.rs @@ -0,0 +1,1422 @@ +use crate::io::{Interest, PollEvented, ReadBuf, Ready}; +use crate::net::unix::SocketAddr; + +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::net::Shutdown; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::os::unix::net; +use std::path::Path; +use std::task::{Context, Poll}; + +cfg_io_util! { + use bytes::BufMut; +} + +cfg_net_unix! { + /// An I/O object representing a Unix datagram socket. + /// + /// A socket can be either named (associated with a filesystem path) or + /// unnamed. + /// + /// This type does not provide a `split` method, because this functionality + /// can be achieved by wrapping the socket in an [`Arc`]. Note that you do + /// not need a `Mutex` to share the `UnixDatagram` — an `Arc<UnixDatagram>` + /// is enough. This is because all of the methods take `&self` instead of + /// `&mut self`. + /// + /// **Note:** named sockets are persisted even after the object is dropped + /// and the program has exited, and cannot be reconnected. It is advised + /// that you either check for and unlink the existing socket if it exists, + /// or use a temporary file that is guaranteed to not already exist. + /// + /// [`Arc`]: std::sync::Arc + /// + /// # Examples + /// Using named sockets, associated with a filesystem path: + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind each socket to a filesystem path + /// let tx_path = tmp.path().join("tx"); + /// let tx = UnixDatagram::bind(&tx_path)?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &tx_path); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// Using unnamed sockets, created as a pair + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hello world"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub struct UnixDatagram { + io: PollEvented<mio::net::UnixDatagram>, + } +} + +impl UnixDatagram { + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_recv()` or `try_send()`. It + /// can be used to concurrently recv / send to the same socket on a single + /// task without splitting the socket. + /// + /// The function may complete without the socket being ready. This is a + /// false-positive and attempting an operation will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// Concurrently receive from and send to the socket on the same task + /// without splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// let ready = socket.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = [0; 1024]; + /// match socket.try_recv(&mut data[..]) { + /// Ok(n) => { + /// println!("received {:?}", &data[..n]); + /// } + /// // False-positive, continue + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Write some data + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// println!("sent {} bytes", n); + /// } + /// // False-positive, continue + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is + /// usually paired with `try_send()` or `try_send_to()`. + /// + /// The function may complete without the socket being writable. This is a + /// false-positive and attempting a `try_send()` will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write/send readiness. + /// + /// If the socket is not currently ready for sending, this method will + /// store a clone of the `Waker` from the provided `Context`. When the socket + /// becomes ready for sending, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_send_ready` or `poll_send`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_recv_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the socket is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_recv()`. + /// + /// The function may complete without the socket being readable. This is a + /// false-positive and attempting a `try_recv()` will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read/receive readiness. + /// + /// If the socket is not currently ready for receiving, this method will + /// store a clone of the `Waker` from the provided `Context`. When the + /// socket becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_recv_ready`, `poll_recv` or + /// `poll_peek`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. (However, + /// `poll_send_ready` retains a second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the socket is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Creates a new `UnixDatagram` bound to the specified path. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind the socket to a filesystem path + /// let socket_path = tmp.path().join("socket"); + /// let socket = UnixDatagram::bind(&socket_path)?; + /// + /// # Ok(()) + /// # } + /// ``` + pub fn bind<P>(path: P) -> io::Result<UnixDatagram> + where + P: AsRef<Path>, + { + let socket = mio::net::UnixDatagram::bind(path)?; + UnixDatagram::new(socket) + } + + /// Creates an unnamed pair of connected sockets. + /// + /// This function will create a pair of interconnected Unix sockets for + /// communicating back and forth between one another. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hail eris"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> { + let (a, b) = mio::net::UnixDatagram::pair()?; + let a = UnixDatagram::new(a)?; + let b = UnixDatagram::new(b)?; + + Ok((a, b)) + } + + /// Creates new `UnixDatagram` from a `std::os::unix::net::UnixDatagram`. + /// + /// This function is intended to be used to wrap a UnixDatagram from the + /// standard library in the Tokio equivalent. The conversion assumes + /// nothing about the underlying datagram; it is left up to the user to set + /// it in non-blocking mode. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a Tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use std::os::unix::net::UnixDatagram as StdUDS; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind the socket to a filesystem path + /// let socket_path = tmp.path().join("socket"); + /// let std_socket = StdUDS::bind(&socket_path)?; + /// std_socket.set_nonblocking(true)?; + /// let tokio_socket = UnixDatagram::from_std(std_socket)?; + /// + /// # Ok(()) + /// # } + /// ``` + pub fn from_std(datagram: net::UnixDatagram) -> io::Result<UnixDatagram> { + let socket = mio::net::UnixDatagram::from_std(datagram); + let io = PollEvented::new(socket)?; + Ok(UnixDatagram { io }) + } + + /// Turns a [`tokio::net::UnixDatagram`] into a [`std::os::unix::net::UnixDatagram`]. + /// + /// The returned [`std::os::unix::net::UnixDatagram`] will have nonblocking + /// mode set as `true`. Use [`set_nonblocking`] to change the blocking mode + /// if needed. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let tokio_socket = tokio::net::UnixDatagram::bind("127.0.0.1:0")?; + /// let std_socket = tokio_socket.into_std()?; + /// std_socket.set_nonblocking(false)?; + /// Ok(()) + /// } + /// ``` + /// + /// [`tokio::net::UnixDatagram`]: UnixDatagram + /// [`std::os::unix::net::UnixDatagram`]: std::os::unix::net::UnixDatagram + /// [`set_nonblocking`]: fn@std::os::unix::net::UnixDatagram::set_nonblocking + pub fn into_std(self) -> io::Result<std::os::unix::net::UnixDatagram> { + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::os::unix::net::UnixDatagram::from_raw_fd(raw_fd) }) + } + + fn new(socket: mio::net::UnixDatagram) -> io::Result<UnixDatagram> { + let io = PollEvented::new(socket)?; + Ok(UnixDatagram { io }) + } + + /// Creates a new `UnixDatagram` which is not bound to any address. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // Create an unbound socket + /// let tx = UnixDatagram::unbound()?; + /// + /// // Create another, bound socket + /// let tmp = tempdir()?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// // Send to the bound socket + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn unbound() -> io::Result<UnixDatagram> { + let socket = mio::net::UnixDatagram::unbound()?; + UnixDatagram::new(socket) + } + + /// Connects the socket to the specified address. + /// + /// The `send` method may be used to send data to the specified address. + /// `recv` and `recv_from` will only receive data from that address. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // Create an unbound socket + /// let tx = UnixDatagram::unbound()?; + /// + /// // Create another, bound socket + /// let tmp = tempdir()?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// // Connect to the bound socket + /// tx.connect(&rx_path)?; + /// + /// // Send to the bound socket + /// let bytes = b"hello world"; + /// tx.send(bytes).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> { + self.io.connect(path) + } + + /// Sends data on the socket to the socket's peer. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hello world"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .async_io(Interest::WRITABLE, || self.io.send(buf)) + .await + } + + /// Tries to send a datagram to the peer without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || self.io.send(buf)) + } + + /// Tries to send a datagram to the peer without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send_to(b"hello world", &server_path) { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize> + where + P: AsRef<Path>, + { + self.io + .registration() + .try_io(Interest::WRITABLE, || self.io.send_to(buf, target)) + } + + /// Receives data from the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hello world"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .async_io(Interest::READABLE, || self.io.recv(buf)) + .await + } + + /// Tries to receive a datagram from the peer without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || self.io.recv(buf)) + } + + cfg_io_util! { + /// Tries to receive data from the socket without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_buf_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> { + let (n, addr) = self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + + // Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the + // buffer. + let (n, addr) = (&*self.io).recv_from(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok((n, addr)) + })?; + + Ok((n, SocketAddr(addr))) + } + + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_buf(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + + // Safety: We trust `UnixDatagram::recv` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).recv(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + } + + /// Sends data on the socket to the specified address. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send_to` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind each socket to a filesystem path + /// let tx_path = tmp.path().join("tx"); + /// let tx = UnixDatagram::bind(&tx_path)?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &tx_path); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize> + where + P: AsRef<Path>, + { + self.io + .registration() + .async_io(Interest::WRITABLE, || self.io.send_to(buf, target.as_ref())) + .await + } + + /// Receives data from the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind each socket to a filesystem path + /// let tx_path = tmp.path().join("tx"); + /// let tx = UnixDatagram::bind(&tx_path)?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &tx_path); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let (n, addr) = self + .io + .registration() + .async_io(Interest::READABLE, || self.io.recv_from(buf)) + .await?; + + Ok((n, SocketAddr(addr))) + } + + /// Attempts to receive a single datagram on the specified address. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<SocketAddr>> { + let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uninitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.recv_from(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(SocketAddr(addr))) + } + + /// Attempts to send data to the specified address. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_send_to<P>( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: P, + ) -> Poll<io::Result<usize>> + where + P: AsRef<Path>, + { + self.io + .registration() + .poll_write_io(cx, || self.io.send_to(buf, target.as_ref())) + } + + /// Attempts to send data on the socket to the remote address to which it + /// was previously `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, + /// only the `Waker` from the `Context` passed to the most recent call will + /// be scheduled to receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not available to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { + self.io + .registration() + .poll_write_io(cx, || self.io.send(buf)) + } + + /// Attempts to receive a single datagram message on the socket from the remote + /// address to which it is `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. This method + /// resolves to an error if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(()))` reads data `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { + let n = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uninitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.recv(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(())) + } + + /// Tries to receive data from the socket without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let (n, addr) = self + .io + .registration() + .try_io(Interest::READABLE, || self.io.recv_from(buf))?; + + Ok((n, SocketAddr(addr))) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UnixDatagram` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixDatagram::readable() + /// [`writable()`]: UnixDatagram::writable() + /// [`ready()`]: UnixDatagram::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + + /// Returns the local address that this socket is bound to. + /// + /// # Examples + /// For a socket bound to a local path + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind socket to a filesystem path + /// let socket_path = tmp.path().join("socket"); + /// let socket = UnixDatagram::bind(&socket_path)?; + /// + /// assert_eq!(socket.local_addr()?.as_pathname().unwrap(), &socket_path); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// For an unbound socket + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create an unbound socket + /// let socket = UnixDatagram::unbound()?; + /// + /// assert!(socket.local_addr()?.is_unnamed()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.local_addr().map(SocketAddr) + } + + /// Returns the address of this socket's peer. + /// + /// The `connect` method will connect the socket to a peer. + /// + /// # Examples + /// For a peer with a local path + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // Create an unbound socket + /// let tx = UnixDatagram::unbound()?; + /// + /// // Create another, bound socket + /// let tmp = tempdir()?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// // Connect to the bound socket + /// tx.connect(&rx_path)?; + /// + /// assert_eq!(tx.peer_addr()?.as_pathname().unwrap(), &rx_path); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// For an unbound peer + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// assert!(sock1.peer_addr()?.is_unnamed()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.io.peer_addr().map(SocketAddr) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create an unbound socket + /// let socket = UnixDatagram::unbound()?; + /// + /// if let Ok(Some(err)) = socket.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// + /// # Ok(()) + /// # } + /// ``` + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.io.take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use std::net::Shutdown; + /// + /// // Create an unbound socket + /// let (socket, other) = UnixDatagram::pair()?; + /// + /// socket.shutdown(Shutdown::Both)?; + /// + /// // NOTE: the following commented out code does NOT work as expected. + /// // Due to an underlying issue, the recv call will block indefinitely. + /// // See: https://github.com/tokio-rs/tokio/issues/1679 + /// //let mut buff = vec![0u8; 24]; + /// //let size = socket.recv(&mut buff).await?; + /// //assert_eq!(size, 0); + /// + /// let send_result = socket.send(b"hello world").await; + /// assert!(send_result.is_err()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.io.shutdown(how) + } +} + +impl TryFrom<std::os::unix::net::UnixDatagram> for UnixDatagram { + type Error = io::Error; + + /// Consumes stream, returning the Tokio I/O object. + /// + /// This is equivalent to + /// [`UnixDatagram::from_std(stream)`](UnixDatagram::from_std). + fn try_from(stream: std::os::unix::net::UnixDatagram) -> Result<Self, Self::Error> { + Self::from_std(stream) + } +} + +impl fmt::Debug for UnixDatagram { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.fmt(f) + } +} + +impl AsRawFd for UnixDatagram { + fn as_raw_fd(&self) -> RawFd { + self.io.as_raw_fd() + } +} diff --git a/third_party/rust/tokio/src/net/unix/listener.rs b/third_party/rust/tokio/src/net/unix/listener.rs new file mode 100644 index 0000000000..1785f8b0f7 --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/listener.rs @@ -0,0 +1,186 @@ +use crate::io::{Interest, PollEvented}; +use crate::net::unix::{SocketAddr, UnixStream}; + +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::os::unix::net; +use std::path::Path; +use std::task::{Context, Poll}; + +cfg_net_unix! { + /// A Unix socket which can accept connections from other Unix sockets. + /// + /// You can accept a new connection by using the [`accept`](`UnixListener::accept`) method. + /// + /// A `UnixListener` can be turned into a `Stream` with [`UnixListenerStream`]. + /// + /// [`UnixListenerStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.UnixListenerStream.html + /// + /// # Errors + /// + /// Note that accepting a connection can lead to various errors and not all + /// of them are necessarily fatal ‒ for example having too many open file + /// descriptors or the other side closing the connection while it waits in + /// an accept queue. These would terminate the stream if not handled in any + /// way. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixListener; + /// + /// #[tokio::main] + /// async fn main() { + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// loop { + /// match listener.accept().await { + /// Ok((stream, _addr)) => { + /// println!("new client!"); + /// } + /// Err(e) => { /* connection failed */ } + /// } + /// } + /// } + /// ``` + pub struct UnixListener { + io: PollEvented<mio::net::UnixListener>, + } +} + +impl UnixListener { + /// Creates a new `UnixListener` bound to the specified path. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn bind<P>(path: P) -> io::Result<UnixListener> + where + P: AsRef<Path>, + { + let listener = mio::net::UnixListener::bind(path)?; + let io = PollEvented::new(listener)?; + Ok(UnixListener { io }) + } + + /// Creates new `UnixListener` from a `std::os::unix::net::UnixListener `. + /// + /// This function is intended to be used to wrap a UnixListener from the + /// standard library in the Tokio equivalent. The conversion assumes + /// nothing about the underlying listener; it is left up to the user to set + /// it in non-blocking mode. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn from_std(listener: net::UnixListener) -> io::Result<UnixListener> { + let listener = mio::net::UnixListener::from_std(listener); + let io = PollEvented::new(listener)?; + Ok(UnixListener { io }) + } + + /// Turns a [`tokio::net::UnixListener`] into a [`std::os::unix::net::UnixListener`]. + /// + /// The returned [`std::os::unix::net::UnixListener`] will have nonblocking mode + /// set as `true`. Use [`set_nonblocking`] to change the blocking mode if needed. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::error::Error; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let tokio_listener = tokio::net::UnixListener::bind("127.0.0.1:0")?; + /// let std_listener = tokio_listener.into_std()?; + /// std_listener.set_nonblocking(false)?; + /// Ok(()) + /// } + /// ``` + /// + /// [`tokio::net::UnixListener`]: UnixListener + /// [`std::os::unix::net::UnixListener`]: std::os::unix::net::UnixListener + /// [`set_nonblocking`]: fn@std::os::unix::net::UnixListener::set_nonblocking + pub fn into_std(self) -> io::Result<std::os::unix::net::UnixListener> { + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { net::UnixListener::from_raw_fd(raw_fd) }) + } + + /// Returns the local socket address of this listener. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.local_addr().map(SocketAddr) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.io.take_error() + } + + /// Accepts a new incoming connection to this listener. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. + pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let (mio, addr) = self + .io + .registration() + .async_io(Interest::READABLE, || self.io.accept()) + .await?; + + let addr = SocketAddr(addr); + let stream = UnixStream::new(mio)?; + Ok((stream, addr)) + } + + /// Polls to accept a new incoming connection to this listener. + /// + /// If there is no connection to accept, `Poll::Pending` is returned and the + /// current task will be notified by a waker. Note that on multiple calls + /// to `poll_accept`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UnixStream, SocketAddr)>> { + let (sock, addr) = ready!(self.io.registration().poll_read_io(cx, || self.io.accept()))?; + let addr = SocketAddr(addr); + let sock = UnixStream::new(sock)?; + Poll::Ready(Ok((sock, addr))) + } +} + +impl TryFrom<std::os::unix::net::UnixListener> for UnixListener { + type Error = io::Error; + + /// Consumes stream, returning the tokio I/O object. + /// + /// This is equivalent to + /// [`UnixListener::from_std(stream)`](UnixListener::from_std). + fn try_from(stream: std::os::unix::net::UnixListener) -> io::Result<Self> { + Self::from_std(stream) + } +} + +impl fmt::Debug for UnixListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.fmt(f) + } +} + +impl AsRawFd for UnixListener { + fn as_raw_fd(&self) -> RawFd { + self.io.as_raw_fd() + } +} diff --git a/third_party/rust/tokio/src/net/unix/mod.rs b/third_party/rust/tokio/src/net/unix/mod.rs new file mode 100644 index 0000000000..14cb456705 --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/mod.rs @@ -0,0 +1,24 @@ +//! Unix domain socket utility types. + +// This module does not currently provide any public API, but it was +// unintentionally defined as a public module. Hide it from the documentation +// instead of changing it to a private module to avoid breakage. +#[doc(hidden)] +pub mod datagram; + +pub(crate) mod listener; + +mod split; +pub use split::{ReadHalf, WriteHalf}; + +mod split_owned; +pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; + +mod socketaddr; +pub use socketaddr::SocketAddr; + +pub(crate) mod stream; +pub(crate) use stream::UnixStream; + +mod ucred; +pub use ucred::UCred; diff --git a/third_party/rust/tokio/src/net/unix/socketaddr.rs b/third_party/rust/tokio/src/net/unix/socketaddr.rs new file mode 100644 index 0000000000..48f7b96b8c --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/socketaddr.rs @@ -0,0 +1,31 @@ +use std::fmt; +use std::path::Path; + +/// An address associated with a Tokio Unix socket. +pub struct SocketAddr(pub(super) mio::net::SocketAddr); + +impl SocketAddr { + /// Returns `true` if the address is unnamed. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn is_unnamed(&self) -> bool { + self.0.is_unnamed() + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn as_pathname(&self) -> Option<&Path> { + self.0.as_pathname() + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(fmt) + } +} diff --git a/third_party/rust/tokio/src/net/unix/split.rs b/third_party/rust/tokio/src/net/unix/split.rs new file mode 100644 index 0000000000..d4686c22d7 --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/split.rs @@ -0,0 +1,305 @@ +//! `UnixStream` split support. +//! +//! A `UnixStream` can be split into a read half and a write half with +//! `UnixStream::split`. The read half implements `AsyncRead` while the write +//! half implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; +use crate::net::UnixStream; + +use crate::net::unix::SocketAddr; +use std::io; +use std::net::Shutdown; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + use bytes::BufMut; +} + +/// Borrowed read half of a [`UnixStream`], created by [`split`]. +/// +/// Reading from a `ReadHalf` is usually done using the convenience methods found on the +/// [`AsyncReadExt`] trait. +/// +/// [`UnixStream`]: UnixStream +/// [`split`]: UnixStream::split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +#[derive(Debug)] +pub struct ReadHalf<'a>(&'a UnixStream); + +/// Borrowed write half of a [`UnixStream`], created by [`split`]. +/// +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will +/// shut down the UnixStream stream in the write direction. +/// +/// Writing to an `WriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. +/// +/// [`UnixStream`]: UnixStream +/// [`split`]: UnixStream::split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +#[derive(Debug)] +pub struct WriteHalf<'a>(&'a UnixStream); + +pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) { + (ReadHalf(stream), WriteHalf(stream)) +} + +impl ReadHalf<'_> { + /// Wait for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.0.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.0.try_read(buf) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.0.try_read_buf(buf) + } + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.0.try_read_vectored(bufs) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + +impl WriteHalf<'_> { + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.0.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.0.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.0.try_write_vectored(buf) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + +impl AsyncRead for ReadHalf<'_> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.0.poll_read_priv(cx, buf) + } +} + +impl AsyncWrite for WriteHalf<'_> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.0.poll_write_priv(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.0.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + self.0.shutdown_std(Shutdown::Write).into() + } +} + +impl AsRef<UnixStream> for ReadHalf<'_> { + fn as_ref(&self) -> &UnixStream { + self.0 + } +} + +impl AsRef<UnixStream> for WriteHalf<'_> { + fn as_ref(&self) -> &UnixStream { + self.0 + } +} diff --git a/third_party/rust/tokio/src/net/unix/split_owned.rs b/third_party/rust/tokio/src/net/unix/split_owned.rs new file mode 100644 index 0000000000..9c3a2a4177 --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/split_owned.rs @@ -0,0 +1,393 @@ +//! `UnixStream` owned split support. +//! +//! A `UnixStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` +//! with the `UnixStream::into_split` method. `OwnedReadHalf` implements +//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; +use crate::net::UnixStream; + +use crate::net::unix::SocketAddr; +use std::error::Error; +use std::net::Shutdown; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +cfg_io_util! { + use bytes::BufMut; +} + +/// Owned read half of a [`UnixStream`], created by [`into_split`]. +/// +/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found +/// on the [`AsyncReadExt`] trait. +/// +/// [`UnixStream`]: crate::net::UnixStream +/// [`into_split`]: crate::net::UnixStream::into_split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +#[derive(Debug)] +pub struct OwnedReadHalf { + inner: Arc<UnixStream>, +} + +/// Owned write half of a [`UnixStream`], created by [`into_split`]. +/// +/// Note that in the [`AsyncWrite`] implementation of this type, +/// [`poll_shutdown`] will shut down the stream in the write direction. +/// Dropping the write half will also shut down the write half of the stream. +/// +/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods +/// found on the [`AsyncWriteExt`] trait. +/// +/// [`UnixStream`]: crate::net::UnixStream +/// [`into_split`]: crate::net::UnixStream::into_split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +#[derive(Debug)] +pub struct OwnedWriteHalf { + inner: Arc<UnixStream>, + shutdown_on_drop: bool, +} + +pub(crate) fn split_owned(stream: UnixStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + shutdown_on_drop: true, + }; + (read, write) +} + +pub(crate) fn reunite( + read: OwnedReadHalf, + write: OwnedWriteHalf, +) -> Result<UnixStream, ReuniteError> { + if Arc::ptr_eq(&read.inner, &write.inner) { + write.forget(); + // This unwrap cannot fail as the api does not allow creating more than two Arcs, + // and we just dropped the other half. + Ok(Arc::try_unwrap(read.inner).expect("UnixStream: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(read, write)) + } +} + +/// Error indicating that two halves were not from the same socket, and thus could +/// not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `UnixStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: crate::net::UnixStream::into_split() + pub fn reunite(self, other: OwnedWriteHalf) -> Result<UnixStream, ReuniteError> { + reunite(self, other) + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.inner.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.try_read(buf) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.inner.try_read_buf(buf) + } + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.inner.try_read_vectored(bufs) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } +} + +impl AsyncRead for OwnedReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.inner.poll_read_priv(cx, buf) + } +} + +impl OwnedWriteHalf { + /// Attempts to put the two halves of a `UnixStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: crate::net::UnixStream::into_split() + pub fn reunite(self, other: OwnedReadHalf) -> Result<UnixStream, ReuniteError> { + reunite(other, self) + } + + /// Destroys the write half, but don't close the write half of the stream + /// until the read half is dropped. If the read half has already been + /// dropped, this closes the stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.inner.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.inner.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.inner.try_write_vectored(buf) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } +} + +impl Drop for OwnedWriteHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown_std(Shutdown::Write); + } + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.inner.poll_write_priv(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.inner.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + let res = self.inner.shutdown_std(Shutdown::Write); + if res.is_ok() { + Pin::into_inner(self).shutdown_on_drop = false; + } + res.into() + } +} + +impl AsRef<UnixStream> for OwnedReadHalf { + fn as_ref(&self) -> &UnixStream { + &*self.inner + } +} + +impl AsRef<UnixStream> for OwnedWriteHalf { + fn as_ref(&self) -> &UnixStream { + &*self.inner + } +} diff --git a/third_party/rust/tokio/src/net/unix/stream.rs b/third_party/rust/tokio/src/net/unix/stream.rs new file mode 100644 index 0000000000..4e7ef87b41 --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/stream.rs @@ -0,0 +1,960 @@ +use crate::future::poll_fn; +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; +use crate::net::unix::split::{split, ReadHalf, WriteHalf}; +use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; +use crate::net::unix::ucred::{self, UCred}; +use crate::net::unix::SocketAddr; + +use std::convert::TryFrom; +use std::fmt; +use std::io::{self, Read, Write}; +use std::net::Shutdown; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::os::unix::net; +use std::path::Path; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + use bytes::BufMut; +} + +cfg_net_unix! { + /// A structure representing a connected Unix socket. + /// + /// This socket can be connected directly with `UnixStream::connect` or accepted + /// from a listener with `UnixListener::incoming`. Additionally, a pair of + /// anonymous Unix sockets can be created with `UnixStream::pair`. + /// + /// To shut down the stream in the write direction, you can call the + /// [`shutdown()`] method. This will cause the other peer to receive a read of + /// length 0, indicating that no more data will be sent. This only closes + /// the stream in one direction. + /// + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown + pub struct UnixStream { + io: PollEvented<mio::net::UnixStream>, + } +} + +impl UnixStream { + /// Connects to the socket named by `path`. + /// + /// This function will create a new Unix socket and connect to the path + /// specified, associating the returned stream with the default event loop's + /// handle. + pub async fn connect<P>(path: P) -> io::Result<UnixStream> + where + P: AsRef<Path>, + { + let stream = mio::net::UnixStream::connect(path)?; + let stream = UnixStream::new(stream)?; + + poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; + + if let Some(e) = stream.io.take_error()? { + return Err(e); + } + + Ok(stream) + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// Concurrently read and write to the stream on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// let ready = stream.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the unix stream is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the unix + /// stream becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the unix stream is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the unix stream is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Try to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`ready()`]: UnixStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: UnixStream::try_read() + /// [`readable()`]: UnixStream::readable() + /// [`ready()`]: UnixStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`ready()`]: UnixStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// let mut buf = Vec::with_capacity(4096); + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_buf(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.io.registration().try_io(Interest::READABLE, || { + use std::io::Read; + + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; + + // Safety: We trust `UnixStream::read` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).read(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the unix stream is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the unix + /// stream becomes ready for writing, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the unix stream is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the unix stream is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: UnixStream::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UnixStream` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`writable()`]: UnixStream::writable() + /// [`ready()`]: UnixStream::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + + /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`. + /// + /// This function is intended to be used to wrap a UnixStream from the + /// standard library in the Tokio equivalent. The conversion assumes + /// nothing about the underlying stream; it is left up to the user to set + /// it in non-blocking mode. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn from_std(stream: net::UnixStream) -> io::Result<UnixStream> { + let stream = mio::net::UnixStream::from_std(stream); + let io = PollEvented::new(stream)?; + + Ok(UnixStream { io }) + } + + /// Turns a [`tokio::net::UnixStream`] into a [`std::os::unix::net::UnixStream`]. + /// + /// The returned [`std::os::unix::net::UnixStream`] will have nonblocking + /// mode set as `true`. Use [`set_nonblocking`] to change the blocking + /// mode if needed. + /// + /// # Examples + /// + /// ``` + /// use std::error::Error; + /// use std::io::Read; + /// use tokio::net::UnixListener; + /// # use tokio::net::UnixStream; + /// # use tokio::io::AsyncWriteExt; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// + /// let mut data = [0u8; 12]; + /// let listener = UnixListener::bind(&bind_path)?; + /// # let handle = tokio::spawn(async { + /// # let mut stream = UnixStream::connect(bind_path).await.unwrap(); + /// # stream.write(b"Hello world!").await.unwrap(); + /// # }); + /// let (tokio_unix_stream, _) = listener.accept().await?; + /// let mut std_unix_stream = tokio_unix_stream.into_std()?; + /// # handle.await.expect("The task being joined has panicked"); + /// std_unix_stream.set_nonblocking(false)?; + /// std_unix_stream.read_exact(&mut data)?; + /// # assert_eq!(b"Hello world!", &data); + /// Ok(()) + /// } + /// ``` + /// [`tokio::net::UnixStream`]: UnixStream + /// [`std::os::unix::net::UnixStream`]: std::os::unix::net::UnixStream + /// [`set_nonblocking`]: fn@std::os::unix::net::UnixStream::set_nonblocking + pub fn into_std(self) -> io::Result<std::os::unix::net::UnixStream> { + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::os::unix::net::UnixStream::from_raw_fd(raw_fd) }) + } + + /// Creates an unnamed pair of connected sockets. + /// + /// This function will create a pair of interconnected Unix sockets for + /// communicating back and forth between one another. Each socket will + /// be associated with the default event loop's handle. + pub fn pair() -> io::Result<(UnixStream, UnixStream)> { + let (a, b) = mio::net::UnixStream::pair()?; + let a = UnixStream::new(a)?; + let b = UnixStream::new(b)?; + + Ok((a, b)) + } + + pub(crate) fn new(stream: mio::net::UnixStream) -> io::Result<UnixStream> { + let io = PollEvented::new(stream)?; + Ok(UnixStream { io }) + } + + /// Returns the socket address of the local half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// println!("{:?}", stream.local_addr()?); + /// # Ok(()) + /// # } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.local_addr().map(SocketAddr) + } + + /// Returns the socket address of the remote half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// println!("{:?}", stream.peer_addr()?); + /// # Ok(()) + /// # } + /// ``` + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.io.peer_addr().map(SocketAddr) + } + + /// Returns effective credentials of the process which called `connect` or `pair`. + pub fn peer_cred(&self) -> io::Result<UCred> { + ucred::get_peer_cred(self) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.io.take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> { + self.io.shutdown(how) + } + + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] + /// Splits a `UnixStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// This method is more efficient than [`into_split`], but the halves cannot be + /// moved into independently spawned tasks. + /// + /// [`into_split`]: Self::into_split() + pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) { + split(self) + } + + /// Splits a `UnixStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, however + /// this comes at the cost of a heap allocation. + /// + /// **Note:** Dropping the write half will shut down the write half of the + /// stream. This is equivalent to calling [`shutdown()`] on the `UnixStream`. + /// + /// [`split`]: Self::split() + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + split_owned(self) + } +} + +impl TryFrom<net::UnixStream> for UnixStream { + type Error = io::Error; + + /// Consumes stream, returning the tokio I/O object. + /// + /// This is equivalent to + /// [`UnixStream::from_std(stream)`](UnixStream::from_std). + fn try_from(stream: net::UnixStream) -> io::Result<Self> { + Self::from_std(stream) + } +} + +impl AsyncRead for UnixStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.poll_read_priv(cx, buf) + } +} + +impl AsyncWrite for UnixStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.poll_write_priv(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.poll_write_vectored_priv(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + self.shutdown_std(std::net::Shutdown::Write)?; + Poll::Ready(Ok(())) + } +} + +impl UnixStream { + // == Poll IO functions that takes `&self` == + // + // To read or write without mutable access to the `UnixStream`, combine the + // `poll_read_ready` or `poll_write_ready` methods with the `try_read` or + // `try_write` methods. + + pub(crate) fn poll_read_priv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // Safety: `UnixStream::read` correctly handles reads into uninitialized memory + unsafe { self.io.poll_read(cx, buf) } + } + + pub(crate) fn poll_write_priv( + &self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + pub(super) fn poll_write_vectored_priv( + &self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } +} + +impl fmt::Debug for UnixStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.fmt(f) + } +} + +impl AsRawFd for UnixStream { + fn as_raw_fd(&self) -> RawFd { + self.io.as_raw_fd() + } +} diff --git a/third_party/rust/tokio/src/net/unix/ucred.rs b/third_party/rust/tokio/src/net/unix/ucred.rs new file mode 100644 index 0000000000..865303b4ce --- /dev/null +++ b/third_party/rust/tokio/src/net/unix/ucred.rs @@ -0,0 +1,252 @@ +use libc::{gid_t, pid_t, uid_t}; + +/// Credentials of a process. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub struct UCred { + /// PID (process ID) of the process. + pid: Option<pid_t>, + /// UID (user ID) of the process. + uid: uid_t, + /// GID (group ID) of the process. + gid: gid_t, +} + +impl UCred { + /// Gets UID (user ID) of the process. + pub fn uid(&self) -> uid_t { + self.uid + } + + /// Gets GID (group ID) of the process. + pub fn gid(&self) -> gid_t { + self.gid + } + + /// Gets PID (process ID) of the process. + /// + /// This is only implemented under Linux, Android, iOS, macOS, Solaris and + /// Illumos. On other platforms this will always return `None`. + pub fn pid(&self) -> Option<pid_t> { + self.pid + } +} + +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] +pub(crate) use self::impl_linux::get_peer_cred; + +#[cfg(any(target_os = "netbsd"))] +pub(crate) use self::impl_netbsd::get_peer_cred; + +#[cfg(any(target_os = "dragonfly", target_os = "freebsd"))] +pub(crate) use self::impl_bsd::get_peer_cred; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub(crate) use self::impl_macos::get_peer_cred; + +#[cfg(any(target_os = "solaris", target_os = "illumos"))] +pub(crate) use self::impl_solaris::get_peer_cred; + +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] +pub(crate) mod impl_linux { + use crate::net::unix::UnixStream; + + use libc::{c_void, getsockopt, socklen_t, SOL_SOCKET, SO_PEERCRED}; + use std::{io, mem}; + + #[cfg(target_os = "openbsd")] + use libc::sockpeercred as ucred; + #[cfg(any(target_os = "linux", target_os = "android"))] + use libc::ucred; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + use std::os::unix::io::AsRawFd; + + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut ucred = ucred { + pid: 0, + uid: 0, + gid: 0, + }; + + let ucred_size = mem::size_of::<ucred>(); + + // These paranoid checks should be optimized-out + assert!(mem::size_of::<u32>() <= mem::size_of::<usize>()); + assert!(ucred_size <= u32::MAX as usize); + + let mut ucred_size = ucred_size as socklen_t; + + let ret = getsockopt( + raw_fd, + SOL_SOCKET, + SO_PEERCRED, + &mut ucred as *mut ucred as *mut c_void, + &mut ucred_size, + ); + if ret == 0 && ucred_size as usize == mem::size_of::<ucred>() { + Ok(super::UCred { + uid: ucred.uid, + gid: ucred.gid, + pid: Some(ucred.pid), + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +#[cfg(any(target_os = "netbsd"))] +pub(crate) mod impl_netbsd { + use crate::net::unix::UnixStream; + + use libc::{c_void, getsockopt, socklen_t, unpcbid, LOCAL_PEEREID, SOL_SOCKET}; + use std::io; + use std::mem::size_of; + use std::os::unix::io::AsRawFd; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut unpcbid = unpcbid { + unp_pid: 0, + unp_euid: 0, + unp_egid: 0, + }; + + let unpcbid_size = size_of::<unpcbid>(); + let mut unpcbid_size = unpcbid_size as socklen_t; + + let ret = getsockopt( + raw_fd, + SOL_SOCKET, + LOCAL_PEEREID, + &mut unpcbid as *mut unpcbid as *mut c_void, + &mut unpcbid_size, + ); + if ret == 0 && unpcbid_size as usize == size_of::<unpcbid>() { + Ok(super::UCred { + uid: unpcbid.unp_euid, + gid: unpcbid.unp_egid, + pid: Some(unpcbid.unp_pid), + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +#[cfg(any(target_os = "dragonfly", target_os = "freebsd"))] +pub(crate) mod impl_bsd { + use crate::net::unix::UnixStream; + + use libc::getpeereid; + use std::io; + use std::mem::MaybeUninit; + use std::os::unix::io::AsRawFd; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut uid = MaybeUninit::uninit(); + let mut gid = MaybeUninit::uninit(); + + let ret = getpeereid(raw_fd, uid.as_mut_ptr(), gid.as_mut_ptr()); + + if ret == 0 { + Ok(super::UCred { + uid: uid.assume_init(), + gid: gid.assume_init(), + pid: None, + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub(crate) mod impl_macos { + use crate::net::unix::UnixStream; + + use libc::{c_void, getpeereid, getsockopt, pid_t, LOCAL_PEEREPID, SOL_LOCAL}; + use std::io; + use std::mem::size_of; + use std::mem::MaybeUninit; + use std::os::unix::io::AsRawFd; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut uid = MaybeUninit::uninit(); + let mut gid = MaybeUninit::uninit(); + let mut pid: MaybeUninit<pid_t> = MaybeUninit::uninit(); + let mut pid_size: MaybeUninit<u32> = MaybeUninit::new(size_of::<pid_t>() as u32); + + if getsockopt( + raw_fd, + SOL_LOCAL, + LOCAL_PEEREPID, + pid.as_mut_ptr() as *mut c_void, + pid_size.as_mut_ptr(), + ) != 0 + { + return Err(io::Error::last_os_error()); + } + + assert!(pid_size.assume_init() == (size_of::<pid_t>() as u32)); + + let ret = getpeereid(raw_fd, uid.as_mut_ptr(), gid.as_mut_ptr()); + + if ret == 0 { + Ok(super::UCred { + uid: uid.assume_init(), + gid: gid.assume_init(), + pid: Some(pid.assume_init()), + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +#[cfg(any(target_os = "solaris", target_os = "illumos"))] +pub(crate) mod impl_solaris { + use crate::net::unix::UnixStream; + use std::io; + use std::os::unix::io::AsRawFd; + use std::ptr; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut cred = ptr::null_mut(); + let ret = libc::getpeerucred(raw_fd, &mut cred); + + if ret == 0 { + let uid = libc::ucred_geteuid(cred); + let gid = libc::ucred_getegid(cred); + let pid = libc::ucred_getpid(cred); + + libc::ucred_free(cred); + + Ok(super::UCred { + uid, + gid, + pid: Some(pid), + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} diff --git a/third_party/rust/tokio/src/net/windows/mod.rs b/third_party/rust/tokio/src/net/windows/mod.rs new file mode 100644 index 0000000000..060b68e663 --- /dev/null +++ b/third_party/rust/tokio/src/net/windows/mod.rs @@ -0,0 +1,3 @@ +//! Windows specific network types. + +pub mod named_pipe; diff --git a/third_party/rust/tokio/src/net/windows/named_pipe.rs b/third_party/rust/tokio/src/net/windows/named_pipe.rs new file mode 100644 index 0000000000..550fd4df2b --- /dev/null +++ b/third_party/rust/tokio/src/net/windows/named_pipe.rs @@ -0,0 +1,2250 @@ +//! Tokio support for [Windows named pipes]. +//! +//! [Windows named pipes]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes + +use std::ffi::c_void; +use std::ffi::OsStr; +use std::io::{self, Read, Write}; +use std::pin::Pin; +use std::ptr; +use std::task::{Context, Poll}; + +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; +use crate::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle}; + +// Hide imports which are not used when generating documentation. +#[cfg(not(docsrs))] +mod doc { + pub(super) use crate::os::windows::ffi::OsStrExt; + pub(super) use crate::winapi::shared::minwindef::{DWORD, FALSE}; + pub(super) use crate::winapi::um::fileapi; + pub(super) use crate::winapi::um::handleapi; + pub(super) use crate::winapi::um::namedpipeapi; + pub(super) use crate::winapi::um::winbase; + pub(super) use crate::winapi::um::winnt; + + pub(super) use mio::windows as mio_windows; +} + +// NB: none of these shows up in public API, so don't document them. +#[cfg(docsrs)] +mod doc { + pub type DWORD = crate::doc::NotDefinedHere; + + pub(super) mod mio_windows { + pub type NamedPipe = crate::doc::NotDefinedHere; + } +} + +use self::doc::*; + +/// A [Windows named pipe] server. +/// +/// Accepting client connections involves creating a server with +/// [`ServerOptions::create`] and waiting for clients to connect using +/// [`NamedPipeServer::connect`]. +/// +/// To avoid having clients sporadically fail with +/// [`std::io::ErrorKind::NotFound`] when they connect to a server, we must +/// ensure that at least one server instance is available at all times. This +/// means that the typical listen loop for a server is a bit involved, because +/// we have to ensure that we never drop a server accidentally while a client +/// might connect. +/// +/// So a correctly implemented server looks like this: +/// +/// ```no_run +/// use std::io; +/// use tokio::net::windows::named_pipe::ServerOptions; +/// +/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-server"; +/// +/// # #[tokio::main] async fn main() -> std::io::Result<()> { +/// // The first server needs to be constructed early so that clients can +/// // be correctly connected. Otherwise calling .wait will cause the client to +/// // error. +/// // +/// // Here we also make use of `first_pipe_instance`, which will ensure that +/// // there are no other servers up and running already. +/// let mut server = ServerOptions::new() +/// .first_pipe_instance(true) +/// .create(PIPE_NAME)?; +/// +/// // Spawn the server loop. +/// let server = tokio::spawn(async move { +/// loop { +/// // Wait for a client to connect. +/// let connected = server.connect().await?; +/// +/// // Construct the next server to be connected before sending the one +/// // we already have of onto a task. This ensures that the server +/// // isn't closed (after it's done in the task) before a new one is +/// // available. Otherwise the client might error with +/// // `io::ErrorKind::NotFound`. +/// server = ServerOptions::new().create(PIPE_NAME)?; +/// +/// let client = tokio::spawn(async move { +/// /* use the connected client */ +/// # Ok::<_, std::io::Error>(()) +/// }); +/// # if true { break } // needed for type inference to work +/// } +/// +/// Ok::<_, io::Error>(()) +/// }); +/// +/// /* do something else not server related here */ +/// # Ok(()) } +/// ``` +/// +/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY +/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes +#[derive(Debug)] +pub struct NamedPipeServer { + io: PollEvented<mio_windows::NamedPipe>, +} + +impl NamedPipeServer { + /// Constructs a new named pipe server from the specified raw handle. + /// + /// This function will consume ownership of the handle given, passing + /// responsibility for closing the handle to the returned object. + /// + /// This function is also unsafe as the primitives currently returned have + /// the contract that they are the sole owner of the file descriptor they + /// are wrapping. Usage of this function could accidentally allow violating + /// this contract which can cause memory unsafety in code that relies on it + /// being true. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> { + let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle); + + Ok(Self { + io: PollEvented::new(named_pipe)?, + }) + } + + /// Retrieves information about the named pipe the server is associated + /// with. + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::{PipeEnd, PipeMode, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-info"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new() + /// .pipe_mode(PipeMode::Message) + /// .max_instances(5) + /// .create(PIPE_NAME)?; + /// + /// let server_info = server.info()?; + /// + /// assert_eq!(server_info.end, PipeEnd::Server); + /// assert_eq!(server_info.mode, PipeMode::Message); + /// assert_eq!(server_info.max_instances, 5); + /// # Ok(()) } + /// ``` + pub fn info(&self) -> io::Result<PipeInfo> { + // Safety: we're ensuring the lifetime of the named pipe. + unsafe { named_pipe_info(self.io.as_raw_handle()) } + } + + /// Enables a named pipe server process to wait for a client process to + /// connect to an instance of a named pipe. A client process connects by + /// creating a named pipe with the same name. + /// + /// This corresponds to the [`ConnectNamedPipe`] system call. + /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as the + /// event in a [`select!`](crate::select) statement and some other branch + /// completes first, then no connection events have been lost. + /// + /// [`ConnectNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-connectnamedpipe + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let pipe = ServerOptions::new().create(PIPE_NAME)?; + /// + /// // Wait for a client to connect. + /// pipe.connect().await?; + /// + /// // Use the connected client... + /// # Ok(()) } + /// ``` + pub async fn connect(&self) -> io::Result<()> { + loop { + match self.io.connect() { + Ok(()) => break, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.registration().readiness(Interest::WRITABLE).await?; + } + Err(e) => return Err(e), + } + } + + Ok(()) + } + + /// Disconnects the server end of a named pipe instance from a client + /// process. + /// + /// ``` + /// use tokio::io::AsyncWriteExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-disconnect"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .open(PIPE_NAME)?; + /// + /// // Wait for a client to become connected. + /// server.connect().await?; + /// + /// // Forcibly disconnect the client. + /// server.disconnect()?; + /// + /// // Write fails with an OS-specific error after client has been + /// // disconnected. + /// let e = client.write(b"ping").await.unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_NOT_CONNECTED as i32)); + /// # Ok(()) } + /// ``` + pub fn disconnect(&self) -> io::Result<()> { + self.io.disconnect() + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same pipe on a single + /// task without splitting the pipe. + /// + /// # Examples + /// + /// Concurrently read and write to the pipe on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-ready"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// let ready = server.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the pipe to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the pipe is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for reading, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Tries to read data from the pipe into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-read"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Tries to read data from the pipe into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: NamedPipeServer::try_read() + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-read-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + /// Waits for the pipe to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-writable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the pipe is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for writing, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Tries to write a buffer to the pipe, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-write"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Tries to write several buffers to the pipe, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: NamedPipeServer::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-write-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the + /// methods defined on the Tokio `NamedPipeServer` type, as this will mess with + /// the readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`writable()`]: NamedPipeServer::writable() + /// [`ready()`]: NamedPipeServer::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } +} + +impl AsyncRead for NamedPipeServer { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unsafe { self.io.poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipeServer { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsRawHandle for NamedPipeServer { + fn as_raw_handle(&self) -> RawHandle { + self.io.as_raw_handle() + } +} + +/// A [Windows named pipe] client. +/// +/// Constructed using [`ClientOptions::open`]. +/// +/// Connecting a client correctly involves a few steps. When connecting through +/// [`ClientOptions::open`], it might error indicating one of two things: +/// +/// * [`std::io::ErrorKind::NotFound`] - There is no server available. +/// * [`ERROR_PIPE_BUSY`] - There is a server available, but it is busy. Sleep +/// for a while and try again. +/// +/// So a correctly implemented client looks like this: +/// +/// ```no_run +/// use std::time::Duration; +/// use tokio::net::windows::named_pipe::ClientOptions; +/// use tokio::time; +/// use winapi::shared::winerror; +/// +/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-client"; +/// +/// # #[tokio::main] async fn main() -> std::io::Result<()> { +/// let client = loop { +/// match ClientOptions::new().open(PIPE_NAME) { +/// Ok(client) => break client, +/// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), +/// Err(e) => return Err(e), +/// } +/// +/// time::sleep(Duration::from_millis(50)).await; +/// }; +/// +/// /* use the connected client */ +/// # Ok(()) } +/// ``` +/// +/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY +/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes +#[derive(Debug)] +pub struct NamedPipeClient { + io: PollEvented<mio_windows::NamedPipe>, +} + +impl NamedPipeClient { + /// Constructs a new named pipe client from the specified raw handle. + /// + /// This function will consume ownership of the handle given, passing + /// responsibility for closing the handle to the returned object. + /// + /// This function is also unsafe as the primitives currently returned have + /// the contract that they are the sole owner of the file descriptor they + /// are wrapping. Usage of this function could accidentally allow violating + /// this contract which can cause memory unsafety in code that relies on it + /// being true. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> { + let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle); + + Ok(Self { + io: PollEvented::new(named_pipe)?, + }) + } + + /// Retrieves information about the named pipe the client is associated + /// with. + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::{ClientOptions, PipeEnd, PipeMode}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-info"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let client = ClientOptions::new() + /// .open(PIPE_NAME)?; + /// + /// let client_info = client.info()?; + /// + /// assert_eq!(client_info.end, PipeEnd::Client); + /// assert_eq!(client_info.mode, PipeMode::Message); + /// assert_eq!(client_info.max_instances, 5); + /// # Ok(()) } + /// ``` + pub fn info(&self) -> io::Result<PipeInfo> { + // Safety: we're ensuring the lifetime of the named pipe. + unsafe { named_pipe_info(self.io.as_raw_handle()) } + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same pipe on a single + /// task without splitting the pipe. + /// + /// # Examples + /// + /// Concurrently read and write to the pipe on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-ready"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// let ready = client.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the pipe to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the pipe is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for reading, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Tries to read data from the pipe into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-read"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Tries to read data from the pipe into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: NamedPipeClient::try_read() + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-read-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + /// Waits for the pipe to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-writable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the pipe is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for writing, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Tries to write a buffer to the pipe, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-write"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Tries to write several buffers to the pipe, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: NamedPipeClient::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-write-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `NamedPipeClient` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`writable()`]: NamedPipeClient::writable() + /// [`ready()`]: NamedPipeClient::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } +} + +impl AsyncRead for NamedPipeClient { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unsafe { self.io.poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipeClient { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsRawHandle for NamedPipeClient { + fn as_raw_handle(&self) -> RawHandle { + self.io.as_raw_handle() + } +} + +// Helper to set a boolean flag as a bitfield. +macro_rules! bool_flag { + ($f:expr, $t:expr, $flag:expr) => {{ + let current = $f; + + if $t { + $f = current | $flag; + } else { + $f = current & !$flag; + }; + }}; +} + +/// A builder structure for construct a named pipe with named pipe-specific +/// options. This is required to use for named pipe servers who wants to modify +/// pipe-related options. +/// +/// See [`ServerOptions::create`]. +#[derive(Debug, Clone)] +pub struct ServerOptions { + open_mode: DWORD, + pipe_mode: DWORD, + max_instances: DWORD, + out_buffer_size: DWORD, + in_buffer_size: DWORD, + default_timeout: DWORD, +} + +impl ServerOptions { + /// Creates a new named pipe builder with the default settings. + /// + /// ``` + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-new"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn new() -> ServerOptions { + ServerOptions { + open_mode: winbase::PIPE_ACCESS_DUPLEX | winbase::FILE_FLAG_OVERLAPPED, + pipe_mode: winbase::PIPE_TYPE_BYTE | winbase::PIPE_REJECT_REMOTE_CLIENTS, + max_instances: winbase::PIPE_UNLIMITED_INSTANCES, + out_buffer_size: 65536, + in_buffer_size: 65536, + default_timeout: 0, + } + } + + /// The pipe mode. + /// + /// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for + /// documentation of what each mode means. + /// + /// This corresponding to specifying [`dwPipeMode`]. + /// + /// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self { + self.pipe_mode = match pipe_mode { + PipeMode::Byte => winbase::PIPE_TYPE_BYTE, + PipeMode::Message => winbase::PIPE_TYPE_MESSAGE, + }; + + self + } + + /// The flow of data in the pipe goes from client to server only. + /// + /// This corresponds to setting [`PIPE_ACCESS_INBOUND`]. + /// + /// [`PIPE_ACCESS_INBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_inbound + /// + /// # Errors + /// + /// Server side prevents connecting by denying inbound access, client errors + /// with [`std::io::ErrorKind::PermissionDenied`] when attempting to create + /// the connection. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err1"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let _server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let e = ClientOptions::new() + /// .open(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// Disabling writing allows a client to connect, but errors with + /// [`std::io::ErrorKind::PermissionDenied`] if a write is attempted. + /// + /// ``` + /// use std::io; + /// use tokio::io::AsyncWriteExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err2"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .write(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let e = client.write(b"ping").await.unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// A unidirectional named pipe that only supports server-to-client + /// communication. + /// + /// ``` + /// use std::io; + /// use tokio::io::{AsyncReadExt, AsyncWriteExt}; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .write(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let write = server.write_all(b"ping"); + /// + /// let mut buf = [0u8; 4]; + /// let read = client.read_exact(&mut buf); + /// + /// let ((), read) = tokio::try_join!(write, read)?; + /// + /// assert_eq!(read, 4); + /// assert_eq!(&buf[..], b"ping"); + /// # Ok(()) } + /// ``` + pub fn access_inbound(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_INBOUND); + self + } + + /// The flow of data in the pipe goes from server to client only. + /// + /// This corresponds to setting [`PIPE_ACCESS_OUTBOUND`]. + /// + /// [`PIPE_ACCESS_OUTBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_outbound + /// + /// # Errors + /// + /// Server side prevents connecting by denying outbound access, client + /// errors with [`std::io::ErrorKind::PermissionDenied`] when attempting to + /// create the connection. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err1"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let e = ClientOptions::new() + /// .open(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// Disabling reading allows a client to connect, but attempting to read + /// will error with [`std::io::ErrorKind::PermissionDenied`]. + /// + /// ``` + /// use std::io; + /// use tokio::io::AsyncReadExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err2"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .read(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let mut buf = [0u8; 4]; + /// let e = client.read(&mut buf).await.unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// A unidirectional named pipe that only supports client-to-server + /// communication. + /// + /// ``` + /// use tokio::io::{AsyncReadExt, AsyncWriteExt}; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let mut server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .read(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let write = client.write_all(b"ping"); + /// + /// let mut buf = [0u8; 4]; + /// let read = server.read_exact(&mut buf); + /// + /// let ((), read) = tokio::try_join!(write, read)?; + /// + /// println!("done reading and writing"); + /// + /// assert_eq!(read, 4); + /// assert_eq!(&buf[..], b"ping"); + /// # Ok(()) } + /// ``` + pub fn access_outbound(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_OUTBOUND); + self + } + + /// If you attempt to create multiple instances of a pipe with this flag + /// set, creation of the first server instance succeeds, but creation of any + /// subsequent instances will fail with + /// [`std::io::ErrorKind::PermissionDenied`]. + /// + /// This option is intended to be used with servers that want to ensure that + /// they are the only process listening for clients on a given named pipe. + /// This is accomplished by enabling it for the first server instance + /// created in a process. + /// + /// This corresponds to setting [`FILE_FLAG_FIRST_PIPE_INSTANCE`]. + /// + /// # Errors + /// + /// If this option is set and more than one instance of the server for a + /// given named pipe exists, calling [`create`] will fail with + /// [`std::io::ErrorKind::PermissionDenied`]. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance-error"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server1 = ServerOptions::new() + /// .first_pipe_instance(true) + /// .create(PIPE_NAME)?; + /// + /// // Second server errs, since it's not the first instance. + /// let e = ServerOptions::new() + /// .first_pipe_instance(true) + /// .create(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut builder = ServerOptions::new(); + /// builder.first_pipe_instance(true); + /// + /// let server = builder.create(PIPE_NAME)?; + /// let e = builder.create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// drop(server); + /// + /// // OK: since, we've closed the other instance. + /// let _server2 = builder.create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + /// + /// [`create`]: ServerOptions::create + /// [`FILE_FLAG_FIRST_PIPE_INSTANCE`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_first_pipe_instance + pub fn first_pipe_instance(&mut self, first: bool) -> &mut Self { + bool_flag!( + self.open_mode, + first, + winbase::FILE_FLAG_FIRST_PIPE_INSTANCE + ); + self + } + + /// Indicates whether this server can accept remote clients or not. Remote + /// clients are disabled by default. + /// + /// This corresponds to setting [`PIPE_REJECT_REMOTE_CLIENTS`]. + /// + /// [`PIPE_REJECT_REMOTE_CLIENTS`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_reject_remote_clients + pub fn reject_remote_clients(&mut self, reject: bool) -> &mut Self { + bool_flag!(self.pipe_mode, reject, winbase::PIPE_REJECT_REMOTE_CLIENTS); + self + } + + /// The maximum number of instances that can be created for this pipe. The + /// first instance of the pipe can specify this value; the same number must + /// be specified for other instances of the pipe. Acceptable values are in + /// the range 1 through 254. The default value is unlimited. + /// + /// This corresponds to specifying [`nMaxInstances`]. + /// + /// [`nMaxInstances`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + /// + /// # Errors + /// + /// The same numbers of `max_instances` have to be used by all servers. Any + /// additional servers trying to be built which uses a mismatching value + /// might error. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions}; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-max-instances"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut server = ServerOptions::new(); + /// server.max_instances(2); + /// + /// let s1 = server.create(PIPE_NAME)?; + /// let c1 = ClientOptions::new().open(PIPE_NAME); + /// + /// let s2 = server.create(PIPE_NAME)?; + /// let c2 = ClientOptions::new().open(PIPE_NAME); + /// + /// // Too many servers! + /// let e = server.create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32)); + /// + /// // Still too many servers even if we specify a higher value! + /// let e = server.max_instances(100).create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32)); + /// # Ok(()) } + /// ``` + /// + /// # Panics + /// + /// This function will panic if more than 254 instances are specified. If + /// you do not wish to set an instance limit, leave it unspecified. + /// + /// ```should_panic + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let builder = ServerOptions::new().max_instances(255); + /// # Ok(()) } + /// ``` + pub fn max_instances(&mut self, instances: usize) -> &mut Self { + assert!(instances < 255, "cannot specify more than 254 instances"); + self.max_instances = instances as DWORD; + self + } + + /// The number of bytes to reserve for the output buffer. + /// + /// This corresponds to specifying [`nOutBufferSize`]. + /// + /// [`nOutBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn out_buffer_size(&mut self, buffer: u32) -> &mut Self { + self.out_buffer_size = buffer as DWORD; + self + } + + /// The number of bytes to reserve for the input buffer. + /// + /// This corresponds to specifying [`nInBufferSize`]. + /// + /// [`nInBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn in_buffer_size(&mut self, buffer: u32) -> &mut Self { + self.in_buffer_size = buffer as DWORD; + self + } + + /// Creates the named pipe identified by `addr` for use as a server. + /// + /// This uses the [`CreateNamedPipe`] function. + /// + /// [`CreateNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// + /// # Examples + /// + /// ``` + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-create"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn create(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeServer> { + // Safety: We're calling create_with_security_attributes_raw w/ a null + // pointer which disables it. + unsafe { self.create_with_security_attributes_raw(addr, ptr::null_mut()) } + } + + /// Creates the named pipe identified by `addr` for use as a server. + /// + /// This is the same as [`create`] except that it supports providing the raw + /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed + /// as the `lpSecurityAttributes` argument to [`CreateFile`]. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// + /// # Safety + /// + /// The `attrs` argument must either be null or point at a valid instance of + /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the + /// behavior is identical to calling the [`create`] method. + /// + /// [`create`]: ServerOptions::create + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES + pub unsafe fn create_with_security_attributes_raw( + &self, + addr: impl AsRef<OsStr>, + attrs: *mut c_void, + ) -> io::Result<NamedPipeServer> { + let addr = encode_addr(addr); + + let h = namedpipeapi::CreateNamedPipeW( + addr.as_ptr(), + self.open_mode, + self.pipe_mode, + self.max_instances, + self.out_buffer_size, + self.in_buffer_size, + self.default_timeout, + attrs as *mut _, + ); + + if h == handleapi::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + NamedPipeServer::from_raw_handle(h) + } +} + +/// A builder suitable for building and interacting with named pipes from the +/// client side. +/// +/// See [`ClientOptions::open`]. +#[derive(Debug, Clone)] +pub struct ClientOptions { + desired_access: DWORD, + security_qos_flags: DWORD, +} + +impl ClientOptions { + /// Creates a new named pipe builder with the default settings. + /// + /// ``` + /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-new"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// // Server must be created in order for the client creation to succeed. + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// let client = ClientOptions::new().open(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn new() -> Self { + Self { + desired_access: winnt::GENERIC_READ | winnt::GENERIC_WRITE, + security_qos_flags: winbase::SECURITY_IDENTIFICATION | winbase::SECURITY_SQOS_PRESENT, + } + } + + /// If the client supports reading data. This is enabled by default. + /// + /// This corresponds to setting [`GENERIC_READ`] in the call to [`CreateFile`]. + /// + /// [`GENERIC_READ`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + pub fn read(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.desired_access, allowed, winnt::GENERIC_READ); + self + } + + /// If the created pipe supports writing data. This is enabled by default. + /// + /// This corresponds to setting [`GENERIC_WRITE`] in the call to [`CreateFile`]. + /// + /// [`GENERIC_WRITE`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + pub fn write(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.desired_access, allowed, winnt::GENERIC_WRITE); + self + } + + /// Sets qos flags which are combined with other flags and attributes in the + /// call to [`CreateFile`]. + /// + /// By default `security_qos_flags` is set to [`SECURITY_IDENTIFICATION`], + /// calling this function would override that value completely with the + /// argument specified. + /// + /// When `security_qos_flags` is not set, a malicious program can gain the + /// elevated privileges of a privileged Rust process when it allows opening + /// user-specified paths, by tricking it into opening a named pipe. So + /// arguably `security_qos_flags` should also be set when opening arbitrary + /// paths. However the bits can then conflict with other flags, specifically + /// `FILE_FLAG_OPEN_NO_RECALL`. + /// + /// For information about possible values, see [Impersonation Levels] on the + /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set + /// automatically when using this method. + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`SECURITY_IDENTIFICATION`]: crate::winapi::um::winbase::SECURITY_IDENTIFICATION + /// [Impersonation Levels]: https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level + pub fn security_qos_flags(&mut self, flags: u32) -> &mut Self { + // See: https://github.com/rust-lang/rust/pull/58216 + self.security_qos_flags = flags | winbase::SECURITY_SQOS_PRESENT; + self + } + + /// Opens the named pipe identified by `addr`. + /// + /// This opens the client using [`CreateFile`] with the + /// `dwCreationDisposition` option set to `OPEN_EXISTING`. + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// There are a few errors you need to take into account when creating a + /// named pipe on the client side: + /// + /// * [`std::io::ErrorKind::NotFound`] - This indicates that the named pipe + /// does not exist. Presumably the server is not up. + /// * [`ERROR_PIPE_BUSY`] - This error is raised when the named pipe exists, + /// but the server is not currently waiting for a connection. Please see the + /// examples for how to check for this error. + /// + /// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY + /// [`winapi`]: crate::winapi + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// [Tokio Runtime]: crate::runtime::Runtime + /// + /// A connect loop that waits until a pipe becomes available looks like + /// this: + /// + /// ```no_run + /// use std::time::Duration; + /// use tokio::net::windows::named_pipe::ClientOptions; + /// use tokio::time; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let client = loop { + /// match ClientOptions::new().open(PIPE_NAME) { + /// Ok(client) => break client, + /// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + /// Err(e) => return Err(e), + /// } + /// + /// time::sleep(Duration::from_millis(50)).await; + /// }; + /// + /// // use the connected client. + /// # Ok(()) } + /// ``` + pub fn open(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeClient> { + // Safety: We're calling open_with_security_attributes_raw w/ a null + // pointer which disables it. + unsafe { self.open_with_security_attributes_raw(addr, ptr::null_mut()) } + } + + /// Opens the named pipe identified by `addr`. + /// + /// This is the same as [`open`] except that it supports providing the raw + /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed + /// as the `lpSecurityAttributes` argument to [`CreateFile`]. + /// + /// # Safety + /// + /// The `attrs` argument must either be null or point at a valid instance of + /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the + /// behavior is identical to calling the [`open`] method. + /// + /// [`open`]: ClientOptions::open + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES + pub unsafe fn open_with_security_attributes_raw( + &self, + addr: impl AsRef<OsStr>, + attrs: *mut c_void, + ) -> io::Result<NamedPipeClient> { + let addr = encode_addr(addr); + + // NB: We could use a platform specialized `OpenOptions` here, but since + // we have access to winapi it ultimately doesn't hurt to use + // `CreateFile` explicitly since it allows the use of our already + // well-structured wide `addr` to pass into CreateFileW. + let h = fileapi::CreateFileW( + addr.as_ptr(), + self.desired_access, + 0, + attrs as *mut _, + fileapi::OPEN_EXISTING, + self.get_flags(), + ptr::null_mut(), + ); + + if h == handleapi::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + NamedPipeClient::from_raw_handle(h) + } + + fn get_flags(&self) -> u32 { + self.security_qos_flags | winbase::FILE_FLAG_OVERLAPPED + } +} + +/// The pipe mode of a named pipe. +/// +/// Set through [`ServerOptions::pipe_mode`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum PipeMode { + /// Data is written to the pipe as a stream of bytes. The pipe does not + /// distinguish bytes written during different write operations. + /// + /// Corresponds to [`PIPE_TYPE_BYTE`][crate::winapi::um::winbase::PIPE_TYPE_BYTE]. + Byte, + /// Data is written to the pipe as a stream of messages. The pipe treats the + /// bytes written during each write operation as a message unit. Any reading + /// on a named pipe returns [`ERROR_MORE_DATA`] when a message is not read + /// completely. + /// + /// Corresponds to [`PIPE_TYPE_MESSAGE`][crate::winapi::um::winbase::PIPE_TYPE_MESSAGE]. + /// + /// [`ERROR_MORE_DATA`]: crate::winapi::shared::winerror::ERROR_MORE_DATA + Message, +} + +/// Indicates the end of a named pipe. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum PipeEnd { + /// The named pipe refers to the client end of a named pipe instance. + /// + /// Corresponds to [`PIPE_CLIENT_END`][crate::winapi::um::winbase::PIPE_CLIENT_END]. + Client, + /// The named pipe refers to the server end of a named pipe instance. + /// + /// Corresponds to [`PIPE_SERVER_END`][crate::winapi::um::winbase::PIPE_SERVER_END]. + Server, +} + +/// Information about a named pipe. +/// +/// Constructed through [`NamedPipeServer::info`] or [`NamedPipeClient::info`]. +#[derive(Debug)] +#[non_exhaustive] +pub struct PipeInfo { + /// Indicates the mode of a named pipe. + pub mode: PipeMode, + /// Indicates the end of a named pipe. + pub end: PipeEnd, + /// The maximum number of instances that can be created for this pipe. + pub max_instances: u32, + /// The number of bytes to reserve for the output buffer. + pub out_buffer_size: u32, + /// The number of bytes to reserve for the input buffer. + pub in_buffer_size: u32, +} + +/// Encodes an address so that it is a null-terminated wide string. +fn encode_addr(addr: impl AsRef<OsStr>) -> Box<[u16]> { + let len = addr.as_ref().encode_wide().count(); + let mut vec = Vec::with_capacity(len + 1); + vec.extend(addr.as_ref().encode_wide()); + vec.push(0); + vec.into_boxed_slice() +} + +/// Internal function to get the info out of a raw named pipe. +unsafe fn named_pipe_info(handle: RawHandle) -> io::Result<PipeInfo> { + let mut flags = 0; + let mut out_buffer_size = 0; + let mut in_buffer_size = 0; + let mut max_instances = 0; + + let result = namedpipeapi::GetNamedPipeInfo( + handle, + &mut flags, + &mut out_buffer_size, + &mut in_buffer_size, + &mut max_instances, + ); + + if result == FALSE { + return Err(io::Error::last_os_error()); + } + + let mut end = PipeEnd::Client; + let mut mode = PipeMode::Byte; + + if flags & winbase::PIPE_SERVER_END != 0 { + end = PipeEnd::Server; + } + + if flags & winbase::PIPE_TYPE_MESSAGE != 0 { + mode = PipeMode::Message; + } + + Ok(PipeInfo { + end, + mode, + out_buffer_size, + in_buffer_size, + max_instances, + }) +} diff --git a/third_party/rust/tokio/src/park/either.rs b/third_party/rust/tokio/src/park/either.rs new file mode 100644 index 0000000000..ee02ec158b --- /dev/null +++ b/third_party/rust/tokio/src/park/either.rs @@ -0,0 +1,74 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + +use crate::park::{Park, Unpark}; + +use std::fmt; +use std::time::Duration; + +pub(crate) enum Either<A, B> { + A(A), + B(B), +} + +impl<A, B> Park for Either<A, B> +where + A: Park, + B: Park, +{ + type Unpark = Either<A::Unpark, B::Unpark>; + type Error = Either<A::Error, B::Error>; + + fn unpark(&self) -> Self::Unpark { + match self { + Either::A(a) => Either::A(a.unpark()), + Either::B(b) => Either::B(b.unpark()), + } + } + + fn park(&mut self) -> Result<(), Self::Error> { + match self { + Either::A(a) => a.park().map_err(Either::A), + Either::B(b) => b.park().map_err(Either::B), + } + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + match self { + Either::A(a) => a.park_timeout(duration).map_err(Either::A), + Either::B(b) => b.park_timeout(duration).map_err(Either::B), + } + } + + fn shutdown(&mut self) { + match self { + Either::A(a) => a.shutdown(), + Either::B(b) => b.shutdown(), + } + } +} + +impl<A, B> Unpark for Either<A, B> +where + A: Unpark, + B: Unpark, +{ + fn unpark(&self) { + match self { + Either::A(a) => a.unpark(), + Either::B(b) => b.unpark(), + } + } +} + +impl<A, B> fmt::Debug for Either<A, B> +where + A: fmt::Debug, + B: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Either::A(a) => a.fmt(fmt), + Either::B(b) => b.fmt(fmt), + } + } +} diff --git a/third_party/rust/tokio/src/park/mod.rs b/third_party/rust/tokio/src/park/mod.rs new file mode 100644 index 0000000000..87d04ff78e --- /dev/null +++ b/third_party/rust/tokio/src/park/mod.rs @@ -0,0 +1,117 @@ +//! Abstraction over blocking and unblocking the current thread. +//! +//! Provides an abstraction over blocking the current thread. This is similar to +//! the park / unpark constructs provided by `std` but made generic. This allows +//! embedding custom functionality to perform when the thread is blocked. +//! +//! A blocked `Park` instance is unblocked by calling `unpark` on its +//! `Unpark` handle. +//! +//! The `ParkThread` struct implements `Park` using `thread::park` to put the +//! thread to sleep. The Tokio reactor also implements park, but uses +//! `mio::Poll` to block the thread instead. +//! +//! The `Park` trait is composable. A timer implementation might decorate a +//! `Park` implementation by checking if any timeouts have elapsed after the +//! inner `Park` implementation unblocks. +//! +//! # Model +//! +//! Conceptually, each `Park` instance has an associated token, which is +//! initially not present: +//! +//! * The `park` method blocks the current thread unless or until the token is +//! available, at which point it atomically consumes the token. +//! * The `unpark` method atomically makes the token available if it wasn't +//! already. +//! +//! Some things to note: +//! +//! * If `unpark` is called before `park`, the next call to `park` will +//! **not** block the thread. +//! * **Spurious** wakeups are permitted, i.e., the `park` method may unblock +//! even if `unpark` was not called. +//! * `park_timeout` does the same as `park` but allows specifying a maximum +//! time to block the thread for. + +cfg_rt! { + pub(crate) mod either; +} + +#[cfg(any(feature = "rt", feature = "sync"))] +pub(crate) mod thread; + +use std::fmt::Debug; +use std::sync::Arc; +use std::time::Duration; + +/// Blocks the current thread. +pub(crate) trait Park { + /// Unpark handle type for the `Park` implementation. + type Unpark: Unpark; + + /// Error returned by `park`. + type Error: Debug; + + /// Gets a new `Unpark` handle associated with this `Park` instance. + fn unpark(&self) -> Self::Unpark; + + /// Blocks the current thread unless or until the token is available. + /// + /// A call to `park` does not guarantee that the thread will remain blocked + /// forever, and callers should be prepared for this possibility. This + /// function may wakeup spuriously for any reason. + /// + /// # Panics + /// + /// This function **should** not panic, but ultimately, panics are left as + /// an implementation detail. Refer to the documentation for the specific + /// `Park` implementation. + fn park(&mut self) -> Result<(), Self::Error>; + + /// Parks the current thread for at most `duration`. + /// + /// This function is the same as `park` but allows specifying a maximum time + /// to block the thread for. + /// + /// Same as `park`, there is no guarantee that the thread will remain + /// blocked for any amount of time. Spurious wakeups are permitted for any + /// reason. + /// + /// # Panics + /// + /// This function **should** not panic, but ultimately, panics are left as + /// an implementation detail. Refer to the documentation for the specific + /// `Park` implementation. + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>; + + /// Releases all resources holded by the parker for proper leak-free shutdown. + fn shutdown(&mut self); +} + +/// Unblock a thread blocked by the associated `Park` instance. +pub(crate) trait Unpark: Sync + Send + 'static { + /// Unblocks a thread that is blocked by the associated `Park` handle. + /// + /// Calling `unpark` atomically makes available the unpark token, if it is + /// not already available. + /// + /// # Panics + /// + /// This function **should** not panic, but ultimately, panics are left as + /// an implementation detail. Refer to the documentation for the specific + /// `Unpark` implementation. + fn unpark(&self); +} + +impl Unpark for Box<dyn Unpark> { + fn unpark(&self) { + (**self).unpark() + } +} + +impl Unpark for Arc<dyn Unpark> { + fn unpark(&self) { + (**self).unpark() + } +} diff --git a/third_party/rust/tokio/src/park/thread.rs b/third_party/rust/tokio/src/park/thread.rs new file mode 100644 index 0000000000..27ce202439 --- /dev/null +++ b/third_party/rust/tokio/src/park/thread.rs @@ -0,0 +1,346 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::park::{Park, Unpark}; + +use std::sync::atomic::Ordering::SeqCst; +use std::time::Duration; + +#[derive(Debug)] +pub(crate) struct ParkThread { + inner: Arc<Inner>, +} + +pub(crate) type ParkError = (); + +/// Unblocks a thread that was blocked by `ParkThread`. +#[derive(Clone, Debug)] +pub(crate) struct UnparkThread { + inner: Arc<Inner>, +} + +#[derive(Debug)] +struct Inner { + state: AtomicUsize, + mutex: Mutex<()>, + condvar: Condvar, +} + +const EMPTY: usize = 0; +const PARKED: usize = 1; +const NOTIFIED: usize = 2; + +thread_local! { + static CURRENT_PARKER: ParkThread = ParkThread::new(); +} + +// ==== impl ParkThread ==== + +impl ParkThread { + pub(crate) fn new() -> Self { + Self { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + mutex: Mutex::new(()), + condvar: Condvar::new(), + }), + } + } +} + +impl Park for ParkThread { + type Unpark = UnparkThread; + type Error = ParkError; + + fn unpark(&self) -> Self::Unpark { + let inner = self.inner.clone(); + UnparkThread { inner } + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park(); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.inner.park_timeout(duration); + Ok(()) + } + + fn shutdown(&mut self) { + self.inner.shutdown(); + } +} + +// ==== impl Inner ==== + +impl Inner { + /// Parks the current thread for at most `dur`. + fn park(&self) { + // If we were previously notified then we consume this notification and + // return quickly. + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + return; + } + + // Otherwise we need to coordinate going to sleep + let mut m = self.mutex.lock(); + + match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read here, even though we know it will be `NOTIFIED`. + // This is because `unpark` may have been called again since we read + // `NOTIFIED` in the `compare_exchange` above. We must perform an + // acquire operation that synchronizes with that `unpark` to observe + // any writes it made before the call to unpark. To do that we must + // read from the write it made to `state`. + let old = self.state.swap(EMPTY, SeqCst); + debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + + return; + } + Err(actual) => panic!("inconsistent park state; actual = {}", actual), + } + + loop { + m = self.condvar.wait(m).unwrap(); + + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + // got a notification + return; + } + + // spurious wakeup, go back to sleep + } + } + + fn park_timeout(&self, dur: Duration) { + // Like `park` above we have a fast path for an already-notified thread, + // and afterwards we start coordinating for a sleep. Return quickly. + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + return; + } + + if dur == Duration::from_millis(0) { + return; + } + + let m = self.mutex.lock(); + + match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read again here, see `park`. + let old = self.state.swap(EMPTY, SeqCst); + debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + + return; + } + Err(actual) => panic!("inconsistent park_timeout state; actual = {}", actual), + } + + // Wait with a timeout, and if we spuriously wake up or otherwise wake up + // from a notification, we just want to unconditionally set the state back to + // empty, either consuming a notification or un-flagging ourselves as + // parked. + let (_m, _result) = self.condvar.wait_timeout(m, dur).unwrap(); + + match self.state.swap(EMPTY, SeqCst) { + NOTIFIED => {} // got a notification, hurray! + PARKED => {} // no notification, alas + n => panic!("inconsistent park_timeout state: {}", n), + } + } + + fn unpark(&self) { + // To ensure the unparked thread will observe any writes we made before + // this call, we must perform a release operation that `park` can + // synchronize with. To do that we must write `NOTIFIED` even if `state` + // is already `NOTIFIED`. That is why this must be a swap rather than a + // compare-and-swap that returns if it reads `NOTIFIED` on failure. + match self.state.swap(NOTIFIED, SeqCst) { + EMPTY => return, // no one was waiting + NOTIFIED => return, // already unparked + PARKED => {} // gotta go wake someone up + _ => panic!("inconsistent state in unpark"), + } + + // There is a period between when the parked thread sets `state` to + // `PARKED` (or last checked `state` in the case of a spurious wake + // up) and when it actually waits on `cvar`. If we were to notify + // during this period it would be ignored and then when the parked + // thread went to sleep it would never wake up. Fortunately, it has + // `lock` locked at this stage so we can acquire `lock` to wait until + // it is ready to receive the notification. + // + // Releasing `lock` before the call to `notify_one` means that when the + // parked thread wakes it doesn't get woken only to have to wait for us + // to release `lock`. + drop(self.mutex.lock()); + + self.condvar.notify_one() + } + + fn shutdown(&self) { + self.condvar.notify_all(); + } +} + +impl Default for ParkThread { + fn default() -> Self { + Self::new() + } +} + +// ===== impl UnparkThread ===== + +impl Unpark for UnparkThread { + fn unpark(&self) { + self.inner.unpark(); + } +} + +use std::future::Future; +use std::marker::PhantomData; +use std::mem; +use std::rc::Rc; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +/// Blocks the current thread using a condition variable. +#[derive(Debug)] +pub(crate) struct CachedParkThread { + _anchor: PhantomData<Rc<()>>, +} + +impl CachedParkThread { + /// Creates a new `ParkThread` handle for the current thread. + /// + /// This type cannot be moved to other threads, so it should be created on + /// the thread that the caller intends to park. + pub(crate) fn new() -> CachedParkThread { + CachedParkThread { + _anchor: PhantomData, + } + } + + pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> { + self.with_current(|park_thread| park_thread.unpark()) + } + + /// Gets a reference to the `ParkThread` handle for this thread. + fn with_current<F, R>(&self, f: F) -> Result<R, ParkError> + where + F: FnOnce(&ParkThread) -> R, + { + CURRENT_PARKER.try_with(|inner| f(inner)).map_err(|_| ()) + } + + pub(crate) fn block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError> { + use std::task::Context; + use std::task::Poll::Ready; + + // `get_unpark()` should not return a Result + let waker = self.get_unpark()?.into_waker(); + let mut cx = Context::from_waker(&waker); + + pin!(f); + + loop { + if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { + return Ok(v); + } + + self.park()?; + } + } +} + +impl Park for CachedParkThread { + type Unpark = UnparkThread; + type Error = ParkError; + + fn unpark(&self) -> Self::Unpark { + self.get_unpark().unwrap() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.with_current(|park_thread| park_thread.inner.park())?; + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?; + Ok(()) + } + + fn shutdown(&mut self) { + let _ = self.with_current(|park_thread| park_thread.inner.shutdown()); + } +} + +impl UnparkThread { + pub(crate) fn into_waker(self) -> Waker { + unsafe { + let raw = unparker_to_raw_waker(self.inner); + Waker::from_raw(raw) + } + } +} + +impl Inner { + #[allow(clippy::wrong_self_convention)] + fn into_raw(this: Arc<Inner>) -> *const () { + Arc::into_raw(this) as *const () + } + + unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> { + Arc::from_raw(ptr as *const Inner) + } +} + +unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker { + RawWaker::new( + Inner::into_raw(unparker), + &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker), + ) +} + +unsafe fn clone(raw: *const ()) -> RawWaker { + let unparker = Inner::from_raw(raw); + + // Increment the ref count + mem::forget(unparker.clone()); + + unparker_to_raw_waker(unparker) +} + +unsafe fn drop_waker(raw: *const ()) { + let _ = Inner::from_raw(raw); +} + +unsafe fn wake(raw: *const ()) { + let unparker = Inner::from_raw(raw); + unparker.unpark(); +} + +unsafe fn wake_by_ref(raw: *const ()) { + let unparker = Inner::from_raw(raw); + unparker.unpark(); + + // We don't actually own a reference to the unparker + mem::forget(unparker); +} diff --git a/third_party/rust/tokio/src/process/kill.rs b/third_party/rust/tokio/src/process/kill.rs new file mode 100644 index 0000000000..a1f1652281 --- /dev/null +++ b/third_party/rust/tokio/src/process/kill.rs @@ -0,0 +1,13 @@ +use std::io; + +/// An interface for killing a running process. +pub(crate) trait Kill { + /// Forcefully kills the process. + fn kill(&mut self) -> io::Result<()>; +} + +impl<T: Kill> Kill for &mut T { + fn kill(&mut self) -> io::Result<()> { + (**self).kill() + } +} diff --git a/third_party/rust/tokio/src/process/mod.rs b/third_party/rust/tokio/src/process/mod.rs new file mode 100644 index 0000000000..4e1a21dd44 --- /dev/null +++ b/third_party/rust/tokio/src/process/mod.rs @@ -0,0 +1,1534 @@ +//! An implementation of asynchronous process management for Tokio. +//! +//! This module provides a [`Command`] struct that imitates the interface of the +//! [`std::process::Command`] type in the standard library, but provides asynchronous versions of +//! functions that create processes. These functions (`spawn`, `status`, `output` and their +//! variants) return "future aware" types that interoperate with Tokio. The asynchronous process +//! support is provided through signal handling on Unix and system APIs on Windows. +//! +//! [`std::process::Command`]: std::process::Command +//! +//! # Examples +//! +//! Here's an example program which will spawn `echo hello world` and then wait +//! for it complete. +//! +//! ```no_run +//! use tokio::process::Command; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // The usage is similar as with the standard library's `Command` type +//! let mut child = Command::new("echo") +//! .arg("hello") +//! .arg("world") +//! .spawn() +//! .expect("failed to spawn"); +//! +//! // Await until the command completes +//! let status = child.wait().await?; +//! println!("the command exited with: {}", status); +//! Ok(()) +//! } +//! ``` +//! +//! Next, let's take a look at an example where we not only spawn `echo hello +//! world` but we also capture its output. +//! +//! ```no_run +//! use tokio::process::Command; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // Like above, but use `output` which returns a future instead of +//! // immediately returning the `Child`. +//! let output = Command::new("echo").arg("hello").arg("world") +//! .output(); +//! +//! let output = output.await?; +//! +//! assert!(output.status.success()); +//! assert_eq!(output.stdout, b"hello world\n"); +//! Ok(()) +//! } +//! ``` +//! +//! We can also read input line by line. +//! +//! ```no_run +//! use tokio::io::{BufReader, AsyncBufReadExt}; +//! use tokio::process::Command; +//! +//! use std::process::Stdio; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut cmd = Command::new("cat"); +//! +//! // Specify that we want the command's standard output piped back to us. +//! // By default, standard input/output/error will be inherited from the +//! // current process (for example, this means that standard input will +//! // come from the keyboard and standard output/error will go directly to +//! // the terminal if this process is invoked from the command line). +//! cmd.stdout(Stdio::piped()); +//! +//! let mut child = cmd.spawn() +//! .expect("failed to spawn command"); +//! +//! let stdout = child.stdout.take() +//! .expect("child did not have a handle to stdout"); +//! +//! let mut reader = BufReader::new(stdout).lines(); +//! +//! // Ensure the child process is spawned in the runtime so it can +//! // make progress on its own while we await for any output. +//! tokio::spawn(async move { +//! let status = child.wait().await +//! .expect("child process encountered an error"); +//! +//! println!("child status was: {}", status); +//! }); +//! +//! while let Some(line) = reader.next_line().await? { +//! println!("Line: {}", line); +//! } +//! +//! Ok(()) +//! } +//! ``` +//! +//! Here is another example using `sort` writing into the child process +//! standard input, capturing the output of the sorted text. +//! +//! ```no_run +//! use tokio::io::AsyncWriteExt; +//! use tokio::process::Command; +//! +//! use std::process::Stdio; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut cmd = Command::new("sort"); +//! +//! // Specifying that we want pipe both the output and the input. +//! // Similarily to capturing the output, by configuring the pipe +//! // to stdin it can now be used as an asynchronous writer. +//! cmd.stdout(Stdio::piped()); +//! cmd.stdin(Stdio::piped()); +//! +//! let mut child = cmd.spawn().expect("failed to spawn command"); +//! +//! // These are the animals we want to sort +//! let animals: &[&str] = &["dog", "bird", "frog", "cat", "fish"]; +//! +//! let mut stdin = child +//! .stdin +//! .take() +//! .expect("child did not have a handle to stdin"); +//! +//! // Write our animals to the child process +//! // Note that the behavior of `sort` is to buffer _all input_ before writing any output. +//! // In the general sense, it is recommended to write to the child in a separate task as +//! // awaiting its exit (or output) to avoid deadlocks (for example, the child tries to write +//! // some output but gets stuck waiting on the parent to read from it, meanwhile the parent +//! // is stuck waiting to write its input completely before reading the output). +//! stdin +//! .write(animals.join("\n").as_bytes()) +//! .await +//! .expect("could not write to stdin"); +//! +//! // We drop the handle here which signals EOF to the child process. +//! // This tells the child process that it there is no more data on the pipe. +//! drop(stdin); +//! +//! let op = child.wait_with_output().await?; +//! +//! // Results should come back in sorted order +//! assert_eq!(op.stdout, "bird\ncat\ndog\nfish\nfrog\n".as_bytes()); +//! +//! Ok(()) +//! } +//! ``` +//! +//! With some coordination, we can also pipe the output of one command into +//! another. +//! +//! ```no_run +//! use tokio::join; +//! use tokio::process::Command; +//! use std::convert::TryInto; +//! use std::process::Stdio; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let mut echo = Command::new("echo") +//! .arg("hello world!") +//! .stdout(Stdio::piped()) +//! .spawn() +//! .expect("failed to spawn echo"); +//! +//! let tr_stdin: Stdio = echo +//! .stdout +//! .take() +//! .unwrap() +//! .try_into() +//! .expect("failed to convert to Stdio"); +//! +//! let tr = Command::new("tr") +//! .arg("a-z") +//! .arg("A-Z") +//! .stdin(tr_stdin) +//! .stdout(Stdio::piped()) +//! .spawn() +//! .expect("failed to spawn tr"); +//! +//! let (echo_result, tr_output) = join!(echo.wait(), tr.wait_with_output()); +//! +//! assert!(echo_result.unwrap().success()); +//! +//! let tr_output = tr_output.expect("failed to await tr"); +//! assert!(tr_output.status.success()); +//! +//! assert_eq!(tr_output.stdout, b"HELLO WORLD!\n"); +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Caveats +//! +//! ## Dropping/Cancellation +//! +//! Similar to the behavior to the standard library, and unlike the futures +//! paradigm of dropping-implies-cancellation, a spawned process will, by +//! default, continue to execute even after the `Child` handle has been dropped. +//! +//! The [`Command::kill_on_drop`] method can be used to modify this behavior +//! and kill the child process if the `Child` wrapper is dropped before it +//! has exited. +//! +//! ## Unix Processes +//! +//! On Unix platforms processes must be "reaped" by their parent process after +//! they have exited in order to release all OS resources. A child process which +//! has exited, but has not yet been reaped by its parent is considered a "zombie" +//! process. Such processes continue to count against limits imposed by the system, +//! and having too many zombie processes present can prevent additional processes +//! from being spawned. +//! +//! The tokio runtime will, on a best-effort basis, attempt to reap and clean up +//! any process which it has spawned. No additional guarantees are made with regards +//! how quickly or how often this procedure will take place. +//! +//! It is recommended to avoid dropping a [`Child`] process handle before it has been +//! fully `await`ed if stricter cleanup guarantees are required. +//! +//! [`Command`]: crate::process::Command +//! [`Command::kill_on_drop`]: crate::process::Command::kill_on_drop +//! [`Child`]: crate::process::Child + +#[path = "unix/mod.rs"] +#[cfg(unix)] +mod imp; + +#[cfg(unix)] +pub(crate) mod unix { + pub(crate) use super::imp::*; +} + +#[path = "windows.rs"] +#[cfg(windows)] +mod imp; + +mod kill; + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::process::kill::Kill; + +use std::convert::TryInto; +use std::ffi::OsStr; +use std::future::Future; +use std::io; +#[cfg(unix)] +use std::os::unix::process::CommandExt; +#[cfg(windows)] +use std::os::windows::io::{AsRawHandle, RawHandle}; +#[cfg(windows)] +use std::os::windows::process::CommandExt; +use std::path::Path; +use std::pin::Pin; +use std::process::{Command as StdCommand, ExitStatus, Output, Stdio}; +use std::task::Context; +use std::task::Poll; + +/// This structure mimics the API of [`std::process::Command`] found in the standard library, but +/// replaces functions that create a process with an asynchronous variant. The main provided +/// asynchronous functions are [spawn](Command::spawn), [status](Command::status), and +/// [output](Command::output). +/// +/// `Command` uses asynchronous versions of some `std` types (for example [`Child`]). +/// +/// [`std::process::Command`]: std::process::Command +/// [`Child`]: struct@Child +#[derive(Debug)] +pub struct Command { + std: StdCommand, + kill_on_drop: bool, +} + +pub(crate) struct SpawnedChild { + child: imp::Child, + stdin: Option<imp::ChildStdio>, + stdout: Option<imp::ChildStdio>, + stderr: Option<imp::ChildStdio>, +} + +impl Command { + /// Constructs a new `Command` for launching the program at + /// path `program`, with the following default configuration: + /// + /// * No arguments to the program + /// * Inherit the current process's environment + /// * Inherit the current process's working directory + /// * Inherit stdin/stdout/stderr for `spawn` or `status`, but create pipes for `output` + /// + /// Builder methods are provided to change these defaults and + /// otherwise configure the process. + /// + /// If `program` is not an absolute path, the `PATH` will be searched in + /// an OS-defined way. + /// + /// The search path to be used may be controlled by setting the + /// `PATH` environment variable on the Command, + /// but this has some implementation limitations on Windows + /// (see issue [rust-lang/rust#37519]). + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// let command = Command::new("sh"); + /// ``` + /// + /// [rust-lang/rust#37519]: https://github.com/rust-lang/rust/issues/37519 + pub fn new<S: AsRef<OsStr>>(program: S) -> Command { + Self::from(StdCommand::new(program)) + } + + /// Cheaply convert to a `&std::process::Command` for places where the type from the standard + /// library is expected. + pub fn as_std(&self) -> &StdCommand { + &self.std + } + + /// Adds an argument to pass to the program. + /// + /// Only one argument can be passed per use. So instead of: + /// + /// ```no_run + /// tokio::process::Command::new("sh") + /// .arg("-C /path/to/repo"); + /// ``` + /// + /// usage would be: + /// + /// ```no_run + /// tokio::process::Command::new("sh") + /// .arg("-C") + /// .arg("/path/to/repo"); + /// ``` + /// + /// To pass multiple arguments see [`args`]. + /// + /// [`args`]: method@Self::args + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .arg("-l") + /// .arg("-a"); + /// ``` + pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Command { + self.std.arg(arg); + self + } + + /// Adds multiple arguments to pass to the program. + /// + /// To pass a single argument see [`arg`]. + /// + /// [`arg`]: method@Self::arg + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .args(&["-l", "-a"]); + /// ``` + pub fn args<I, S>(&mut self, args: I) -> &mut Command + where + I: IntoIterator<Item = S>, + S: AsRef<OsStr>, + { + self.std.args(args); + self + } + + /// Inserts or updates an environment variable mapping. + /// + /// Note that environment variable names are case-insensitive (but case-preserving) on Windows, + /// and case-sensitive on all other platforms. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .env("PATH", "/bin"); + /// ``` + pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Command + where + K: AsRef<OsStr>, + V: AsRef<OsStr>, + { + self.std.env(key, val); + self + } + + /// Adds or updates multiple environment variable mappings. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// use std::process::{Stdio}; + /// use std::env; + /// use std::collections::HashMap; + /// + /// let filtered_env : HashMap<String, String> = + /// env::vars().filter(|&(ref k, _)| + /// k == "TERM" || k == "TZ" || k == "LANG" || k == "PATH" + /// ).collect(); + /// + /// let command = Command::new("printenv") + /// .stdin(Stdio::null()) + /// .stdout(Stdio::inherit()) + /// .env_clear() + /// .envs(&filtered_env); + /// ``` + pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Command + where + I: IntoIterator<Item = (K, V)>, + K: AsRef<OsStr>, + V: AsRef<OsStr>, + { + self.std.envs(vars); + self + } + + /// Removes an environment variable mapping. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .env_remove("PATH"); + /// ``` + pub fn env_remove<K: AsRef<OsStr>>(&mut self, key: K) -> &mut Command { + self.std.env_remove(key); + self + } + + /// Clears the entire environment map for the child process. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .env_clear(); + /// ``` + pub fn env_clear(&mut self) -> &mut Command { + self.std.env_clear(); + self + } + + /// Sets the working directory for the child process. + /// + /// # Platform-specific behavior + /// + /// If the program path is relative (e.g., `"./script.sh"`), it's ambiguous + /// whether it should be interpreted relative to the parent's working + /// directory or relative to `current_dir`. The behavior in this case is + /// platform specific and unstable, and it's recommended to use + /// [`canonicalize`] to get an absolute program path instead. + /// + /// [`canonicalize`]: crate::fs::canonicalize() + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .current_dir("/bin"); + /// ``` + pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Command { + self.std.current_dir(dir); + self + } + + /// Sets configuration for the child process's standard input (stdin) handle. + /// + /// Defaults to [`inherit`] when used with `spawn` or `status`, and + /// defaults to [`piped`] when used with `output`. + /// + /// [`inherit`]: std::process::Stdio::inherit + /// [`piped`]: std::process::Stdio::piped + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use std::process::{Stdio}; + /// use tokio::process::Command; + /// + /// let command = Command::new("ls") + /// .stdin(Stdio::null()); + /// ``` + pub fn stdin<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Command { + self.std.stdin(cfg); + self + } + + /// Sets configuration for the child process's standard output (stdout) handle. + /// + /// Defaults to [`inherit`] when used with `spawn` or `status`, and + /// defaults to [`piped`] when used with `output`. + /// + /// [`inherit`]: std::process::Stdio::inherit + /// [`piped`]: std::process::Stdio::piped + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// use std::process::Stdio; + /// + /// let command = Command::new("ls") + /// .stdout(Stdio::null()); + /// ``` + pub fn stdout<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Command { + self.std.stdout(cfg); + self + } + + /// Sets configuration for the child process's standard error (stderr) handle. + /// + /// Defaults to [`inherit`] when used with `spawn` or `status`, and + /// defaults to [`piped`] when used with `output`. + /// + /// [`inherit`]: std::process::Stdio::inherit + /// [`piped`]: std::process::Stdio::piped + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// use std::process::{Stdio}; + /// + /// let command = Command::new("ls") + /// .stderr(Stdio::null()); + /// ``` + pub fn stderr<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Command { + self.std.stderr(cfg); + self + } + + /// Controls whether a `kill` operation should be invoked on a spawned child + /// process when its corresponding `Child` handle is dropped. + /// + /// By default, this value is assumed to be `false`, meaning the next spawned + /// process will not be killed on drop, similar to the behavior of the standard + /// library. + /// + /// # Caveats + /// + /// On Unix platforms processes must be "reaped" by their parent process after + /// they have exited in order to release all OS resources. A child process which + /// has exited, but has not yet been reaped by its parent is considered a "zombie" + /// process. Such processes continue to count against limits imposed by the system, + /// and having too many zombie processes present can prevent additional processes + /// from being spawned. + /// + /// Although issuing a `kill` signal to the child process is a synchronous + /// operation, the resulting zombie process cannot be `.await`ed inside of the + /// destructor to avoid blocking other tasks. The tokio runtime will, on a + /// best-effort basis, attempt to reap and clean up such processes in the + /// background, but makes no additional guarantees are made with regards + /// how quickly or how often this procedure will take place. + /// + /// If stronger guarantees are required, it is recommended to avoid dropping + /// a [`Child`] handle where possible, and instead utilize `child.wait().await` + /// or `child.kill().await` where possible. + pub fn kill_on_drop(&mut self, kill_on_drop: bool) -> &mut Command { + self.kill_on_drop = kill_on_drop; + self + } + + /// Sets the [process creation flags][1] to be passed to `CreateProcess`. + /// + /// These will always be ORed with `CREATE_UNICODE_ENVIRONMENT`. + /// + /// [1]: https://msdn.microsoft.com/en-us/library/windows/desktop/ms684863(v=vs.85).aspx + #[cfg(windows)] + #[cfg_attr(docsrs, doc(cfg(windows)))] + pub fn creation_flags(&mut self, flags: u32) -> &mut Command { + self.std.creation_flags(flags); + self + } + + /// Sets the child process's user ID. This translates to a + /// `setuid` call in the child process. Failure in the `setuid` + /// call will cause the spawn to fail. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] + pub fn uid(&mut self, id: u32) -> &mut Command { + self.std.uid(id); + self + } + + /// Similar to `uid` but sets the group ID of the child process. This has + /// the same semantics as the `uid` field. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] + pub fn gid(&mut self, id: u32) -> &mut Command { + self.std.gid(id); + self + } + + /// Sets executable argument. + /// + /// Set the first process argument, `argv[0]`, to something other than the + /// default executable path. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] + pub fn arg0<S>(&mut self, arg: S) -> &mut Command + where + S: AsRef<OsStr>, + { + self.std.arg0(arg); + self + } + + /// Schedules a closure to be run just before the `exec` function is + /// invoked. + /// + /// The closure is allowed to return an I/O error whose OS error code will + /// be communicated back to the parent and returned as an error from when + /// the spawn was requested. + /// + /// Multiple closures can be registered and they will be called in order of + /// their registration. If a closure returns `Err` then no further closures + /// will be called and the spawn operation will immediately return with a + /// failure. + /// + /// # Safety + /// + /// This closure will be run in the context of the child process after a + /// `fork`. This primarily means that any modifications made to memory on + /// behalf of this closure will **not** be visible to the parent process. + /// This is often a very constrained environment where normal operations + /// like `malloc` or acquiring a mutex are not guaranteed to work (due to + /// other threads perhaps still running when the `fork` was run). + /// + /// This also means that all resources such as file descriptors and + /// memory-mapped regions got duplicated. It is your responsibility to make + /// sure that the closure does not violate library invariants by making + /// invalid use of these duplicates. + /// + /// When this closure is run, aspects such as the stdio file descriptors and + /// working directory have successfully been changed, so output to these + /// locations may not appear where intended. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] + pub unsafe fn pre_exec<F>(&mut self, f: F) -> &mut Command + where + F: FnMut() -> io::Result<()> + Send + Sync + 'static, + { + self.std.pre_exec(f); + self + } + + /// Executes the command as a child process, returning a handle to it. + /// + /// By default, stdin, stdout and stderr are inherited from the parent. + /// + /// This method will spawn the child process synchronously and return a + /// handle to a future-aware child process. The `Child` returned implements + /// `Future` itself to acquire the `ExitStatus` of the child, and otherwise + /// the `Child` has methods to acquire handles to the stdin, stdout, and + /// stderr streams. + /// + /// All I/O this child does will be associated with the current default + /// event loop. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// async fn run_ls() -> std::process::ExitStatus { + /// Command::new("ls") + /// .spawn() + /// .expect("ls command failed to start") + /// .wait() + /// .await + /// .expect("ls command failed to run") + /// } + /// ``` + /// + /// # Caveats + /// + /// ## Dropping/Cancellation + /// + /// Similar to the behavior to the standard library, and unlike the futures + /// paradigm of dropping-implies-cancellation, a spawned process will, by + /// default, continue to execute even after the `Child` handle has been dropped. + /// + /// The [`Command::kill_on_drop`] method can be used to modify this behavior + /// and kill the child process if the `Child` wrapper is dropped before it + /// has exited. + /// + /// ## Unix Processes + /// + /// On Unix platforms processes must be "reaped" by their parent process after + /// they have exited in order to release all OS resources. A child process which + /// has exited, but has not yet been reaped by its parent is considered a "zombie" + /// process. Such processes continue to count against limits imposed by the system, + /// and having too many zombie processes present can prevent additional processes + /// from being spawned. + /// + /// The tokio runtime will, on a best-effort basis, attempt to reap and clean up + /// any process which it has spawned. No additional guarantees are made with regards + /// how quickly or how often this procedure will take place. + /// + /// It is recommended to avoid dropping a [`Child`] process handle before it has been + /// fully `await`ed if stricter cleanup guarantees are required. + /// + /// [`Command`]: crate::process::Command + /// [`Command::kill_on_drop`]: crate::process::Command::kill_on_drop + /// [`Child`]: crate::process::Child + /// + /// # Errors + /// + /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock` + /// if the system process limit is reached (which includes other applications + /// running on the system). + pub fn spawn(&mut self) -> io::Result<Child> { + imp::spawn_child(&mut self.std).map(|spawned_child| Child { + child: FusedChild::Child(ChildDropGuard { + inner: spawned_child.child, + kill_on_drop: self.kill_on_drop, + }), + stdin: spawned_child.stdin.map(|inner| ChildStdin { inner }), + stdout: spawned_child.stdout.map(|inner| ChildStdout { inner }), + stderr: spawned_child.stderr.map(|inner| ChildStderr { inner }), + }) + } + + /// Executes the command as a child process, waiting for it to finish and + /// collecting its exit status. + /// + /// By default, stdin, stdout and stderr are inherited from the parent. + /// If any input/output handles are set to a pipe then they will be immediately + /// closed after the child is spawned. + /// + /// All I/O this child does will be associated with the current default + /// event loop. + /// + /// The destructor of the future returned by this function will kill + /// the child if [`kill_on_drop`] is set to true. + /// + /// [`kill_on_drop`]: fn@Self::kill_on_drop + /// + /// # Errors + /// + /// This future will return an error if the child process cannot be spawned + /// or if there is an error while awaiting its status. + /// + /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock` + /// if the system process limit is reached (which includes other applications + /// running on the system). + /// + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// async fn run_ls() -> std::process::ExitStatus { + /// Command::new("ls") + /// .status() + /// .await + /// .expect("ls command failed to run") + /// } + /// ``` + pub fn status(&mut self) -> impl Future<Output = io::Result<ExitStatus>> { + let child = self.spawn(); + + async { + let mut child = child?; + + // Ensure we close any stdio handles so we can't deadlock + // waiting on the child which may be waiting to read/write + // to a pipe we're holding. + child.stdin.take(); + child.stdout.take(); + child.stderr.take(); + + child.wait().await + } + } + + /// Executes the command as a child process, waiting for it to finish and + /// collecting all of its output. + /// + /// > **Note**: this method, unlike the standard library, will + /// > unconditionally configure the stdout/stderr handles to be pipes, even + /// > if they have been previously configured. If this is not desired then + /// > the `spawn` method should be used in combination with the + /// > `wait_with_output` method on child. + /// + /// This method will return a future representing the collection of the + /// child process's stdout/stderr. It will resolve to + /// the `Output` type in the standard library, containing `stdout` and + /// `stderr` as `Vec<u8>` along with an `ExitStatus` representing how the + /// process exited. + /// + /// All I/O this child does will be associated with the current default + /// event loop. + /// + /// The destructor of the future returned by this function will kill + /// the child if [`kill_on_drop`] is set to true. + /// + /// [`kill_on_drop`]: fn@Self::kill_on_drop + /// + /// # Errors + /// + /// This future will return an error if the child process cannot be spawned + /// or if there is an error while awaiting its status. + /// + /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock` + /// if the system process limit is reached (which includes other applications + /// running on the system). + /// # Examples + /// + /// Basic usage: + /// + /// ```no_run + /// use tokio::process::Command; + /// + /// async fn run_ls() { + /// let output: std::process::Output = Command::new("ls") + /// .output() + /// .await + /// .expect("ls command failed to run"); + /// println!("stderr of ls: {:?}", output.stderr); + /// } + /// ``` + pub fn output(&mut self) -> impl Future<Output = io::Result<Output>> { + self.std.stdout(Stdio::piped()); + self.std.stderr(Stdio::piped()); + + let child = self.spawn(); + + async { child?.wait_with_output().await } + } +} + +impl From<StdCommand> for Command { + fn from(std: StdCommand) -> Command { + Command { + std, + kill_on_drop: false, + } + } +} + +/// A drop guard which can ensure the child process is killed on drop if specified. +#[derive(Debug)] +struct ChildDropGuard<T: Kill> { + inner: T, + kill_on_drop: bool, +} + +impl<T: Kill> Kill for ChildDropGuard<T> { + fn kill(&mut self) -> io::Result<()> { + let ret = self.inner.kill(); + + if ret.is_ok() { + self.kill_on_drop = false; + } + + ret + } +} + +impl<T: Kill> Drop for ChildDropGuard<T> { + fn drop(&mut self) { + if self.kill_on_drop { + drop(self.kill()); + } + } +} + +impl<T, E, F> Future for ChildDropGuard<F> +where + F: Future<Output = Result<T, E>> + Kill + Unpin, +{ + type Output = Result<T, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + let ret = Pin::new(&mut self.inner).poll(cx); + + if let Poll::Ready(Ok(_)) = ret { + // Avoid the overhead of trying to kill a reaped process + self.kill_on_drop = false; + } + + if ret.is_ready() { + coop.made_progress(); + } + + ret + } +} + +/// Keeps track of the exit status of a child process without worrying about +/// polling the underlying futures even after they have completed. +#[derive(Debug)] +enum FusedChild { + Child(ChildDropGuard<imp::Child>), + Done(ExitStatus), +} + +/// Representation of a child process spawned onto an event loop. +/// +/// # Caveats +/// Similar to the behavior to the standard library, and unlike the futures +/// paradigm of dropping-implies-cancellation, a spawned process will, by +/// default, continue to execute even after the `Child` handle has been dropped. +/// +/// The `Command::kill_on_drop` method can be used to modify this behavior +/// and kill the child process if the `Child` wrapper is dropped before it +/// has exited. +#[derive(Debug)] +pub struct Child { + child: FusedChild, + + /// The handle for writing to the child's standard input (stdin), if it has + /// been captured. To avoid partially moving the `child` and thus blocking + /// yourself from calling functions on `child` while using `stdin`, you might + /// find it helpful to do: + /// + /// ```no_run + /// # let mut child = tokio::process::Command::new("echo").spawn().unwrap(); + /// let stdin = child.stdin.take().unwrap(); + /// ``` + pub stdin: Option<ChildStdin>, + + /// The handle for reading from the child's standard output (stdout), if it + /// has been captured. You might find it helpful to do + /// + /// ```no_run + /// # let mut child = tokio::process::Command::new("echo").spawn().unwrap(); + /// let stdout = child.stdout.take().unwrap(); + /// ``` + /// + /// to avoid partially moving the `child` and thus blocking yourself from calling + /// functions on `child` while using `stdout`. + pub stdout: Option<ChildStdout>, + + /// The handle for reading from the child's standard error (stderr), if it + /// has been captured. You might find it helpful to do + /// + /// ```no_run + /// # let mut child = tokio::process::Command::new("echo").spawn().unwrap(); + /// let stderr = child.stderr.take().unwrap(); + /// ``` + /// + /// to avoid partially moving the `child` and thus blocking yourself from calling + /// functions on `child` while using `stderr`. + pub stderr: Option<ChildStderr>, +} + +impl Child { + /// Returns the OS-assigned process identifier associated with this child + /// while it is still running. + /// + /// Once the child has been polled to completion this will return `None`. + /// This is done to avoid confusion on platforms like Unix where the OS + /// identifier could be reused once the process has completed. + pub fn id(&self) -> Option<u32> { + match &self.child { + FusedChild::Child(child) => Some(child.inner.id()), + FusedChild::Done(_) => None, + } + } + + /// Extracts the raw handle of the process associated with this child while + /// it is still running. Returns `None` if the child has exited. + #[cfg(windows)] + pub fn raw_handle(&self) -> Option<RawHandle> { + match &self.child { + FusedChild::Child(c) => Some(c.inner.as_raw_handle()), + FusedChild::Done(_) => None, + } + } + + /// Attempts to force the child to exit, but does not wait for the request + /// to take effect. + /// + /// On Unix platforms, this is the equivalent to sending a SIGKILL. Note + /// that on Unix platforms it is possible for a zombie process to remain + /// after a kill is sent; to avoid this, the caller should ensure that either + /// `child.wait().await` or `child.try_wait()` is invoked successfully. + pub fn start_kill(&mut self) -> io::Result<()> { + match &mut self.child { + FusedChild::Child(child) => child.kill(), + FusedChild::Done(_) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument: can't kill an exited process", + )), + } + } + + /// Forces the child to exit. + /// + /// This is equivalent to sending a SIGKILL on unix platforms. + /// + /// If the child has to be killed remotely, it is possible to do it using + /// a combination of the select! macro and a oneshot channel. In the following + /// example, the child will run until completion unless a message is sent on + /// the oneshot channel. If that happens, the child is killed immediately + /// using the `.kill()` method. + /// + /// ```no_run + /// use tokio::process::Command; + /// use tokio::sync::oneshot::channel; + /// + /// #[tokio::main] + /// async fn main() { + /// let (send, recv) = channel::<()>(); + /// let mut child = Command::new("sleep").arg("1").spawn().unwrap(); + /// tokio::spawn(async move { send.send(()) }); + /// tokio::select! { + /// _ = child.wait() => {} + /// _ = recv => child.kill().await.expect("kill failed"), + /// } + /// } + /// ``` + pub async fn kill(&mut self) -> io::Result<()> { + self.start_kill()?; + self.wait().await?; + Ok(()) + } + + /// Waits for the child to exit completely, returning the status that it + /// exited with. This function will continue to have the same return value + /// after it has been called at least once. + /// + /// The stdin handle to the child process, if any, will be closed + /// before waiting. This helps avoid deadlock: it ensures that the + /// child does not block waiting for input from the parent, while + /// the parent waits for the child to exit. + /// + /// If the caller wishes to explicitly control when the child's stdin + /// handle is closed, they may `.take()` it before calling `.wait()`: + /// + /// ``` + /// # #[cfg(not(unix))]fn main(){} + /// # #[cfg(unix)] + /// use tokio::io::AsyncWriteExt; + /// # #[cfg(unix)] + /// use tokio::process::Command; + /// # #[cfg(unix)] + /// use std::process::Stdio; + /// + /// # #[cfg(unix)] + /// #[tokio::main] + /// async fn main() { + /// let mut child = Command::new("cat") + /// .stdin(Stdio::piped()) + /// .spawn() + /// .unwrap(); + /// + /// let mut stdin = child.stdin.take().unwrap(); + /// tokio::spawn(async move { + /// // do something with stdin here... + /// stdin.write_all(b"hello world\n").await.unwrap(); + /// + /// // then drop when finished + /// drop(stdin); + /// }); + /// + /// // wait for the process to complete + /// let _ = child.wait().await; + /// } + /// ``` + pub async fn wait(&mut self) -> io::Result<ExitStatus> { + // Ensure stdin is closed so the child isn't stuck waiting on + // input while the parent is waiting for it to exit. + drop(self.stdin.take()); + + match &mut self.child { + FusedChild::Done(exit) => Ok(*exit), + FusedChild::Child(child) => { + let ret = child.await; + + if let Ok(exit) = ret { + self.child = FusedChild::Done(exit); + } + + ret + } + } + } + + /// Attempts to collect the exit status of the child if it has already + /// exited. + /// + /// This function will not block the calling thread and will only + /// check to see if the child process has exited or not. If the child has + /// exited then on Unix the process ID is reaped. This function is + /// guaranteed to repeatedly return a successful exit status so long as the + /// child has already exited. + /// + /// If the child has exited, then `Ok(Some(status))` is returned. If the + /// exit status is not available at this time then `Ok(None)` is returned. + /// If an error occurs, then that error is returned. + /// + /// Note that unlike `wait`, this function will not attempt to drop stdin, + /// nor will it wake the current task if the child exits. + pub fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + match &mut self.child { + FusedChild::Done(exit) => Ok(Some(*exit)), + FusedChild::Child(guard) => { + let ret = guard.inner.try_wait(); + + if let Ok(Some(exit)) = ret { + // Avoid the overhead of trying to kill a reaped process + guard.kill_on_drop = false; + self.child = FusedChild::Done(exit); + } + + ret + } + } + } + + /// Returns a future that will resolve to an `Output`, containing the exit + /// status, stdout, and stderr of the child process. + /// + /// The returned future will simultaneously waits for the child to exit and + /// collect all remaining output on the stdout/stderr handles, returning an + /// `Output` instance. + /// + /// The stdin handle to the child process, if any, will be closed before + /// waiting. This helps avoid deadlock: it ensures that the child does not + /// block waiting for input from the parent, while the parent waits for the + /// child to exit. + /// + /// By default, stdin, stdout and stderr are inherited from the parent. In + /// order to capture the output into this `Output` it is necessary to create + /// new pipes between parent and child. Use `stdout(Stdio::piped())` or + /// `stderr(Stdio::piped())`, respectively, when creating a `Command`. + pub async fn wait_with_output(mut self) -> io::Result<Output> { + use crate::future::try_join3; + + async fn read_to_end<A: AsyncRead + Unpin>(io: &mut Option<A>) -> io::Result<Vec<u8>> { + let mut vec = Vec::new(); + if let Some(io) = io.as_mut() { + crate::io::util::read_to_end(io, &mut vec).await?; + } + Ok(vec) + } + + let mut stdout_pipe = self.stdout.take(); + let mut stderr_pipe = self.stderr.take(); + + let stdout_fut = read_to_end(&mut stdout_pipe); + let stderr_fut = read_to_end(&mut stderr_pipe); + + let (status, stdout, stderr) = try_join3(self.wait(), stdout_fut, stderr_fut).await?; + + // Drop happens after `try_join` due to <https://github.com/tokio-rs/tokio/issues/4309> + drop(stdout_pipe); + drop(stderr_pipe); + + Ok(Output { + status, + stdout, + stderr, + }) + } +} + +/// The standard input stream for spawned children. +/// +/// This type implements the `AsyncWrite` trait to pass data to the stdin handle of +/// handle of a child process asynchronously. +#[derive(Debug)] +pub struct ChildStdin { + inner: imp::ChildStdio, +} + +/// The standard output stream for spawned children. +/// +/// This type implements the `AsyncRead` trait to read data from the stdout +/// handle of a child process asynchronously. +#[derive(Debug)] +pub struct ChildStdout { + inner: imp::ChildStdio, +} + +/// The standard error stream for spawned children. +/// +/// This type implements the `AsyncRead` trait to read data from the stderr +/// handle of a child process asynchronously. +#[derive(Debug)] +pub struct ChildStderr { + inner: imp::ChildStdio, +} + +impl ChildStdin { + /// Creates an asynchronous `ChildStdin` from a synchronous one. + /// + /// # Errors + /// + /// This method may fail if an error is encountered when setting the pipe to + /// non-blocking mode, or when registering the pipe with the runtime's IO + /// driver. + pub fn from_std(inner: std::process::ChildStdin) -> io::Result<Self> { + Ok(Self { + inner: imp::stdio(inner)?, + }) + } +} + +impl ChildStdout { + /// Creates an asynchronous `ChildStderr` from a synchronous one. + /// + /// # Errors + /// + /// This method may fail if an error is encountered when setting the pipe to + /// non-blocking mode, or when registering the pipe with the runtime's IO + /// driver. + pub fn from_std(inner: std::process::ChildStdout) -> io::Result<Self> { + Ok(Self { + inner: imp::stdio(inner)?, + }) + } +} + +impl ChildStderr { + /// Creates an asynchronous `ChildStderr` from a synchronous one. + /// + /// # Errors + /// + /// This method may fail if an error is encountered when setting the pipe to + /// non-blocking mode, or when registering the pipe with the runtime's IO + /// driver. + pub fn from_std(inner: std::process::ChildStderr) -> io::Result<Self> { + Ok(Self { + inner: imp::stdio(inner)?, + }) + } +} + +impl AsyncWrite for ChildStdin { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } +} + +impl AsyncRead for ChildStdout { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // Safety: pipes support reading into uninitialized memory + unsafe { self.inner.poll_read(cx, buf) } + } +} + +impl AsyncRead for ChildStderr { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // Safety: pipes support reading into uninitialized memory + unsafe { self.inner.poll_read(cx, buf) } + } +} + +impl TryInto<Stdio> for ChildStdin { + type Error = io::Error; + + fn try_into(self) -> Result<Stdio, Self::Error> { + imp::convert_to_stdio(self.inner) + } +} + +impl TryInto<Stdio> for ChildStdout { + type Error = io::Error; + + fn try_into(self) -> Result<Stdio, Self::Error> { + imp::convert_to_stdio(self.inner) + } +} + +impl TryInto<Stdio> for ChildStderr { + type Error = io::Error; + + fn try_into(self) -> Result<Stdio, Self::Error> { + imp::convert_to_stdio(self.inner) + } +} + +#[cfg(unix)] +mod sys { + use std::os::unix::io::{AsRawFd, RawFd}; + + use super::{ChildStderr, ChildStdin, ChildStdout}; + + impl AsRawFd for ChildStdin { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } + } + + impl AsRawFd for ChildStdout { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } + } + + impl AsRawFd for ChildStderr { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } + } +} + +#[cfg(windows)] +mod sys { + use std::os::windows::io::{AsRawHandle, RawHandle}; + + use super::{ChildStderr, ChildStdin, ChildStdout}; + + impl AsRawHandle for ChildStdin { + fn as_raw_handle(&self) -> RawHandle { + self.inner.as_raw_handle() + } + } + + impl AsRawHandle for ChildStdout { + fn as_raw_handle(&self) -> RawHandle { + self.inner.as_raw_handle() + } + } + + impl AsRawHandle for ChildStderr { + fn as_raw_handle(&self) -> RawHandle { + self.inner.as_raw_handle() + } + } +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::kill::Kill; + use super::ChildDropGuard; + + use futures::future::FutureExt; + use std::future::Future; + use std::io; + use std::pin::Pin; + use std::task::{Context, Poll}; + + struct Mock { + num_kills: usize, + num_polls: usize, + poll_result: Poll<Result<(), ()>>, + } + + impl Mock { + fn new() -> Self { + Self::with_result(Poll::Pending) + } + + fn with_result(result: Poll<Result<(), ()>>) -> Self { + Self { + num_kills: 0, + num_polls: 0, + poll_result: result, + } + } + } + + impl Kill for Mock { + fn kill(&mut self) -> io::Result<()> { + self.num_kills += 1; + Ok(()) + } + } + + impl Future for Mock { + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { + let inner = Pin::get_mut(self); + inner.num_polls += 1; + inner.poll_result + } + } + + #[test] + fn kills_on_drop_if_specified() { + let mut mock = Mock::new(); + + { + let guard = ChildDropGuard { + inner: &mut mock, + kill_on_drop: true, + }; + drop(guard); + } + + assert_eq!(1, mock.num_kills); + assert_eq!(0, mock.num_polls); + } + + #[test] + fn no_kill_on_drop_by_default() { + let mut mock = Mock::new(); + + { + let guard = ChildDropGuard { + inner: &mut mock, + kill_on_drop: false, + }; + drop(guard); + } + + assert_eq!(0, mock.num_kills); + assert_eq!(0, mock.num_polls); + } + + #[test] + fn no_kill_if_already_killed() { + let mut mock = Mock::new(); + + { + let mut guard = ChildDropGuard { + inner: &mut mock, + kill_on_drop: true, + }; + let _ = guard.kill(); + drop(guard); + } + + assert_eq!(1, mock.num_kills); + assert_eq!(0, mock.num_polls); + } + + #[test] + fn no_kill_if_reaped() { + let mut mock_pending = Mock::with_result(Poll::Pending); + let mut mock_reaped = Mock::with_result(Poll::Ready(Ok(()))); + let mut mock_err = Mock::with_result(Poll::Ready(Err(()))); + + let waker = futures::task::noop_waker(); + let mut context = Context::from_waker(&waker); + { + let mut guard = ChildDropGuard { + inner: &mut mock_pending, + kill_on_drop: true, + }; + let _ = guard.poll_unpin(&mut context); + + let mut guard = ChildDropGuard { + inner: &mut mock_reaped, + kill_on_drop: true, + }; + let _ = guard.poll_unpin(&mut context); + + let mut guard = ChildDropGuard { + inner: &mut mock_err, + kill_on_drop: true, + }; + let _ = guard.poll_unpin(&mut context); + } + + assert_eq!(1, mock_pending.num_kills); + assert_eq!(1, mock_pending.num_polls); + + assert_eq!(0, mock_reaped.num_kills); + assert_eq!(1, mock_reaped.num_polls); + + assert_eq!(1, mock_err.num_kills); + assert_eq!(1, mock_err.num_polls); + } +} diff --git a/third_party/rust/tokio/src/process/unix/driver.rs b/third_party/rust/tokio/src/process/unix/driver.rs new file mode 100644 index 0000000000..84dc8fbd02 --- /dev/null +++ b/third_party/rust/tokio/src/process/unix/driver.rs @@ -0,0 +1,58 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +//! Process driver. + +use crate::park::Park; +use crate::process::unix::GlobalOrphanQueue; +use crate::signal::unix::driver::{Driver as SignalDriver, Handle as SignalHandle}; + +use std::io; +use std::time::Duration; + +/// Responsible for cleaning up orphaned child processes on Unix platforms. +#[derive(Debug)] +pub(crate) struct Driver { + park: SignalDriver, + signal_handle: SignalHandle, +} + +// ===== impl Driver ===== + +impl Driver { + /// Creates a new signal `Driver` instance that delegates wakeups to `park`. + pub(crate) fn new(park: SignalDriver) -> Self { + let signal_handle = park.handle(); + + Self { + park, + signal_handle, + } + } +} + +// ===== impl Park for Driver ===== + +impl Park for Driver { + type Unpark = <SignalDriver as Park>::Unpark; + type Error = io::Error; + + fn unpark(&self) -> Self::Unpark { + self.park.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.park.park()?; + GlobalOrphanQueue::reap_orphans(&self.signal_handle); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.park.park_timeout(duration)?; + GlobalOrphanQueue::reap_orphans(&self.signal_handle); + Ok(()) + } + + fn shutdown(&mut self) { + self.park.shutdown() + } +} diff --git a/third_party/rust/tokio/src/process/unix/mod.rs b/third_party/rust/tokio/src/process/unix/mod.rs new file mode 100644 index 0000000000..576fe6cb47 --- /dev/null +++ b/third_party/rust/tokio/src/process/unix/mod.rs @@ -0,0 +1,250 @@ +//! Unix handling of child processes. +//! +//! Right now the only "fancy" thing about this is how we implement the +//! `Future` implementation on `Child` to get the exit status. Unix offers +//! no way to register a child with epoll, and the only real way to get a +//! notification when a process exits is the SIGCHLD signal. +//! +//! Signal handling in general is *super* hairy and complicated, and it's even +//! more complicated here with the fact that signals are coalesced, so we may +//! not get a SIGCHLD-per-child. +//! +//! Our best approximation here is to check *all spawned processes* for all +//! SIGCHLD signals received. To do that we create a `Signal`, implemented in +//! the `tokio-net` crate, which is a stream over signals being received. +//! +//! Later when we poll the process's exit status we simply check to see if a +//! SIGCHLD has happened since we last checked, and while that returns "yes" we +//! keep trying. +//! +//! Note that this means that this isn't really scalable, but then again +//! processes in general aren't scalable (e.g. millions) so it shouldn't be that +//! bad in theory... + +pub(crate) mod driver; + +pub(crate) mod orphan; +use orphan::{OrphanQueue, OrphanQueueImpl, Wait}; + +mod reap; +use reap::Reaper; + +use crate::io::PollEvented; +use crate::process::kill::Kill; +use crate::process::SpawnedChild; +use crate::signal::unix::driver::Handle as SignalHandle; +use crate::signal::unix::{signal, Signal, SignalKind}; + +use mio::event::Source; +use mio::unix::SourceFd; +use once_cell::sync::Lazy; +use std::fmt; +use std::fs::File; +use std::future::Future; +use std::io; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::pin::Pin; +use std::process::{Child as StdChild, ExitStatus, Stdio}; +use std::task::Context; +use std::task::Poll; + +impl Wait for StdChild { + fn id(&self) -> u32 { + self.id() + } + + fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + self.try_wait() + } +} + +impl Kill for StdChild { + fn kill(&mut self) -> io::Result<()> { + self.kill() + } +} + +static ORPHAN_QUEUE: Lazy<OrphanQueueImpl<StdChild>> = Lazy::new(OrphanQueueImpl::new); + +pub(crate) struct GlobalOrphanQueue; + +impl fmt::Debug for GlobalOrphanQueue { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + ORPHAN_QUEUE.fmt(fmt) + } +} + +impl GlobalOrphanQueue { + fn reap_orphans(handle: &SignalHandle) { + ORPHAN_QUEUE.reap_orphans(handle) + } +} + +impl OrphanQueue<StdChild> for GlobalOrphanQueue { + fn push_orphan(&self, orphan: StdChild) { + ORPHAN_QUEUE.push_orphan(orphan) + } +} + +#[must_use = "futures do nothing unless polled"] +pub(crate) struct Child { + inner: Reaper<StdChild, GlobalOrphanQueue, Signal>, +} + +impl fmt::Debug for Child { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Child") + .field("pid", &self.inner.id()) + .finish() + } +} + +pub(crate) fn spawn_child(cmd: &mut std::process::Command) -> io::Result<SpawnedChild> { + let mut child = cmd.spawn()?; + let stdin = child.stdin.take().map(stdio).transpose()?; + let stdout = child.stdout.take().map(stdio).transpose()?; + let stderr = child.stderr.take().map(stdio).transpose()?; + + let signal = signal(SignalKind::child())?; + + Ok(SpawnedChild { + child: Child { + inner: Reaper::new(child, GlobalOrphanQueue, signal), + }, + stdin, + stdout, + stderr, + }) +} + +impl Child { + pub(crate) fn id(&self) -> u32 { + self.inner.id() + } + + pub(crate) fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + self.inner.inner_mut().try_wait() + } +} + +impl Kill for Child { + fn kill(&mut self) -> io::Result<()> { + self.inner.kill() + } +} + +impl Future for Child { + type Output = io::Result<ExitStatus>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + Pin::new(&mut self.inner).poll(cx) + } +} + +#[derive(Debug)] +pub(crate) struct Pipe { + // Actually a pipe and not a File. However, we are reusing `File` to get + // close on drop. This is a similar trick as `mio`. + fd: File, +} + +impl<T: IntoRawFd> From<T> for Pipe { + fn from(fd: T) -> Self { + let fd = unsafe { File::from_raw_fd(fd.into_raw_fd()) }; + Self { fd } + } +} + +impl<'a> io::Read for &'a Pipe { + fn read(&mut self, bytes: &mut [u8]) -> io::Result<usize> { + (&self.fd).read(bytes) + } +} + +impl<'a> io::Write for &'a Pipe { + fn write(&mut self, bytes: &[u8]) -> io::Result<usize> { + (&self.fd).write(bytes) + } + + fn flush(&mut self) -> io::Result<()> { + (&self.fd).flush() + } +} + +impl AsRawFd for Pipe { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +pub(crate) fn convert_to_stdio(io: PollEvented<Pipe>) -> io::Result<Stdio> { + let mut fd = io.into_inner()?.fd; + + // Ensure that the fd to be inherited is set to *blocking* mode, as this + // is the default that virtually all programs expect to have. Those + // programs that know how to work with nonblocking stdio will know how to + // change it to nonblocking mode. + set_nonblocking(&mut fd, false)?; + + Ok(Stdio::from(fd)) +} + +impl Source for Pipe { + fn register( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, + ) -> io::Result<()> { + SourceFd(&self.as_raw_fd()).register(registry, token, interest) + } + + fn reregister( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, + ) -> io::Result<()> { + SourceFd(&self.as_raw_fd()).reregister(registry, token, interest) + } + + fn deregister(&mut self, registry: &mio::Registry) -> io::Result<()> { + SourceFd(&self.as_raw_fd()).deregister(registry) + } +} + +pub(crate) type ChildStdio = PollEvented<Pipe>; + +fn set_nonblocking<T: AsRawFd>(fd: &mut T, nonblocking: bool) -> io::Result<()> { + unsafe { + let fd = fd.as_raw_fd(); + let previous = libc::fcntl(fd, libc::F_GETFL); + if previous == -1 { + return Err(io::Error::last_os_error()); + } + + let new = if nonblocking { + previous | libc::O_NONBLOCK + } else { + previous & !libc::O_NONBLOCK + }; + + let r = libc::fcntl(fd, libc::F_SETFL, new); + if r == -1 { + return Err(io::Error::last_os_error()); + } + } + + Ok(()) +} + +pub(super) fn stdio<T>(io: T) -> io::Result<PollEvented<Pipe>> +where + T: IntoRawFd, +{ + // Set the fd to nonblocking before we pass it to the event loop + let mut pipe = Pipe::from(io); + set_nonblocking(&mut pipe, true)?; + + PollEvented::new(pipe) +} diff --git a/third_party/rust/tokio/src/process/unix/orphan.rs b/third_party/rust/tokio/src/process/unix/orphan.rs new file mode 100644 index 0000000000..0e52530c37 --- /dev/null +++ b/third_party/rust/tokio/src/process/unix/orphan.rs @@ -0,0 +1,321 @@ +use crate::loom::sync::{Mutex, MutexGuard}; +use crate::signal::unix::driver::Handle as SignalHandle; +use crate::signal::unix::{signal_with_handle, SignalKind}; +use crate::sync::watch; +use std::io; +use std::process::ExitStatus; + +/// An interface for waiting on a process to exit. +pub(crate) trait Wait { + /// Get the identifier for this process or diagnostics. + fn id(&self) -> u32; + /// Try waiting for a process to exit in a non-blocking manner. + fn try_wait(&mut self) -> io::Result<Option<ExitStatus>>; +} + +impl<T: Wait> Wait for &mut T { + fn id(&self) -> u32 { + (**self).id() + } + + fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + (**self).try_wait() + } +} + +/// An interface for queueing up an orphaned process so that it can be reaped. +pub(crate) trait OrphanQueue<T> { + /// Adds an orphan to the queue. + fn push_orphan(&self, orphan: T); +} + +impl<T, O: OrphanQueue<T>> OrphanQueue<T> for &O { + fn push_orphan(&self, orphan: T) { + (**self).push_orphan(orphan); + } +} + +/// An implementation of `OrphanQueue`. +#[derive(Debug)] +pub(crate) struct OrphanQueueImpl<T> { + sigchild: Mutex<Option<watch::Receiver<()>>>, + queue: Mutex<Vec<T>>, +} + +impl<T> OrphanQueueImpl<T> { + pub(crate) fn new() -> Self { + Self { + sigchild: Mutex::new(None), + queue: Mutex::new(Vec::new()), + } + } + + #[cfg(test)] + fn len(&self) -> usize { + self.queue.lock().len() + } + + pub(crate) fn push_orphan(&self, orphan: T) + where + T: Wait, + { + self.queue.lock().push(orphan) + } + + /// Attempts to reap every process in the queue, ignoring any errors and + /// enqueueing any orphans which have not yet exited. + pub(crate) fn reap_orphans(&self, handle: &SignalHandle) + where + T: Wait, + { + // If someone else is holding the lock, they will be responsible for draining + // the queue as necessary, so we can safely bail if that happens + if let Some(mut sigchild_guard) = self.sigchild.try_lock() { + match &mut *sigchild_guard { + Some(sigchild) => { + if sigchild.try_has_changed().and_then(Result::ok).is_some() { + drain_orphan_queue(self.queue.lock()); + } + } + None => { + let queue = self.queue.lock(); + + // Be lazy and only initialize the SIGCHLD listener if there + // are any orphaned processes in the queue. + if !queue.is_empty() { + // An errors shouldn't really happen here, but if it does it + // means that the signal driver isn't running, in + // which case there isn't anything we can + // register/initialize here, so we can try again later + if let Ok(sigchild) = signal_with_handle(SignalKind::child(), handle) { + *sigchild_guard = Some(sigchild); + drain_orphan_queue(queue); + } + } + } + } + } + } +} + +fn drain_orphan_queue<T>(mut queue: MutexGuard<'_, Vec<T>>) +where + T: Wait, +{ + for i in (0..queue.len()).rev() { + match queue[i].try_wait() { + Ok(None) => {} + Ok(Some(_)) | Err(_) => { + // The stdlib handles interruption errors (EINTR) when polling a child process. + // All other errors represent invalid inputs or pids that have already been + // reaped, so we can drop the orphan in case an error is raised. + queue.swap_remove(i); + } + } + } + + drop(queue); +} + +#[cfg(all(test, not(loom)))] +pub(crate) mod test { + use super::*; + use crate::io::driver::Driver as IoDriver; + use crate::signal::unix::driver::{Driver as SignalDriver, Handle as SignalHandle}; + use crate::sync::watch; + use std::cell::{Cell, RefCell}; + use std::io; + use std::os::unix::process::ExitStatusExt; + use std::process::ExitStatus; + use std::rc::Rc; + + pub(crate) struct MockQueue<W> { + pub(crate) all_enqueued: RefCell<Vec<W>>, + } + + impl<W> MockQueue<W> { + pub(crate) fn new() -> Self { + Self { + all_enqueued: RefCell::new(Vec::new()), + } + } + } + + impl<W> OrphanQueue<W> for MockQueue<W> { + fn push_orphan(&self, orphan: W) { + self.all_enqueued.borrow_mut().push(orphan); + } + } + + struct MockWait { + total_waits: Rc<Cell<usize>>, + num_wait_until_status: usize, + return_err: bool, + } + + impl MockWait { + fn new(num_wait_until_status: usize) -> Self { + Self { + total_waits: Rc::new(Cell::new(0)), + num_wait_until_status, + return_err: false, + } + } + + fn with_err() -> Self { + Self { + total_waits: Rc::new(Cell::new(0)), + num_wait_until_status: 0, + return_err: true, + } + } + } + + impl Wait for MockWait { + fn id(&self) -> u32 { + 42 + } + + fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + let waits = self.total_waits.get(); + + let ret = if self.num_wait_until_status == waits { + if self.return_err { + Ok(Some(ExitStatus::from_raw(0))) + } else { + Err(io::Error::new(io::ErrorKind::Other, "mock err")) + } + } else { + Ok(None) + }; + + self.total_waits.set(waits + 1); + ret + } + } + + #[test] + fn drain_attempts_a_single_reap_of_all_queued_orphans() { + let first_orphan = MockWait::new(0); + let second_orphan = MockWait::new(1); + let third_orphan = MockWait::new(2); + let fourth_orphan = MockWait::with_err(); + + let first_waits = first_orphan.total_waits.clone(); + let second_waits = second_orphan.total_waits.clone(); + let third_waits = third_orphan.total_waits.clone(); + let fourth_waits = fourth_orphan.total_waits.clone(); + + let orphanage = OrphanQueueImpl::new(); + orphanage.push_orphan(first_orphan); + orphanage.push_orphan(third_orphan); + orphanage.push_orphan(second_orphan); + orphanage.push_orphan(fourth_orphan); + + assert_eq!(orphanage.len(), 4); + + drain_orphan_queue(orphanage.queue.lock()); + assert_eq!(orphanage.len(), 2); + assert_eq!(first_waits.get(), 1); + assert_eq!(second_waits.get(), 1); + assert_eq!(third_waits.get(), 1); + assert_eq!(fourth_waits.get(), 1); + + drain_orphan_queue(orphanage.queue.lock()); + assert_eq!(orphanage.len(), 1); + assert_eq!(first_waits.get(), 1); + assert_eq!(second_waits.get(), 2); + assert_eq!(third_waits.get(), 2); + assert_eq!(fourth_waits.get(), 1); + + drain_orphan_queue(orphanage.queue.lock()); + assert_eq!(orphanage.len(), 0); + assert_eq!(first_waits.get(), 1); + assert_eq!(second_waits.get(), 2); + assert_eq!(third_waits.get(), 3); + assert_eq!(fourth_waits.get(), 1); + + // Safe to reap when empty + drain_orphan_queue(orphanage.queue.lock()); + } + + #[test] + fn no_reap_if_no_signal_received() { + let (tx, rx) = watch::channel(()); + + let handle = SignalHandle::default(); + + let orphanage = OrphanQueueImpl::new(); + *orphanage.sigchild.lock() = Some(rx); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 0); + + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 0); + + tx.send(()).unwrap(); + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 1); + } + + #[test] + fn no_reap_if_signal_lock_held() { + let handle = SignalHandle::default(); + + let orphanage = OrphanQueueImpl::new(); + let signal_guard = orphanage.sigchild.lock(); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 0); + + drop(signal_guard); + } + + #[cfg_attr(miri, ignore)] // Miri does not support epoll. + #[test] + fn does_not_register_signal_if_queue_empty() { + let signal_driver = IoDriver::new().and_then(SignalDriver::new).unwrap(); + let handle = signal_driver.handle(); + + let orphanage = OrphanQueueImpl::new(); + assert!(orphanage.sigchild.lock().is_none()); // Sanity + + // No register when queue empty + orphanage.reap_orphans(&handle); + assert!(orphanage.sigchild.lock().is_none()); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + orphanage.reap_orphans(&handle); + assert!(orphanage.sigchild.lock().is_some()); + assert_eq!(waits.get(), 1); // Eager reap when registering listener + } + + #[test] + fn does_nothing_if_signal_could_not_be_registered() { + let handle = SignalHandle::default(); + + let orphanage = OrphanQueueImpl::new(); + assert!(orphanage.sigchild.lock().is_none()); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + // Signal handler has "gone away", nothing to register or reap + orphanage.reap_orphans(&handle); + assert!(orphanage.sigchild.lock().is_none()); + assert_eq!(waits.get(), 0); + } +} diff --git a/third_party/rust/tokio/src/process/unix/reap.rs b/third_party/rust/tokio/src/process/unix/reap.rs new file mode 100644 index 0000000000..f7f4d3cc70 --- /dev/null +++ b/third_party/rust/tokio/src/process/unix/reap.rs @@ -0,0 +1,298 @@ +use crate::process::imp::orphan::{OrphanQueue, Wait}; +use crate::process::kill::Kill; +use crate::signal::unix::InternalStream; + +use std::future::Future; +use std::io; +use std::ops::Deref; +use std::pin::Pin; +use std::process::ExitStatus; +use std::task::Context; +use std::task::Poll; + +/// Orchestrates between registering interest for receiving signals when a +/// child process has exited, and attempting to poll for process completion. +#[derive(Debug)] +pub(crate) struct Reaper<W, Q, S> +where + W: Wait, + Q: OrphanQueue<W>, +{ + inner: Option<W>, + orphan_queue: Q, + signal: S, +} + +impl<W, Q, S> Deref for Reaper<W, Q, S> +where + W: Wait, + Q: OrphanQueue<W>, +{ + type Target = W; + + fn deref(&self) -> &Self::Target { + self.inner() + } +} + +impl<W, Q, S> Reaper<W, Q, S> +where + W: Wait, + Q: OrphanQueue<W>, +{ + pub(crate) fn new(inner: W, orphan_queue: Q, signal: S) -> Self { + Self { + inner: Some(inner), + orphan_queue, + signal, + } + } + + fn inner(&self) -> &W { + self.inner.as_ref().expect("inner has gone away") + } + + pub(crate) fn inner_mut(&mut self) -> &mut W { + self.inner.as_mut().expect("inner has gone away") + } +} + +impl<W, Q, S> Future for Reaper<W, Q, S> +where + W: Wait + Unpin, + Q: OrphanQueue<W> + Unpin, + S: InternalStream + Unpin, +{ + type Output = io::Result<ExitStatus>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + // If the child hasn't exited yet, then it's our responsibility to + // ensure the current task gets notified when it might be able to + // make progress. We can use the delivery of a SIGCHLD signal as a + // sign that we can potentially make progress. + // + // However, we will register for a notification on the next signal + // BEFORE we poll the child. Otherwise it is possible that the child + // can exit and the signal can arrive after we last polled the child, + // but before we've registered for a notification on the next signal + // (this can cause a deadlock if there are no more spawned children + // which can generate a different signal for us). A side effect of + // pre-registering for signal notifications is that when the child + // exits, we will have already registered for an additional + // notification we don't need to consume. If another signal arrives, + // this future's task will be notified/woken up again. Since the + // futures model allows for spurious wake ups this extra wakeup + // should not cause significant issues with parent futures. + let registered_interest = self.signal.poll_recv(cx).is_pending(); + + if let Some(status) = self.inner_mut().try_wait()? { + return Poll::Ready(Ok(status)); + } + + // If our attempt to poll for the next signal was not ready, then + // we've arranged for our task to get notified and we can bail out. + if registered_interest { + return Poll::Pending; + } else { + // Otherwise, if the signal stream delivered a signal to us, we + // won't get notified at the next signal, so we'll loop and try + // again. + continue; + } + } + } +} + +impl<W, Q, S> Kill for Reaper<W, Q, S> +where + W: Kill + Wait, + Q: OrphanQueue<W>, +{ + fn kill(&mut self) -> io::Result<()> { + self.inner_mut().kill() + } +} + +impl<W, Q, S> Drop for Reaper<W, Q, S> +where + W: Wait, + Q: OrphanQueue<W>, +{ + fn drop(&mut self) { + if let Ok(Some(_)) = self.inner_mut().try_wait() { + return; + } + + let orphan = self.inner.take().unwrap(); + self.orphan_queue.push_orphan(orphan); + } +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + use crate::process::unix::orphan::test::MockQueue; + use futures::future::FutureExt; + use std::os::unix::process::ExitStatusExt; + use std::process::ExitStatus; + use std::task::Context; + use std::task::Poll; + + #[derive(Debug)] + struct MockWait { + total_kills: usize, + total_waits: usize, + num_wait_until_status: usize, + status: ExitStatus, + } + + impl MockWait { + fn new(status: ExitStatus, num_wait_until_status: usize) -> Self { + Self { + total_kills: 0, + total_waits: 0, + num_wait_until_status, + status, + } + } + } + + impl Wait for MockWait { + fn id(&self) -> u32 { + 0 + } + + fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + let ret = if self.num_wait_until_status == self.total_waits { + Some(self.status) + } else { + None + }; + + self.total_waits += 1; + Ok(ret) + } + } + + impl Kill for MockWait { + fn kill(&mut self) -> io::Result<()> { + self.total_kills += 1; + Ok(()) + } + } + + struct MockStream { + total_polls: usize, + values: Vec<Option<()>>, + } + + impl MockStream { + fn new(values: Vec<Option<()>>) -> Self { + Self { + total_polls: 0, + values, + } + } + } + + impl InternalStream for MockStream { + fn poll_recv(&mut self, _cx: &mut Context<'_>) -> Poll<Option<()>> { + self.total_polls += 1; + match self.values.remove(0) { + Some(()) => Poll::Ready(Some(())), + None => Poll::Pending, + } + } + } + + #[test] + fn reaper() { + let exit = ExitStatus::from_raw(0); + let mock = MockWait::new(exit, 3); + let mut grim = Reaper::new( + mock, + MockQueue::new(), + MockStream::new(vec![None, Some(()), None, None, None]), + ); + + let waker = futures::task::noop_waker(); + let mut context = Context::from_waker(&waker); + + // Not yet exited, interest registered + assert!(grim.poll_unpin(&mut context).is_pending()); + assert_eq!(1, grim.signal.total_polls); + assert_eq!(1, grim.total_waits); + assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); + + // Not yet exited, couldn't register interest the first time + // but managed to register interest the second time around + assert!(grim.poll_unpin(&mut context).is_pending()); + assert_eq!(3, grim.signal.total_polls); + assert_eq!(3, grim.total_waits); + assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); + + // Exited + if let Poll::Ready(r) = grim.poll_unpin(&mut context) { + assert!(r.is_ok()); + let exit_code = r.unwrap(); + assert_eq!(exit_code, exit); + } else { + unreachable!(); + } + assert_eq!(4, grim.signal.total_polls); + assert_eq!(4, grim.total_waits); + assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); + } + + #[test] + fn kill() { + let exit = ExitStatus::from_raw(0); + let mut grim = Reaper::new( + MockWait::new(exit, 0), + MockQueue::new(), + MockStream::new(vec![None]), + ); + + grim.kill().unwrap(); + assert_eq!(1, grim.total_kills); + assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); + } + + #[test] + fn drop_reaps_if_possible() { + let exit = ExitStatus::from_raw(0); + let mut mock = MockWait::new(exit, 0); + + { + let queue = MockQueue::new(); + + let grim = Reaper::new(&mut mock, &queue, MockStream::new(vec![])); + + drop(grim); + + assert!(queue.all_enqueued.borrow().is_empty()); + } + + assert_eq!(1, mock.total_waits); + assert_eq!(0, mock.total_kills); + } + + #[test] + fn drop_enqueues_orphan_if_wait_fails() { + let exit = ExitStatus::from_raw(0); + let mut mock = MockWait::new(exit, 2); + + { + let queue = MockQueue::<&mut MockWait>::new(); + let grim = Reaper::new(&mut mock, &queue, MockStream::new(vec![])); + drop(grim); + + assert_eq!(1, queue.all_enqueued.borrow().len()); + } + + assert_eq!(1, mock.total_waits); + assert_eq!(0, mock.total_kills); + } +} diff --git a/third_party/rust/tokio/src/process/windows.rs b/third_party/rust/tokio/src/process/windows.rs new file mode 100644 index 0000000000..136d5b0cab --- /dev/null +++ b/third_party/rust/tokio/src/process/windows.rs @@ -0,0 +1,205 @@ +//! Windows asynchronous process handling. +//! +//! Like with Unix we don't actually have a way of registering a process with an +//! IOCP object. As a result we similarly need another mechanism for getting a +//! signal when a process has exited. For now this is implemented with the +//! `RegisterWaitForSingleObject` function in the kernel32.dll. +//! +//! This strategy is the same that libuv takes and essentially just queues up a +//! wait for the process in a kernel32-specific thread pool. Once the object is +//! notified (e.g. the process exits) then we have a callback that basically +//! just completes a `Oneshot`. +//! +//! The `poll_exit` implementation will attempt to wait for the process in a +//! nonblocking fashion, but failing that it'll fire off a +//! `RegisterWaitForSingleObject` and then wait on the other end of the oneshot +//! from then on out. + +use crate::io::PollEvented; +use crate::process::kill::Kill; +use crate::process::SpawnedChild; +use crate::sync::oneshot; + +use mio::windows::NamedPipe; +use std::fmt; +use std::future::Future; +use std::io; +use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; +use std::pin::Pin; +use std::process::Stdio; +use std::process::{Child as StdChild, Command as StdCommand, ExitStatus}; +use std::ptr; +use std::task::Context; +use std::task::Poll; +use winapi::shared::minwindef::{DWORD, FALSE}; +use winapi::um::handleapi::{DuplicateHandle, INVALID_HANDLE_VALUE}; +use winapi::um::processthreadsapi::GetCurrentProcess; +use winapi::um::threadpoollegacyapiset::UnregisterWaitEx; +use winapi::um::winbase::{RegisterWaitForSingleObject, INFINITE}; +use winapi::um::winnt::{ + BOOLEAN, DUPLICATE_SAME_ACCESS, HANDLE, PVOID, WT_EXECUTEINWAITTHREAD, WT_EXECUTEONLYONCE, +}; + +#[must_use = "futures do nothing unless polled"] +pub(crate) struct Child { + child: StdChild, + waiting: Option<Waiting>, +} + +impl fmt::Debug for Child { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Child") + .field("pid", &self.id()) + .field("child", &self.child) + .field("waiting", &"..") + .finish() + } +} + +struct Waiting { + rx: oneshot::Receiver<()>, + wait_object: HANDLE, + tx: *mut Option<oneshot::Sender<()>>, +} + +unsafe impl Sync for Waiting {} +unsafe impl Send for Waiting {} + +pub(crate) fn spawn_child(cmd: &mut StdCommand) -> io::Result<SpawnedChild> { + let mut child = cmd.spawn()?; + let stdin = child.stdin.take().map(stdio).transpose()?; + let stdout = child.stdout.take().map(stdio).transpose()?; + let stderr = child.stderr.take().map(stdio).transpose()?; + + Ok(SpawnedChild { + child: Child { + child, + waiting: None, + }, + stdin, + stdout, + stderr, + }) +} + +impl Child { + pub(crate) fn id(&self) -> u32 { + self.child.id() + } + + pub(crate) fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + self.child.try_wait() + } +} + +impl Kill for Child { + fn kill(&mut self) -> io::Result<()> { + self.child.kill() + } +} + +impl Future for Child { + type Output = io::Result<ExitStatus>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let inner = Pin::get_mut(self); + loop { + if let Some(ref mut w) = inner.waiting { + match Pin::new(&mut w.rx).poll(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(_)) => panic!("should not be canceled"), + Poll::Pending => return Poll::Pending, + } + let status = inner.try_wait()?.expect("not ready yet"); + return Poll::Ready(Ok(status)); + } + + if let Some(e) = inner.try_wait()? { + return Poll::Ready(Ok(e)); + } + let (tx, rx) = oneshot::channel(); + let ptr = Box::into_raw(Box::new(Some(tx))); + let mut wait_object = ptr::null_mut(); + let rc = unsafe { + RegisterWaitForSingleObject( + &mut wait_object, + inner.child.as_raw_handle(), + Some(callback), + ptr as *mut _, + INFINITE, + WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE, + ) + }; + if rc == 0 { + let err = io::Error::last_os_error(); + drop(unsafe { Box::from_raw(ptr) }); + return Poll::Ready(Err(err)); + } + inner.waiting = Some(Waiting { + rx, + wait_object, + tx: ptr, + }); + } + } +} + +impl AsRawHandle for Child { + fn as_raw_handle(&self) -> RawHandle { + self.child.as_raw_handle() + } +} + +impl Drop for Waiting { + fn drop(&mut self) { + unsafe { + let rc = UnregisterWaitEx(self.wait_object, INVALID_HANDLE_VALUE); + if rc == 0 { + panic!("failed to unregister: {}", io::Error::last_os_error()); + } + drop(Box::from_raw(self.tx)); + } + } +} + +unsafe extern "system" fn callback(ptr: PVOID, _timer_fired: BOOLEAN) { + let complete = &mut *(ptr as *mut Option<oneshot::Sender<()>>); + let _ = complete.take().unwrap().send(()); +} + +pub(crate) type ChildStdio = PollEvented<NamedPipe>; + +pub(super) fn stdio<T>(io: T) -> io::Result<PollEvented<NamedPipe>> +where + T: IntoRawHandle, +{ + let pipe = unsafe { NamedPipe::from_raw_handle(io.into_raw_handle()) }; + PollEvented::new(pipe) +} + +pub(crate) fn convert_to_stdio(io: PollEvented<NamedPipe>) -> io::Result<Stdio> { + let named_pipe = io.into_inner()?; + + // Mio does not implement `IntoRawHandle` for `NamedPipe`, so we'll manually + // duplicate the handle here... + unsafe { + let mut dup_handle = INVALID_HANDLE_VALUE; + let cur_proc = GetCurrentProcess(); + + let status = DuplicateHandle( + cur_proc, + named_pipe.as_raw_handle(), + cur_proc, + &mut dup_handle, + 0 as DWORD, + FALSE, + DUPLICATE_SAME_ACCESS, + ); + + if status == 0 { + return Err(io::Error::last_os_error()); + } + + Ok(Stdio::from_raw_handle(dup_handle)) + } +} diff --git a/third_party/rust/tokio/src/runtime/basic_scheduler.rs b/third_party/rust/tokio/src/runtime/basic_scheduler.rs new file mode 100644 index 0000000000..401f55b3f2 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/basic_scheduler.rs @@ -0,0 +1,574 @@ +use crate::future::poll_fn; +use crate::loom::sync::atomic::AtomicBool; +use crate::loom::sync::{Arc, Mutex}; +use crate::park::{Park, Unpark}; +use crate::runtime::context::EnterGuard; +use crate::runtime::driver::Driver; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::Callback; +use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; +use crate::sync::notify::Notify; +use crate::util::atomic_cell::AtomicCell; +use crate::util::{waker_ref, Wake, WakerRef}; + +use std::cell::RefCell; +use std::collections::VecDeque; +use std::fmt; +use std::future::Future; +use std::sync::atomic::Ordering::{AcqRel, Release}; +use std::task::Poll::{Pending, Ready}; +use std::time::Duration; + +/// Executes tasks on the current thread +pub(crate) struct BasicScheduler { + /// Core scheduler data is acquired by a thread entering `block_on`. + core: AtomicCell<Core>, + + /// Notifier for waking up other threads to steal the + /// driver. + notify: Notify, + + /// Sendable task spawner + spawner: Spawner, + + /// This is usually None, but right before dropping the BasicScheduler, it + /// is changed to `Some` with the context being the runtime's own context. + /// This ensures that any tasks dropped in the `BasicScheduler`s destructor + /// run in that runtime's context. + context_guard: Option<EnterGuard>, +} + +/// Data required for executing the scheduler. The struct is passed around to +/// a function that will perform the scheduling work and acts as a capability token. +struct Core { + /// Scheduler run queue + tasks: VecDeque<task::Notified<Arc<Shared>>>, + + /// Sendable task spawner + spawner: Spawner, + + /// Current tick + tick: u8, + + /// Runtime driver + /// + /// The driver is removed before starting to park the thread + driver: Option<Driver>, + + /// Metrics batch + metrics: MetricsBatch, +} + +#[derive(Clone)] +pub(crate) struct Spawner { + shared: Arc<Shared>, +} + +/// Scheduler state shared between threads. +struct Shared { + /// Remote run queue. None if the `Runtime` has been dropped. + queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>, + + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + + /// Unpark the blocked thread. + unpark: <Driver as Park>::Unpark, + + /// Indicates whether the blocked on thread was woken. + woken: AtomicBool, + + /// Callback for a worker parking itself + before_park: Option<Callback>, + + /// Callback for a worker unparking itself + after_unpark: Option<Callback>, + + /// Keeps track of various runtime metrics. + scheduler_metrics: SchedulerMetrics, + + /// This scheduler only has one worker. + worker_metrics: WorkerMetrics, +} + +/// Thread-local context. +struct Context { + /// Handle to the spawner + spawner: Spawner, + + /// Scheduler core, enabling the holder of `Context` to execute the + /// scheduler. + core: RefCell<Option<Box<Core>>>, +} + +/// Initial queue capacity. +const INITIAL_CAPACITY: usize = 64; + +/// Max number of tasks to poll per tick. +#[cfg(loom)] +const MAX_TASKS_PER_TICK: usize = 4; +#[cfg(not(loom))] +const MAX_TASKS_PER_TICK: usize = 61; + +/// How often to check the remote queue first. +const REMOTE_FIRST_INTERVAL: u8 = 31; + +// Tracks the current BasicScheduler. +scoped_thread_local!(static CURRENT: Context); + +impl BasicScheduler { + pub(crate) fn new( + driver: Driver, + before_park: Option<Callback>, + after_unpark: Option<Callback>, + ) -> BasicScheduler { + let unpark = driver.unpark(); + + let spawner = Spawner { + shared: Arc::new(Shared { + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), + owned: OwnedTasks::new(), + unpark, + woken: AtomicBool::new(false), + before_park, + after_unpark, + scheduler_metrics: SchedulerMetrics::new(), + worker_metrics: WorkerMetrics::new(), + }), + }; + + let core = AtomicCell::new(Some(Box::new(Core { + tasks: VecDeque::with_capacity(INITIAL_CAPACITY), + spawner: spawner.clone(), + tick: 0, + driver: Some(driver), + metrics: MetricsBatch::new(), + }))); + + BasicScheduler { + core, + notify: Notify::new(), + spawner, + context_guard: None, + } + } + + pub(crate) fn spawner(&self) -> &Spawner { + &self.spawner + } + + pub(crate) fn block_on<F: Future>(&self, future: F) -> F::Output { + pin!(future); + + // Attempt to steal the scheduler core and block_on the future if we can + // there, otherwise, lets select on a notification that the core is + // available or the future is complete. + loop { + if let Some(core) = self.take_core() { + return core.block_on(future); + } else { + let mut enter = crate::runtime::enter(false); + + let notified = self.notify.notified(); + pin!(notified); + + if let Some(out) = enter + .block_on(poll_fn(|cx| { + if notified.as_mut().poll(cx).is_ready() { + return Ready(None); + } + + if let Ready(out) = future.as_mut().poll(cx) { + return Ready(Some(out)); + } + + Pending + })) + .expect("Failed to `Enter::block_on`") + { + return out; + } + } + } + } + + fn take_core(&self) -> Option<CoreGuard<'_>> { + let core = self.core.take()?; + + Some(CoreGuard { + context: Context { + spawner: self.spawner.clone(), + core: RefCell::new(Some(core)), + }, + basic_scheduler: self, + }) + } + + pub(super) fn set_context_guard(&mut self, guard: EnterGuard) { + self.context_guard = Some(guard); + } +} + +impl Drop for BasicScheduler { + fn drop(&mut self) { + // Avoid a double panic if we are currently panicking and + // the lock may be poisoned. + + let core = match self.take_core() { + Some(core) => core, + None if std::thread::panicking() => return, + None => panic!("Oh no! We never placed the Core back, this is a bug!"), + }; + + core.enter(|mut core, context| { + // Drain the OwnedTasks collection. This call also closes the + // collection, ensuring that no tasks are ever pushed after this + // call returns. + context.spawner.shared.owned.close_and_shutdown_all(); + + // Drain local queue + // We already shut down every task, so we just need to drop the task. + while let Some(task) = core.pop_task() { + drop(task); + } + + // Drain remote queue and set it to None + let remote_queue = core.spawner.shared.queue.lock().take(); + + // Using `Option::take` to replace the shared queue with `None`. + // We already shut down every task, so we just need to drop the task. + if let Some(remote_queue) = remote_queue { + for task in remote_queue { + drop(task); + } + } + + assert!(context.spawner.shared.owned.is_empty()); + + // Submit metrics + core.metrics.submit(&core.spawner.shared.worker_metrics); + + (core, ()) + }); + } +} + +impl fmt::Debug for BasicScheduler { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("BasicScheduler").finish() + } +} + +// ===== impl Core ===== + +impl Core { + fn pop_task(&mut self) -> Option<task::Notified<Arc<Shared>>> { + let ret = self.tasks.pop_front(); + self.spawner + .shared + .worker_metrics + .set_queue_depth(self.tasks.len()); + ret + } + + fn push_task(&mut self, task: task::Notified<Arc<Shared>>) { + self.tasks.push_back(task); + self.metrics.inc_local_schedule_count(); + self.spawner + .shared + .worker_metrics + .set_queue_depth(self.tasks.len()); + } +} + +// ===== impl Context ===== + +impl Context { + /// Execute the closure with the given scheduler core stored in the + /// thread-local context. + fn run_task<R>(&self, mut core: Box<Core>, f: impl FnOnce() -> R) -> (Box<Core>, R) { + core.metrics.incr_poll_count(); + self.enter(core, || crate::coop::budget(f)) + } + + /// Blocks the current thread until an event is received by the driver, + /// including I/O events, timer events, ... + fn park(&self, mut core: Box<Core>) -> Box<Core> { + let mut driver = core.driver.take().expect("driver missing"); + + if let Some(f) = &self.spawner.shared.before_park { + // Incorrect lint, the closures are actually different types so `f` + // cannot be passed as an argument to `enter`. + #[allow(clippy::redundant_closure)] + let (c, _) = self.enter(core, || f()); + core = c; + } + + // This check will fail if `before_park` spawns a task for us to run + // instead of parking the thread + if core.tasks.is_empty() { + // Park until the thread is signaled + core.metrics.about_to_park(); + core.metrics.submit(&core.spawner.shared.worker_metrics); + + let (c, _) = self.enter(core, || { + driver.park().expect("failed to park"); + }); + + core = c; + core.metrics.returned_from_park(); + } + + if let Some(f) = &self.spawner.shared.after_unpark { + // Incorrect lint, the closures are actually different types so `f` + // cannot be passed as an argument to `enter`. + #[allow(clippy::redundant_closure)] + let (c, _) = self.enter(core, || f()); + core = c; + } + + core.driver = Some(driver); + core + } + + /// Checks the driver for new events without blocking the thread. + fn park_yield(&self, mut core: Box<Core>) -> Box<Core> { + let mut driver = core.driver.take().expect("driver missing"); + + core.metrics.submit(&core.spawner.shared.worker_metrics); + let (mut core, _) = self.enter(core, || { + driver + .park_timeout(Duration::from_millis(0)) + .expect("failed to park"); + }); + + core.driver = Some(driver); + core + } + + fn enter<R>(&self, core: Box<Core>, f: impl FnOnce() -> R) -> (Box<Core>, R) { + // Store the scheduler core in the thread-local context + // + // A drop-guard is employed at a higher level. + *self.core.borrow_mut() = Some(core); + + // Execute the closure while tracking the execution budget + let ret = f(); + + // Take the scheduler core back + let core = self.core.borrow_mut().take().expect("core missing"); + (core, ret) + } +} + +// ===== impl Spawner ===== + +impl Spawner { + /// Spawns a future onto the basic scheduler + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: crate::future::Future + Send + 'static, + F::Output: Send + 'static, + { + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + + if let Some(notified) = notified { + self.shared.schedule(notified); + } + + handle + } + + fn pop(&self) -> Option<task::Notified<Arc<Shared>>> { + match self.shared.queue.lock().as_mut() { + Some(queue) => queue.pop_front(), + None => None, + } + } + + fn waker_ref(&self) -> WakerRef<'_> { + // Set woken to true when enter block_on, ensure outer future + // be polled for the first time when enter loop + self.shared.woken.store(true, Release); + waker_ref(&self.shared) + } + + // reset woken to false and return original value + pub(crate) fn reset_woken(&self) -> bool { + self.shared.woken.swap(false, AcqRel) + } +} + +cfg_metrics! { + impl Spawner { + pub(crate) fn scheduler_metrics(&self) -> &SchedulerMetrics { + &self.shared.scheduler_metrics + } + + pub(crate) fn injection_queue_depth(&self) -> usize { + // TODO: avoid having to lock. The multi-threaded injection queue + // could probably be used here. + self.shared.queue.lock() + .as_ref() + .map(|queue| queue.len()) + .unwrap_or(0) + } + + pub(crate) fn worker_metrics(&self, worker: usize) -> &WorkerMetrics { + assert_eq!(0, worker); + &self.shared.worker_metrics + } + } +} + +impl fmt::Debug for Spawner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Spawner").finish() + } +} + +// ===== impl Shared ===== + +impl Schedule for Arc<Shared> { + fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { + self.owned.remove(task) + } + + fn schedule(&self, task: task::Notified<Self>) { + CURRENT.with(|maybe_cx| match maybe_cx { + Some(cx) if Arc::ptr_eq(self, &cx.spawner.shared) => { + let mut core = cx.core.borrow_mut(); + + // If `None`, the runtime is shutting down, so there is no need + // to schedule the task. + if let Some(core) = core.as_mut() { + core.push_task(task); + } + } + _ => { + // Track that a task was scheduled from **outside** of the runtime. + self.scheduler_metrics.inc_remote_schedule_count(); + + // If the queue is None, then the runtime has shut down. We + // don't need to do anything with the notification in that case. + let mut guard = self.queue.lock(); + if let Some(queue) = guard.as_mut() { + queue.push_back(task); + drop(guard); + self.unpark.unpark(); + } + } + }); + } +} + +impl Wake for Shared { + fn wake(arc_self: Arc<Self>) { + Wake::wake_by_ref(&arc_self) + } + + /// Wake by reference + fn wake_by_ref(arc_self: &Arc<Self>) { + arc_self.woken.store(true, Release); + arc_self.unpark.unpark(); + } +} + +// ===== CoreGuard ===== + +/// Used to ensure we always place the `Core` value back into its slot in +/// `BasicScheduler`, even if the future panics. +struct CoreGuard<'a> { + context: Context, + basic_scheduler: &'a BasicScheduler, +} + +impl CoreGuard<'_> { + fn block_on<F: Future>(self, future: F) -> F::Output { + self.enter(|mut core, context| { + let _enter = crate::runtime::enter(false); + let waker = context.spawner.waker_ref(); + let mut cx = std::task::Context::from_waker(&waker); + + pin!(future); + + 'outer: loop { + if core.spawner.reset_woken() { + let (c, res) = context.enter(core, || { + crate::coop::budget(|| future.as_mut().poll(&mut cx)) + }); + + core = c; + + if let Ready(v) = res { + return (core, v); + } + } + + for _ in 0..MAX_TASKS_PER_TICK { + // Get and increment the current tick + let tick = core.tick; + core.tick = core.tick.wrapping_add(1); + + let entry = if tick % REMOTE_FIRST_INTERVAL == 0 { + core.spawner.pop().or_else(|| core.tasks.pop_front()) + } else { + core.tasks.pop_front().or_else(|| core.spawner.pop()) + }; + + let task = match entry { + Some(entry) => entry, + None => { + core = context.park(core); + + // Try polling the `block_on` future next + continue 'outer; + } + }; + + let task = context.spawner.shared.owned.assert_owner(task); + + let (c, _) = context.run_task(core, || { + task.run(); + }); + + core = c; + } + + // Yield to the driver, this drives the timer and pulls any + // pending I/O events. + core = context.park_yield(core); + } + }) + } + + /// Enters the scheduler context. This sets the queue and other necessary + /// scheduler state in the thread-local. + fn enter<F, R>(self, f: F) -> R + where + F: FnOnce(Box<Core>, &Context) -> (Box<Core>, R), + { + // Remove `core` from `context` to pass into the closure. + let core = self.context.core.borrow_mut().take().expect("core missing"); + + // Call the closure and place `core` back + let (core, ret) = CURRENT.set(&self.context, || f(core, &self.context)); + + *self.context.core.borrow_mut() = Some(core); + + ret + } +} + +impl Drop for CoreGuard<'_> { + fn drop(&mut self) { + if let Some(core) = self.context.core.borrow_mut().take() { + // Replace old scheduler back into the state to allow + // other threads to pick it up and drive it. + self.basic_scheduler.core.set(core); + + // Wake up other possible threads that could steal the driver. + self.basic_scheduler.notify.notify_one() + } + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/mod.rs b/third_party/rust/tokio/src/runtime/blocking/mod.rs new file mode 100644 index 0000000000..15fe05c9ad --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/mod.rs @@ -0,0 +1,48 @@ +//! Abstracts out the APIs necessary to `Runtime` for integrating the blocking +//! pool. When the `blocking` feature flag is **not** enabled, these APIs are +//! shells. This isolates the complexity of dealing with conditional +//! compilation. + +mod pool; +pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task}; + +cfg_fs! { + pub(crate) use pool::spawn_mandatory_blocking; +} + +mod schedule; +mod shutdown; +mod task; +pub(crate) use schedule::NoopSchedule; +pub(crate) use task::BlockingTask; + +use crate::runtime::Builder; + +pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool { + BlockingPool::new(builder, thread_cap) +} + +/* +cfg_not_blocking_impl! { + use crate::runtime::Builder; + use std::time::Duration; + + #[derive(Debug, Clone)] + pub(crate) struct BlockingPool {} + + pub(crate) use BlockingPool as Spawner; + + pub(crate) fn create_blocking_pool(_builder: &Builder, _thread_cap: usize) -> BlockingPool { + BlockingPool {} + } + + impl BlockingPool { + pub(crate) fn spawner(&self) -> &BlockingPool { + self + } + + pub(crate) fn shutdown(&mut self, _duration: Option<Duration>) { + } + } +} +*/ diff --git a/third_party/rust/tokio/src/runtime/blocking/pool.rs b/third_party/rust/tokio/src/runtime/blocking/pool.rs new file mode 100644 index 0000000000..daf1f63fac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/pool.rs @@ -0,0 +1,396 @@ +//! Thread pool for blocking operations + +use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::thread; +use crate::runtime::blocking::schedule::NoopSchedule; +use crate::runtime::blocking::shutdown; +use crate::runtime::builder::ThreadNameFn; +use crate::runtime::context; +use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::{Builder, Callback, Handle}; + +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::time::Duration; + +pub(crate) struct BlockingPool { + spawner: Spawner, + shutdown_rx: shutdown::Receiver, +} + +#[derive(Clone)] +pub(crate) struct Spawner { + inner: Arc<Inner>, +} + +struct Inner { + /// State shared between worker threads. + shared: Mutex<Shared>, + + /// Pool threads wait on this. + condvar: Condvar, + + /// Spawned threads use this name. + thread_name: ThreadNameFn, + + /// Spawned thread stack size. + stack_size: Option<usize>, + + /// Call after a thread starts. + after_start: Option<Callback>, + + /// Call before a thread stops. + before_stop: Option<Callback>, + + // Maximum number of threads. + thread_cap: usize, + + // Customizable wait timeout. + keep_alive: Duration, +} + +struct Shared { + queue: VecDeque<Task>, + num_th: usize, + num_idle: u32, + num_notify: u32, + shutdown: bool, + shutdown_tx: Option<shutdown::Sender>, + /// Prior to shutdown, we clean up JoinHandles by having each timed-out + /// thread join on the previous timed-out thread. This is not strictly + /// necessary but helps avoid Valgrind false positives, see + /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666> + /// for more information. + last_exiting_thread: Option<thread::JoinHandle<()>>, + /// This holds the JoinHandles for all running threads; on shutdown, the thread + /// calling shutdown handles joining on these. + worker_threads: HashMap<usize, thread::JoinHandle<()>>, + /// This is a counter used to iterate worker_threads in a consistent order (for loom's + /// benefit). + worker_thread_index: usize, +} + +pub(crate) struct Task { + task: task::UnownedTask<NoopSchedule>, + mandatory: Mandatory, +} + +#[derive(PartialEq, Eq)] +pub(crate) enum Mandatory { + #[cfg_attr(not(fs), allow(dead_code))] + Mandatory, + NonMandatory, +} + +impl Task { + pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task { + Task { task, mandatory } + } + + fn run(self) { + self.task.run(); + } + + fn shutdown_or_run_if_mandatory(self) { + match self.mandatory { + Mandatory::NonMandatory => self.task.shutdown(), + Mandatory::Mandatory => self.task.run(), + } + } +} + +const KEEP_ALIVE: Duration = Duration::from_secs(10); + +/// Runs the provided function on an executor dedicated to blocking operations. +/// Tasks will be scheduled as non-mandatory, meaning they may not get executed +/// in case of runtime shutdown. +pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let rt = context::current(); + rt.spawn_blocking(func) +} + +cfg_fs! { + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + /// Runs the provided function on an executor dedicated to blocking + /// operations. Tasks will be scheduled as mandatory, meaning they are + /// guaranteed to run unless a shutdown is already taking place. In case a + /// shutdown is already taking place, `None` will be returned. + pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let rt = context::current(); + rt.spawn_mandatory_blocking(func) + } +} + +// ===== impl BlockingPool ===== + +impl BlockingPool { + pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool { + let (shutdown_tx, shutdown_rx) = shutdown::channel(); + let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE); + + BlockingPool { + spawner: Spawner { + inner: Arc::new(Inner { + shared: Mutex::new(Shared { + queue: VecDeque::new(), + num_th: 0, + num_idle: 0, + num_notify: 0, + shutdown: false, + shutdown_tx: Some(shutdown_tx), + last_exiting_thread: None, + worker_threads: HashMap::new(), + worker_thread_index: 0, + }), + condvar: Condvar::new(), + thread_name: builder.thread_name.clone(), + stack_size: builder.thread_stack_size, + after_start: builder.after_start.clone(), + before_stop: builder.before_stop.clone(), + thread_cap, + keep_alive, + }), + }, + shutdown_rx, + } + } + + pub(crate) fn spawner(&self) -> &Spawner { + &self.spawner + } + + pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) { + let mut shared = self.spawner.inner.shared.lock(); + + // The function can be called multiple times. First, by explicitly + // calling `shutdown` then by the drop handler calling `shutdown`. This + // prevents shutting down twice. + if shared.shutdown { + return; + } + + shared.shutdown = true; + shared.shutdown_tx = None; + self.spawner.inner.condvar.notify_all(); + + let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread); + let workers = std::mem::take(&mut shared.worker_threads); + + drop(shared); + + if self.shutdown_rx.wait(timeout) { + let _ = last_exited_thread.map(|th| th.join()); + + // Loom requires that execution be deterministic, so sort by thread ID before joining. + // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic) + let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect(); + workers.sort_by_key(|(id, _)| *id); + + for (_id, handle) in workers.into_iter() { + let _ = handle.join(); + } + } + } +} + +impl Drop for BlockingPool { + fn drop(&mut self) { + self.shutdown(None); + } +} + +impl fmt::Debug for BlockingPool { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("BlockingPool").finish() + } +} + +// ===== impl Spawner ===== + +impl Spawner { + pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { + let mut shared = self.inner.shared.lock(); + + if shared.shutdown { + // Shutdown the task: it's fine to shutdown this task (even if + // mandatory) because it was scheduled after the shutdown of the + // runtime began. + task.task.shutdown(); + + // no need to even push this task; it would never get picked up + return Err(()); + } + + shared.queue.push_back(task); + + if shared.num_idle == 0 { + // No threads are able to process the task. + + if shared.num_th == self.inner.thread_cap { + // At max number of threads + } else { + shared.num_th += 1; + assert!(shared.shutdown_tx.is_some()); + let shutdown_tx = shared.shutdown_tx.clone(); + + if let Some(shutdown_tx) = shutdown_tx { + let id = shared.worker_thread_index; + shared.worker_thread_index += 1; + + let handle = self.spawn_thread(shutdown_tx, rt, id); + + shared.worker_threads.insert(id, handle); + } + } + } else { + // Notify an idle worker thread. The notification counter + // is used to count the needed amount of notifications + // exactly. Thread libraries may generate spurious + // wakeups, this counter is used to keep us in a + // consistent state. + shared.num_idle -= 1; + shared.num_notify += 1; + self.inner.condvar.notify_one(); + } + + Ok(()) + } + + fn spawn_thread( + &self, + shutdown_tx: shutdown::Sender, + rt: &Handle, + id: usize, + ) -> thread::JoinHandle<()> { + let mut builder = thread::Builder::new().name((self.inner.thread_name)()); + + if let Some(stack_size) = self.inner.stack_size { + builder = builder.stack_size(stack_size); + } + + let rt = rt.clone(); + + builder + .spawn(move || { + // Only the reference should be moved into the closure + let _enter = crate::runtime::context::enter(rt.clone()); + rt.blocking_spawner.inner.run(id); + drop(shutdown_tx); + }) + .expect("OS can't spawn a new worker thread") + } +} + +impl Inner { + fn run(&self, worker_thread_id: usize) { + if let Some(f) = &self.after_start { + f() + } + + let mut shared = self.shared.lock(); + let mut join_on_thread = None; + + 'main: loop { + // BUSY + while let Some(task) = shared.queue.pop_front() { + drop(shared); + task.run(); + + shared = self.shared.lock(); + } + + // IDLE + shared.num_idle += 1; + + while !shared.shutdown { + let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap(); + + shared = lock_result.0; + let timeout_result = lock_result.1; + + if shared.num_notify != 0 { + // We have received a legitimate wakeup, + // acknowledge it by decrementing the counter + // and transition to the BUSY state. + shared.num_notify -= 1; + break; + } + + // Even if the condvar "timed out", if the pool is entering the + // shutdown phase, we want to perform the cleanup logic. + if !shared.shutdown && timeout_result.timed_out() { + // We'll join the prior timed-out thread's JoinHandle after dropping the lock. + // This isn't done when shutting down, because the thread calling shutdown will + // handle joining everything. + let my_handle = shared.worker_threads.remove(&worker_thread_id); + join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle); + + break 'main; + } + + // Spurious wakeup detected, go back to sleep. + } + + if shared.shutdown { + // Drain the queue + while let Some(task) = shared.queue.pop_front() { + drop(shared); + + task.shutdown_or_run_if_mandatory(); + + shared = self.shared.lock(); + } + + // Work was produced, and we "took" it (by decrementing num_notify). + // This means that num_idle was decremented once for our wakeup. + // But, since we are exiting, we need to "undo" that, as we'll stay idle. + shared.num_idle += 1; + // NOTE: Technically we should also do num_notify++ and notify again, + // but since we're shutting down anyway, that won't be necessary. + break; + } + } + + // Thread exit + shared.num_th -= 1; + + // num_idle should now be tracked exactly, panic + // with a descriptive message if it is not the + // case. + shared.num_idle = shared + .num_idle + .checked_sub(1) + .expect("num_idle underflowed on thread exit"); + + if shared.shutdown && shared.num_th == 0 { + self.condvar.notify_one(); + } + + drop(shared); + + if let Some(f) = &self.before_stop { + f() + } + + if let Some(handle) = join_on_thread { + let _ = handle.join(); + } + } +} + +impl fmt::Debug for Spawner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("blocking::Spawner").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/schedule.rs b/third_party/rust/tokio/src/runtime/blocking/schedule.rs new file mode 100644 index 0000000000..54252241d9 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/schedule.rs @@ -0,0 +1,19 @@ +use crate::runtime::task::{self, Task}; + +/// `task::Schedule` implementation that does nothing. This is unique to the +/// blocking scheduler as tasks scheduled are not really futures but blocking +/// operations. +/// +/// We avoid storing the task by forgetting it in `bind` and re-materializing it +/// in `release. +pub(crate) struct NoopSchedule; + +impl task::Schedule for NoopSchedule { + fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { + None + } + + fn schedule(&self, _task: task::Notified<Self>) { + unreachable!(); + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/shutdown.rs b/third_party/rust/tokio/src/runtime/blocking/shutdown.rs new file mode 100644 index 0000000000..e6f4674183 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/shutdown.rs @@ -0,0 +1,71 @@ +//! A shutdown channel. +//! +//! Each worker holds the `Sender` half. When all the `Sender` halves are +//! dropped, the `Receiver` receives a notification. + +use crate::loom::sync::Arc; +use crate::sync::oneshot; + +use std::time::Duration; + +#[derive(Debug, Clone)] +pub(super) struct Sender { + _tx: Arc<oneshot::Sender<()>>, +} + +#[derive(Debug)] +pub(super) struct Receiver { + rx: oneshot::Receiver<()>, +} + +pub(super) fn channel() -> (Sender, Receiver) { + let (tx, rx) = oneshot::channel(); + let tx = Sender { _tx: Arc::new(tx) }; + let rx = Receiver { rx }; + + (tx, rx) +} + +impl Receiver { + /// Blocks the current thread until all `Sender` handles drop. + /// + /// If `timeout` is `Some`, the thread is blocked for **at most** `timeout` + /// duration. If `timeout` is `None`, then the thread is blocked until the + /// shutdown signal is received. + /// + /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`. + pub(crate) fn wait(&mut self, timeout: Option<Duration>) -> bool { + use crate::runtime::enter::try_enter; + + if timeout == Some(Duration::from_nanos(0)) { + return false; + } + + let mut e = match try_enter(false) { + Some(enter) => enter, + _ => { + if std::thread::panicking() { + // Don't panic in a panic + return false; + } else { + panic!( + "Cannot drop a runtime in a context where blocking is not allowed. \ + This happens when a runtime is dropped from within an asynchronous context." + ); + } + } + }; + + // The oneshot completes with an Err + // + // If blocking fails to wait, this indicates a problem parking the + // current thread (usually, shutting down a runtime stored in a + // thread-local). + if let Some(timeout) = timeout { + e.block_on_timeout(&mut self.rx, timeout).is_ok() + } else { + let _ = e.block_on(&mut self.rx); + true + } + } +} diff --git a/third_party/rust/tokio/src/runtime/blocking/task.rs b/third_party/rust/tokio/src/runtime/blocking/task.rs new file mode 100644 index 0000000000..0b7803a6c0 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/blocking/task.rs @@ -0,0 +1,44 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Converts a function to a future that completes on poll. +pub(crate) struct BlockingTask<T> { + func: Option<T>, +} + +impl<T> BlockingTask<T> { + /// Initializes a new blocking task from the given function. + pub(crate) fn new(func: T) -> BlockingTask<T> { + BlockingTask { func: Some(func) } + } +} + +// The closure `F` is never pinned +impl<T> Unpin for BlockingTask<T> {} + +impl<T, R> Future for BlockingTask<T> +where + T: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + type Output = R; + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<R> { + let me = &mut *self; + let func = me + .func + .take() + .expect("[internal exception] blocking task ran twice."); + + // This is a little subtle: + // For convenience, we'd like _every_ call tokio ever makes to Task::poll() to be budgeted + // using coop. However, the way things are currently modeled, even running a blocking task + // currently goes through Task::poll(), and so is subject to budgeting. That isn't really + // what we want; a blocking task may itself want to run tasks (it might be a Worker!), so + // we want it to start without any budgeting. + crate::coop::stop(); + + Poll::Ready(func()) + } +} diff --git a/third_party/rust/tokio/src/runtime/builder.rs b/third_party/rust/tokio/src/runtime/builder.rs new file mode 100644 index 0000000000..91c365fd51 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/builder.rs @@ -0,0 +1,718 @@ +use crate::runtime::handle::Handle; +use crate::runtime::{blocking, driver, Callback, Runtime, Spawner}; + +use std::fmt; +use std::io; +use std::time::Duration; + +/// Builds Tokio Runtime with custom configuration values. +/// +/// Methods can be chained in order to set the configuration values. The +/// Runtime is constructed by calling [`build`]. +/// +/// New instances of `Builder` are obtained via [`Builder::new_multi_thread`] +/// or [`Builder::new_current_thread`]. +/// +/// See function level documentation for details on the various configuration +/// settings. +/// +/// [`build`]: method@Self::build +/// [`Builder::new_multi_thread`]: method@Self::new_multi_thread +/// [`Builder::new_current_thread`]: method@Self::new_current_thread +/// +/// # Examples +/// +/// ``` +/// use tokio::runtime::Builder; +/// +/// fn main() { +/// // build runtime +/// let runtime = Builder::new_multi_thread() +/// .worker_threads(4) +/// .thread_name("my-custom-name") +/// .thread_stack_size(3 * 1024 * 1024) +/// .build() +/// .unwrap(); +/// +/// // use runtime ... +/// } +/// ``` +pub struct Builder { + /// Runtime type + kind: Kind, + + /// Whether or not to enable the I/O driver + enable_io: bool, + + /// Whether or not to enable the time driver + enable_time: bool, + + /// Whether or not the clock should start paused. + start_paused: bool, + + /// The number of worker threads, used by Runtime. + /// + /// Only used when not using the current-thread executor. + worker_threads: Option<usize>, + + /// Cap on thread usage. + max_blocking_threads: usize, + + /// Name fn used for threads spawned by the runtime. + pub(super) thread_name: ThreadNameFn, + + /// Stack size used for threads spawned by the runtime. + pub(super) thread_stack_size: Option<usize>, + + /// Callback to run after each thread starts. + pub(super) after_start: Option<Callback>, + + /// To run before each worker thread stops + pub(super) before_stop: Option<Callback>, + + /// To run before each worker thread is parked. + pub(super) before_park: Option<Callback>, + + /// To run after each thread is unparked. + pub(super) after_unpark: Option<Callback>, + + /// Customizable keep alive timeout for BlockingPool + pub(super) keep_alive: Option<Duration>, +} + +pub(crate) type ThreadNameFn = std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>; + +pub(crate) enum Kind { + CurrentThread, + #[cfg(feature = "rt-multi-thread")] + MultiThread, +} + +impl Builder { + /// Returns a new builder with the current thread scheduler selected. + /// + /// Configuration methods can be chained on the return value. + /// + /// To spawn non-`Send` tasks on the resulting runtime, combine it with a + /// [`LocalSet`]. + /// + /// [`LocalSet`]: crate::task::LocalSet + pub fn new_current_thread() -> Builder { + Builder::new(Kind::CurrentThread) + } + + /// Returns a new builder with the multi thread scheduler selected. + /// + /// Configuration methods can be chained on the return value. + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + pub fn new_multi_thread() -> Builder { + Builder::new(Kind::MultiThread) + } + + /// Returns a new runtime builder initialized with default configuration + /// values. + /// + /// Configuration methods can be chained on the return value. + pub(crate) fn new(kind: Kind) -> Builder { + Builder { + kind, + + // I/O defaults to "off" + enable_io: false, + + // Time defaults to "off" + enable_time: false, + + // The clock starts not-paused + start_paused: false, + + // Default to lazy auto-detection (one thread per CPU core) + worker_threads: None, + + max_blocking_threads: 512, + + // Default thread name + thread_name: std::sync::Arc::new(|| "tokio-runtime-worker".into()), + + // Do not set a stack size by default + thread_stack_size: None, + + // No worker thread callbacks + after_start: None, + before_stop: None, + before_park: None, + after_unpark: None, + + keep_alive: None, + } + } + + /// Enables both I/O and time drivers. + /// + /// Doing this is a shorthand for calling `enable_io` and `enable_time` + /// individually. If additional components are added to Tokio in the future, + /// `enable_all` will include these future components. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// ``` + pub fn enable_all(&mut self) -> &mut Self { + #[cfg(any(feature = "net", feature = "process", all(unix, feature = "signal")))] + self.enable_io(); + #[cfg(feature = "time")] + self.enable_time(); + + self + } + + /// Sets the number of worker threads the `Runtime` will use. + /// + /// This can be any number above 0 though it is advised to keep this value + /// on the smaller side. + /// + /// # Default + /// + /// The default value is the number of cores available to the system. + /// + /// # Panic + /// + /// When using the `current_thread` runtime this method will panic, since + /// those variants do not allow setting worker thread counts. + /// + /// + /// # Examples + /// + /// ## Multi threaded runtime with 4 threads + /// + /// ``` + /// use tokio::runtime; + /// + /// // This will spawn a work-stealing runtime with 4 worker threads. + /// let rt = runtime::Builder::new_multi_thread() + /// .worker_threads(4) + /// .build() + /// .unwrap(); + /// + /// rt.spawn(async move {}); + /// ``` + /// + /// ## Current thread runtime (will only run on the current thread via `Runtime::block_on`) + /// + /// ``` + /// use tokio::runtime; + /// + /// // Create a runtime that _must_ be driven from a call + /// // to `Runtime::block_on`. + /// let rt = runtime::Builder::new_current_thread() + /// .build() + /// .unwrap(); + /// + /// // This will run the runtime and future on the current thread + /// rt.block_on(async move {}); + /// ``` + /// + /// # Panic + /// + /// This will panic if `val` is not larger than `0`. + pub fn worker_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Worker threads cannot be set to 0"); + self.worker_threads = Some(val); + self + } + + /// Specifies the limit for additional threads spawned by the Runtime. + /// + /// These threads are used for blocking operations like tasks spawned + /// through [`spawn_blocking`]. Unlike the [`worker_threads`], they are not + /// always active and will exit if left idle for too long. You can change + /// this timeout duration with [`thread_keep_alive`]. + /// + /// The default value is 512. + /// + /// # Panic + /// + /// This will panic if `val` is not larger than `0`. + /// + /// # Upgrading from 0.x + /// + /// In old versions `max_threads` limited both blocking and worker threads, but the + /// current `max_blocking_threads` does not include async worker threads in the count. + /// + /// [`spawn_blocking`]: fn@crate::task::spawn_blocking + /// [`worker_threads`]: Self::worker_threads + /// [`thread_keep_alive`]: Self::thread_keep_alive + #[cfg_attr(docsrs, doc(alias = "max_threads"))] + pub fn max_blocking_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Max blocking threads cannot be set to 0"); + self.max_blocking_threads = val; + self + } + + /// Sets name of threads spawned by the `Runtime`'s thread pool. + /// + /// The default name is "tokio-runtime-worker". + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_name("my-pool") + /// .build(); + /// # } + /// ``` + pub fn thread_name(&mut self, val: impl Into<String>) -> &mut Self { + let val = val.into(); + self.thread_name = std::sync::Arc::new(move || val.clone()); + self + } + + /// Sets a function used to generate the name of threads spawned by the `Runtime`'s thread pool. + /// + /// The default name fn is `|| "tokio-runtime-worker".into()`. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// # use std::sync::atomic::{AtomicUsize, Ordering}; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_name_fn(|| { + /// static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + /// let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + /// format!("my-pool-{}", id) + /// }) + /// .build(); + /// # } + /// ``` + pub fn thread_name_fn<F>(&mut self, f: F) -> &mut Self + where + F: Fn() -> String + Send + Sync + 'static, + { + self.thread_name = std::sync::Arc::new(f); + self + } + + /// Sets the stack size (in bytes) for worker threads. + /// + /// The actual stack size may be greater than this value if the platform + /// specifies minimal stack size. + /// + /// The default stack size for spawned threads is 2 MiB, though this + /// particular stack size is subject to change in the future. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_stack_size(32 * 1024) + /// .build(); + /// # } + /// ``` + pub fn thread_stack_size(&mut self, val: usize) -> &mut Self { + self.thread_stack_size = Some(val); + self + } + + /// Executes function `f` after each thread is started but before it starts + /// doing work. + /// + /// This is intended for bookkeeping and monitoring use cases. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_start(|| { + /// println!("thread started"); + /// }) + /// .build(); + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_start<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.after_start = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` before each thread stops. + /// + /// This is intended for bookkeeping and monitoring use cases. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_stop(|| { + /// println!("thread stopping"); + /// }) + /// .build(); + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_stop<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.before_stop = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` just before a thread is parked (goes idle). + /// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn) + /// can be called, and may result in this thread being unparked immediately. + /// + /// This can be used to start work only when the executor is idle, or for bookkeeping + /// and monitoring purposes. + /// + /// Note: There can only be one park callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ## Multithreaded executor + /// ``` + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicBool, Ordering}; + /// # use tokio::runtime; + /// # use tokio::sync::Barrier; + /// # pub fn main() { + /// let once = AtomicBool::new(true); + /// let barrier = Arc::new(Barrier::new(2)); + /// + /// let runtime = runtime::Builder::new_multi_thread() + /// .worker_threads(1) + /// .on_thread_park({ + /// let barrier = barrier.clone(); + /// move || { + /// let barrier = barrier.clone(); + /// if once.swap(false, Ordering::Relaxed) { + /// tokio::spawn(async move { barrier.wait().await; }); + /// } + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// barrier.wait().await; + /// }) + /// # } + /// ``` + /// ## Current thread executor + /// ``` + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicBool, Ordering}; + /// # use tokio::runtime; + /// # use tokio::sync::Barrier; + /// # pub fn main() { + /// let once = AtomicBool::new(true); + /// let barrier = Arc::new(Barrier::new(2)); + /// + /// let runtime = runtime::Builder::new_current_thread() + /// .on_thread_park({ + /// let barrier = barrier.clone(); + /// move || { + /// let barrier = barrier.clone(); + /// if once.swap(false, Ordering::Relaxed) { + /// tokio::spawn(async move { barrier.wait().await; }); + /// } + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// barrier.wait().await; + /// }) + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.before_park = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` just after a thread unparks (starts executing tasks). + /// + /// This is intended for bookkeeping and monitoring use cases; note that work + /// in this callback will increase latencies when the application has allowed one or + /// more runtime threads to go idle. + /// + /// Note: There can only be one unpark callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_unpark(|| { + /// println!("thread unparking"); + /// }) + /// .build(); + /// + /// runtime.unwrap().block_on(async { + /// tokio::task::yield_now().await; + /// println!("Hello from Tokio!"); + /// }) + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.after_unpark = Some(std::sync::Arc::new(f)); + self + } + + /// Creates the configured `Runtime`. + /// + /// The returned `Runtime` instance is ready to spawn tasks. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Builder; + /// + /// let rt = Builder::new_multi_thread().build().unwrap(); + /// + /// rt.block_on(async { + /// println!("Hello from the Tokio runtime"); + /// }); + /// ``` + pub fn build(&mut self) -> io::Result<Runtime> { + match &self.kind { + Kind::CurrentThread => self.build_basic_runtime(), + #[cfg(feature = "rt-multi-thread")] + Kind::MultiThread => self.build_threaded_runtime(), + } + } + + fn get_cfg(&self) -> driver::Cfg { + driver::Cfg { + enable_pause_time: match self.kind { + Kind::CurrentThread => true, + #[cfg(feature = "rt-multi-thread")] + Kind::MultiThread => false, + }, + enable_io: self.enable_io, + enable_time: self.enable_time, + start_paused: self.start_paused, + } + } + + /// Sets a custom timeout for a thread in the blocking pool. + /// + /// By default, the timeout for a thread is set to 10 seconds. This can + /// be overridden using .thread_keep_alive(). + /// + /// # Example + /// + /// ``` + /// # use tokio::runtime; + /// # use std::time::Duration; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_keep_alive(Duration::from_millis(100)) + /// .build(); + /// # } + /// ``` + pub fn thread_keep_alive(&mut self, duration: Duration) -> &mut Self { + self.keep_alive = Some(duration); + self + } + + fn build_basic_runtime(&mut self) -> io::Result<Runtime> { + use crate::runtime::{BasicScheduler, Kind}; + + let (driver, resources) = driver::Driver::new(self.get_cfg())?; + + // And now put a single-threaded scheduler on top of the timer. When + // there are no futures ready to do something, it'll let the timer or + // the reactor to generate some new stimuli for the futures to continue + // in their life. + let scheduler = + BasicScheduler::new(driver, self.before_park.clone(), self.after_unpark.clone()); + let spawner = Spawner::Basic(scheduler.spawner().clone()); + + // Blocking pool + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); + let blocking_spawner = blocking_pool.spawner().clone(); + + Ok(Runtime { + kind: Kind::CurrentThread(scheduler), + handle: Handle { + spawner, + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, + blocking_spawner, + }, + blocking_pool, + }) + } +} + +cfg_io_driver! { + impl Builder { + /// Enables the I/O driver. + /// + /// Doing this enables using net, process, signal, and some I/O types on + /// the runtime. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_io() + /// .build() + /// .unwrap(); + /// ``` + pub fn enable_io(&mut self) -> &mut Self { + self.enable_io = true; + self + } + } +} + +cfg_time! { + impl Builder { + /// Enables the time driver. + /// + /// Doing this enables using `tokio::time` on the runtime. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_time() + /// .build() + /// .unwrap(); + /// ``` + pub fn enable_time(&mut self) -> &mut Self { + self.enable_time = true; + self + } + } +} + +cfg_test_util! { + impl Builder { + /// Controls if the runtime's clock starts paused or advancing. + /// + /// Pausing time requires the current-thread runtime; construction of + /// the runtime will panic otherwise. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_current_thread() + /// .enable_time() + /// .start_paused(true) + /// .build() + /// .unwrap(); + /// ``` + pub fn start_paused(&mut self, start_paused: bool) -> &mut Self { + self.start_paused = start_paused; + self + } + } +} + +cfg_rt_multi_thread! { + impl Builder { + fn build_threaded_runtime(&mut self) -> io::Result<Runtime> { + use crate::loom::sys::num_cpus; + use crate::runtime::{Kind, ThreadPool}; + use crate::runtime::park::Parker; + + let core_threads = self.worker_threads.unwrap_or_else(num_cpus); + + let (driver, resources) = driver::Driver::new(self.get_cfg())?; + + let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver), self.before_park.clone(), self.after_unpark.clone()); + let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); + + // Create the blocking pool + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads + core_threads); + let blocking_spawner = blocking_pool.spawner().clone(); + + // Create the runtime handle + let handle = Handle { + spawner, + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, + blocking_spawner, + }; + + // Spawn the thread pool workers + let _enter = crate::runtime::context::enter(handle.clone()); + launch.launch(); + + Ok(Runtime { + kind: Kind::ThreadPool(scheduler), + handle, + blocking_pool, + }) + } + } +} + +impl fmt::Debug for Builder { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Builder") + .field("worker_threads", &self.worker_threads) + .field("max_blocking_threads", &self.max_blocking_threads) + .field( + "thread_name", + &"<dyn Fn() -> String + Send + Sync + 'static>", + ) + .field("thread_stack_size", &self.thread_stack_size) + .field("after_start", &self.after_start.as_ref().map(|_| "...")) + .field("before_stop", &self.before_stop.as_ref().map(|_| "...")) + .field("before_park", &self.before_park.as_ref().map(|_| "...")) + .field("after_unpark", &self.after_unpark.as_ref().map(|_| "...")) + .finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/context.rs b/third_party/rust/tokio/src/runtime/context.rs new file mode 100644 index 0000000000..1f44a53402 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/context.rs @@ -0,0 +1,111 @@ +//! Thread local runtime context +use crate::runtime::{Handle, TryCurrentError}; + +use std::cell::RefCell; + +thread_local! { + static CONTEXT: RefCell<Option<Handle>> = RefCell::new(None) +} + +pub(crate) fn try_current() -> Result<Handle, crate::runtime::TryCurrentError> { + match CONTEXT.try_with(|ctx| ctx.borrow().clone()) { + Ok(Some(handle)) => Ok(handle), + Ok(None) => Err(TryCurrentError::new_no_context()), + Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()), + } +} + +pub(crate) fn current() -> Handle { + match try_current() { + Ok(handle) => handle, + Err(e) => panic!("{}", e), + } +} + +cfg_io_driver! { + pub(crate) fn io_handle() -> crate::runtime::driver::IoHandle { + match CONTEXT.try_with(|ctx| { + let ctx = ctx.borrow(); + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).io_handle.clone() + }) { + Ok(io_handle) => io_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +cfg_signal_internal! { + #[cfg(unix)] + pub(crate) fn signal_handle() -> crate::runtime::driver::SignalHandle { + match CONTEXT.try_with(|ctx| { + let ctx = ctx.borrow(); + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).signal_handle.clone() + }) { + Ok(signal_handle) => signal_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +cfg_time! { + pub(crate) fn time_handle() -> crate::runtime::driver::TimeHandle { + match CONTEXT.try_with(|ctx| { + let ctx = ctx.borrow(); + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).time_handle.clone() + }) { + Ok(time_handle) => time_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } + + cfg_test_util! { + pub(crate) fn clock() -> Option<crate::runtime::driver::Clock> { + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.clock.clone())) { + Ok(clock) => clock, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } + } +} + +cfg_rt! { + pub(crate) fn spawn_handle() -> Option<crate::runtime::Spawner> { + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.spawner.clone())) { + Ok(spawner) => spawner, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +/// Sets this [`Handle`] as the current active [`Handle`]. +/// +/// [`Handle`]: Handle +pub(crate) fn enter(new: Handle) -> EnterGuard { + match try_enter(new) { + Some(guard) => guard, + None => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } +} + +/// Sets this [`Handle`] as the current active [`Handle`]. +/// +/// [`Handle`]: Handle +pub(crate) fn try_enter(new: Handle) -> Option<EnterGuard> { + CONTEXT + .try_with(|ctx| { + let old = ctx.borrow_mut().replace(new); + EnterGuard(old) + }) + .ok() +} + +#[derive(Debug)] +pub(crate) struct EnterGuard(Option<Handle>); + +impl Drop for EnterGuard { + fn drop(&mut self) { + CONTEXT.with(|ctx| { + *ctx.borrow_mut() = self.0.take(); + }); + } +} diff --git a/third_party/rust/tokio/src/runtime/driver.rs b/third_party/rust/tokio/src/runtime/driver.rs new file mode 100644 index 0000000000..7e459779bb --- /dev/null +++ b/third_party/rust/tokio/src/runtime/driver.rs @@ -0,0 +1,208 @@ +//! Abstracts out the entire chain of runtime sub-drivers into common types. +use crate::park::thread::ParkThread; +use crate::park::Park; + +use std::io; +use std::time::Duration; + +// ===== io driver ===== + +cfg_io_driver! { + type IoDriver = crate::io::driver::Driver; + type IoStack = crate::park::either::Either<ProcessDriver, ParkThread>; + pub(crate) type IoHandle = Option<crate::io::driver::Handle>; + + fn create_io_stack(enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + use crate::park::either::Either; + + #[cfg(loom)] + assert!(!enabled); + + let ret = if enabled { + let io_driver = crate::io::driver::Driver::new()?; + let io_handle = io_driver.handle(); + + let (signal_driver, signal_handle) = create_signal_driver(io_driver)?; + let process_driver = create_process_driver(signal_driver); + + (Either::A(process_driver), Some(io_handle), signal_handle) + } else { + (Either::B(ParkThread::new()), Default::default(), Default::default()) + }; + + Ok(ret) + } +} + +cfg_not_io_driver! { + pub(crate) type IoHandle = (); + type IoStack = ParkThread; + + fn create_io_stack(_enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + Ok((ParkThread::new(), Default::default(), Default::default())) + } +} + +// ===== signal driver ===== + +macro_rules! cfg_signal_internal_and_unix { + ($($item:item)*) => { + #[cfg(unix)] + cfg_signal_internal! { $($item)* } + } +} + +cfg_signal_internal_and_unix! { + type SignalDriver = crate::signal::unix::driver::Driver; + pub(crate) type SignalHandle = Option<crate::signal::unix::driver::Handle>; + + fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> { + let driver = crate::signal::unix::driver::Driver::new(io_driver)?; + let handle = driver.handle(); + Ok((driver, Some(handle))) + } +} + +cfg_not_signal_internal! { + pub(crate) type SignalHandle = (); + + cfg_io_driver! { + type SignalDriver = IoDriver; + + fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> { + Ok((io_driver, ())) + } + } +} + +// ===== process driver ===== + +cfg_process_driver! { + type ProcessDriver = crate::process::unix::driver::Driver; + + fn create_process_driver(signal_driver: SignalDriver) -> ProcessDriver { + crate::process::unix::driver::Driver::new(signal_driver) + } +} + +cfg_not_process_driver! { + cfg_io_driver! { + type ProcessDriver = SignalDriver; + + fn create_process_driver(signal_driver: SignalDriver) -> ProcessDriver { + signal_driver + } + } +} + +// ===== time driver ===== + +cfg_time! { + type TimeDriver = crate::park::either::Either<crate::time::driver::Driver<IoStack>, IoStack>; + + pub(crate) type Clock = crate::time::Clock; + pub(crate) type TimeHandle = Option<crate::time::driver::Handle>; + + fn create_clock(enable_pausing: bool, start_paused: bool) -> Clock { + crate::time::Clock::new(enable_pausing, start_paused) + } + + fn create_time_driver( + enable: bool, + io_stack: IoStack, + clock: Clock, + ) -> (TimeDriver, TimeHandle) { + use crate::park::either::Either; + + if enable { + let driver = crate::time::driver::Driver::new(io_stack, clock); + let handle = driver.handle(); + + (Either::A(driver), Some(handle)) + } else { + (Either::B(io_stack), None) + } + } +} + +cfg_not_time! { + type TimeDriver = IoStack; + + pub(crate) type Clock = (); + pub(crate) type TimeHandle = (); + + fn create_clock(_enable_pausing: bool, _start_paused: bool) -> Clock { + () + } + + fn create_time_driver( + _enable: bool, + io_stack: IoStack, + _clock: Clock, + ) -> (TimeDriver, TimeHandle) { + (io_stack, ()) + } +} + +// ===== runtime driver ===== + +#[derive(Debug)] +pub(crate) struct Driver { + inner: TimeDriver, +} + +pub(crate) struct Resources { + pub(crate) io_handle: IoHandle, + pub(crate) signal_handle: SignalHandle, + pub(crate) time_handle: TimeHandle, + pub(crate) clock: Clock, +} + +pub(crate) struct Cfg { + pub(crate) enable_io: bool, + pub(crate) enable_time: bool, + pub(crate) enable_pause_time: bool, + pub(crate) start_paused: bool, +} + +impl Driver { + pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Resources)> { + let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io)?; + + let clock = create_clock(cfg.enable_pause_time, cfg.start_paused); + + let (time_driver, time_handle) = + create_time_driver(cfg.enable_time, io_stack, clock.clone()); + + Ok(( + Self { inner: time_driver }, + Resources { + io_handle, + signal_handle, + time_handle, + clock, + }, + )) + } +} + +impl Park for Driver { + type Unpark = <TimeDriver as Park>::Unpark; + type Error = <TimeDriver as Park>::Error; + + fn unpark(&self) -> Self::Unpark { + self.inner.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park() + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.inner.park_timeout(duration) + } + + fn shutdown(&mut self) { + self.inner.shutdown() + } +} diff --git a/third_party/rust/tokio/src/runtime/enter.rs b/third_party/rust/tokio/src/runtime/enter.rs new file mode 100644 index 0000000000..3f14cb5878 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/enter.rs @@ -0,0 +1,205 @@ +use std::cell::{Cell, RefCell}; +use std::fmt; +use std::marker::PhantomData; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum EnterContext { + #[cfg_attr(not(feature = "rt"), allow(dead_code))] + Entered { + allow_blocking: bool, + }, + NotEntered, +} + +impl EnterContext { + pub(crate) fn is_entered(self) -> bool { + matches!(self, EnterContext::Entered { .. }) + } +} + +thread_local!(static ENTERED: Cell<EnterContext> = Cell::new(EnterContext::NotEntered)); + +/// Represents an executor context. +pub(crate) struct Enter { + _p: PhantomData<RefCell<()>>, +} + +cfg_rt! { + use crate::park::thread::ParkError; + + use std::time::Duration; + + /// Marks the current thread as being within the dynamic extent of an + /// executor. + pub(crate) fn enter(allow_blocking: bool) -> Enter { + if let Some(enter) = try_enter(allow_blocking) { + return enter; + } + + panic!( + "Cannot start a runtime from within a runtime. This happens \ + because a function (like `block_on`) attempted to block the \ + current thread while the thread is being used to drive \ + asynchronous tasks." + ); + } + + /// Tries to enter a runtime context, returns `None` if already in a runtime + /// context. + pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + c.set(EnterContext::Entered { allow_blocking }); + Some(Enter { _p: PhantomData }) + } + }) + } +} + +// Forces the current "entered" state to be cleared while the closure +// is executed. +// +// # Warning +// +// This is hidden for a reason. Do not use without fully understanding +// executors. Misusing can easily cause your program to deadlock. +cfg_rt_multi_thread! { + pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R { + // Reset in case the closure panics + struct Reset(EnterContext); + impl Drop for Reset { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(!c.get().is_entered(), "closure claimed permanent executor"); + c.set(self.0); + }); + } + } + + let was = ENTERED.with(|c| { + let e = c.get(); + assert!(e.is_entered(), "asked to exit when not entered"); + c.set(EnterContext::NotEntered); + e + }); + + let _reset = Reset(was); + // dropping _reset after f() will reset ENTERED + f() + } +} + +cfg_rt! { + /// Disallows blocking in the current runtime context until the guard is dropped. + pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { + let reset = ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: true, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: false, + }); + true + } else { + false + } + }); + DisallowBlockingGuard(reset) + } + + pub(crate) struct DisallowBlockingGuard(bool); + impl Drop for DisallowBlockingGuard { + fn drop(&mut self) { + if self.0 { + // XXX: Do we want some kind of assertion here, or is "best effort" okay? + ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: false, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: true, + }); + } + }) + } + } + } +} + +cfg_rt_multi_thread! { + /// Returns true if in a runtime context. + pub(crate) fn context() -> EnterContext { + ENTERED.with(|c| c.get()) + } +} + +cfg_rt! { + impl Enter { + /// Blocks the thread on the specified future, returning the value with + /// which that future completes. + pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, ParkError> + where + F: std::future::Future, + { + use crate::park::thread::CachedParkThread; + + let mut park = CachedParkThread::new(); + park.block_on(f) + } + + /// Blocks the thread on the specified future for **at most** `timeout` + /// + /// If the future completes before `timeout`, the result is returned. If + /// `timeout` elapses, then `Err` is returned. + pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ParkError> + where + F: std::future::Future, + { + use crate::park::Park; + use crate::park::thread::CachedParkThread; + use std::task::Context; + use std::task::Poll::Ready; + use std::time::Instant; + + let mut park = CachedParkThread::new(); + let waker = park.get_unpark()?.into_waker(); + let mut cx = Context::from_waker(&waker); + + pin!(f); + let when = Instant::now() + timeout; + + loop { + if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { + return Ok(v); + } + + let now = Instant::now(); + + if now >= when { + return Err(()); + } + + park.park_timeout(when - now)?; + } + } + } +} + +impl fmt::Debug for Enter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Enter").finish() + } +} + +impl Drop for Enter { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(c.get().is_entered()); + c.set(EnterContext::NotEntered); + }); + } +} diff --git a/third_party/rust/tokio/src/runtime/handle.rs b/third_party/rust/tokio/src/runtime/handle.rs new file mode 100644 index 0000000000..9dbe6774dd --- /dev/null +++ b/third_party/rust/tokio/src/runtime/handle.rs @@ -0,0 +1,435 @@ +use crate::runtime::blocking::{BlockingTask, NoopSchedule}; +use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::{blocking, context, driver, Spawner}; +use crate::util::error::{CONTEXT_MISSING_ERROR, THREAD_LOCAL_DESTROYED_ERROR}; + +use std::future::Future; +use std::marker::PhantomData; +use std::{error, fmt}; + +/// Handle to the runtime. +/// +/// The handle is internally reference-counted and can be freely cloned. A handle can be +/// obtained using the [`Runtime::handle`] method. +/// +/// [`Runtime::handle`]: crate::runtime::Runtime::handle() +#[derive(Debug, Clone)] +pub struct Handle { + pub(super) spawner: Spawner, + + /// Handles to the I/O drivers + #[cfg_attr( + not(any(feature = "net", feature = "process", all(unix, feature = "signal"))), + allow(dead_code) + )] + pub(super) io_handle: driver::IoHandle, + + /// Handles to the signal drivers + #[cfg_attr( + any( + loom, + not(all(unix, feature = "signal")), + not(all(unix, feature = "process")), + ), + allow(dead_code) + )] + pub(super) signal_handle: driver::SignalHandle, + + /// Handles to the time drivers + #[cfg_attr(not(feature = "time"), allow(dead_code))] + pub(super) time_handle: driver::TimeHandle, + + /// Source of `Instant::now()` + #[cfg_attr(not(all(feature = "time", feature = "test-util")), allow(dead_code))] + pub(super) clock: driver::Clock, + + /// Blocking pool spawner + pub(super) blocking_spawner: blocking::Spawner, +} + +/// Runtime context guard. +/// +/// Returned by [`Runtime::enter`] and [`Handle::enter`], the context guard exits +/// the runtime context on drop. +/// +/// [`Runtime::enter`]: fn@crate::runtime::Runtime::enter +#[derive(Debug)] +#[must_use = "Creating and dropping a guard does nothing"] +pub struct EnterGuard<'a> { + _guard: context::EnterGuard, + _handle_lifetime: PhantomData<&'a Handle>, +} + +impl Handle { + /// Enters the runtime context. This allows you to construct types that must + /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. + /// It will also allow you to call methods such as [`tokio::spawn`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + pub fn enter(&self) -> EnterGuard<'_> { + EnterGuard { + _guard: context::enter(self.clone()), + _handle_lifetime: PhantomData, + } + } + + /// Returns a `Handle` view over the currently running `Runtime`. + /// + /// # Panic + /// + /// This will panic if called outside the context of a Tokio runtime. That means that you must + /// call this on one of the threads **being run by the runtime**. Calling this from within a + /// thread created by `std::thread::spawn` (for example) will cause a panic. + /// + /// # Examples + /// + /// This can be used to obtain the handle of the surrounding runtime from an async + /// block or function running on that runtime. + /// + /// ``` + /// # use std::thread; + /// # use tokio::runtime::Runtime; + /// # fn dox() { + /// # let rt = Runtime::new().unwrap(); + /// # rt.spawn(async { + /// use tokio::runtime::Handle; + /// + /// // Inside an async block or function. + /// let handle = Handle::current(); + /// handle.spawn(async { + /// println!("now running in the existing Runtime"); + /// }); + /// + /// # let handle = + /// thread::spawn(move || { + /// // Notice that the handle is created outside of this thread and then moved in + /// handle.spawn(async { /* ... */ }) + /// // This next line would cause a panic + /// // let handle2 = Handle::current(); + /// }); + /// # handle.join().unwrap(); + /// # }); + /// # } + /// ``` + pub fn current() -> Self { + context::current() + } + + /// Returns a Handle view over the currently running Runtime + /// + /// Returns an error if no Runtime has been started + /// + /// Contrary to `current`, this never panics + pub fn try_current() -> Result<Self, TryCurrentError> { + context::try_current() + } + + /// Spawns a future onto the Tokio runtime. + /// + /// This spawns the given future onto the runtime's executor, usually a + /// thread pool. The thread pool is then responsible for polling the future + /// until it completes. + /// + /// See [module level][mod] documentation for more details. + /// + /// [mod]: index.html + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// // Get a handle from this runtime + /// let handle = rt.handle(); + /// + /// // Spawn a future onto the runtime using the handle + /// handle.spawn(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "task", None); + self.spawner.spawn(future) + } + + /// Runs the provided function on an executor dedicated to blocking. + /// operations. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// // Get a handle from this runtime + /// let handle = rt.handle(); + /// + /// // Spawn a blocking function onto the runtime using the handle + /// handle.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + #[track_caller] + pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, _was_spawned) = + if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 { + self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None) + } else { + self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None) + }; + + join_handle + } + + cfg_fs! { + #[track_caller] + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + pub(crate) fn spawn_mandatory_blocking<F, R>(&self, func: F) -> Option<JoinHandle<R>> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 { + self.spawn_blocking_inner( + Box::new(func), + blocking::Mandatory::Mandatory, + None + ) + } else { + self.spawn_blocking_inner( + func, + blocking::Mandatory::Mandatory, + None + ) + }; + + if was_spawned { + Some(join_handle) + } else { + None + } + } + } + + #[track_caller] + pub(crate) fn spawn_blocking_inner<F, R>( + &self, + func: F, + is_mandatory: blocking::Mandatory, + name: Option<&str>, + ) -> (JoinHandle<R>, bool) + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let fut = BlockingTask::new(func); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let fut = { + use tracing::Instrument; + let location = std::panic::Location::caller(); + let span = tracing::trace_span!( + target: "tokio::task::blocking", + "runtime.spawn", + kind = %"blocking", + task.name = %name.unwrap_or_default(), + "fn" = %std::any::type_name::<F>(), + spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), + ); + fut.instrument(span) + }; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let _ = name; + + let (task, handle) = task::unowned(fut, NoopSchedule); + let spawned = self + .blocking_spawner + .spawn(blocking::Task::new(task, is_mandatory), self); + (handle, spawned.is_ok()) + } + + /// Runs a future to completion on this `Handle`'s associated `Runtime`. + /// + /// This runs the given future on the current thread, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers which + /// the future spawns internally will be executed on the runtime. + /// + /// When this is used on a `current_thread` runtime, only the + /// [`Runtime::block_on`] method can drive the IO and timer drivers, but the + /// `Handle::block_on` method cannot drive them. This means that, when using + /// this method on a current_thread runtime, anything that relies on IO or + /// timers will not work unless there is another thread currently calling + /// [`Runtime::block_on`] on the same runtime. + /// + /// # If the runtime has been shut down + /// + /// If the `Handle`'s associated `Runtime` has been shut down (through + /// [`Runtime::shutdown_background`], [`Runtime::shutdown_timeout`], or by + /// dropping it) and `Handle::block_on` is used it might return an error or + /// panic. Specifically IO resources will return an error and timers will + /// panic. Runtime independent futures will run as normal. + /// + /// # Panics + /// + /// This function panics if the provided future panics, if called within an + /// asynchronous execution context, or if a timer future is executed on a + /// runtime that has been shut down. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Get a handle from this runtime + /// let handle = rt.handle(); + /// + /// // Execute the future, blocking the current thread until completion + /// handle.block_on(async { + /// println!("hello"); + /// }); + /// ``` + /// + /// Or using `Handle::current`: + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main () { + /// let handle = Handle::current(); + /// std::thread::spawn(move || { + /// // Using Handle::block_on to run async code in the new thread. + /// handle.block_on(async { + /// println!("hello"); + /// }); + /// }); + /// } + /// ``` + /// + /// [`JoinError`]: struct@crate::task::JoinError + /// [`JoinHandle`]: struct@crate::task::JoinHandle + /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on + /// [`Runtime::shutdown_background`]: fn@crate::runtime::Runtime::shutdown_background + /// [`Runtime::shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// [`spawn_blocking`]: crate::task::spawn_blocking + /// [`tokio::fs`]: crate::fs + /// [`tokio::net`]: crate::net + /// [`tokio::time`]: crate::time + #[track_caller] + pub fn block_on<F: Future>(&self, future: F) -> F::Output { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "block_on", None); + + // Enter the **runtime** context. This configures spawning, the current I/O driver, ... + let _rt_enter = self.enter(); + + // Enter a **blocking** context. This prevents blocking from a runtime. + let mut blocking_enter = crate::runtime::enter(true); + + // Block on the future + blocking_enter + .block_on(future) + .expect("failed to park thread") + } + + pub(crate) fn shutdown(mut self) { + self.spawner.shutdown(); + } +} + +cfg_metrics! { + use crate::runtime::RuntimeMetrics; + + impl Handle { + /// Returns a view that lets you get information about how the runtime + /// is performing. + pub fn metrics(&self) -> RuntimeMetrics { + RuntimeMetrics::new(self.clone()) + } + } +} + +/// Error returned by `try_current` when no Runtime has been started +#[derive(Debug)] +pub struct TryCurrentError { + kind: TryCurrentErrorKind, +} + +impl TryCurrentError { + pub(crate) fn new_no_context() -> Self { + Self { + kind: TryCurrentErrorKind::NoContext, + } + } + + pub(crate) fn new_thread_local_destroyed() -> Self { + Self { + kind: TryCurrentErrorKind::ThreadLocalDestroyed, + } + } + + /// Returns true if the call failed because there is currently no runtime in + /// the Tokio context. + pub fn is_missing_context(&self) -> bool { + matches!(self.kind, TryCurrentErrorKind::NoContext) + } + + /// Returns true if the call failed because the Tokio context thread-local + /// had been destroyed. This can usually only happen if in the destructor of + /// other thread-locals. + pub fn is_thread_local_destroyed(&self) -> bool { + matches!(self.kind, TryCurrentErrorKind::ThreadLocalDestroyed) + } +} + +enum TryCurrentErrorKind { + NoContext, + ThreadLocalDestroyed, +} + +impl fmt::Debug for TryCurrentErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use TryCurrentErrorKind::*; + match self { + NoContext => f.write_str("NoContext"), + ThreadLocalDestroyed => f.write_str("ThreadLocalDestroyed"), + } + } +} + +impl fmt::Display for TryCurrentError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use TryCurrentErrorKind::*; + match self.kind { + NoContext => f.write_str(CONTEXT_MISSING_ERROR), + ThreadLocalDestroyed => f.write_str(THREAD_LOCAL_DESTROYED_ERROR), + } + } +} + +impl error::Error for TryCurrentError {} diff --git a/third_party/rust/tokio/src/runtime/metrics/batch.rs b/third_party/rust/tokio/src/runtime/metrics/batch.rs new file mode 100644 index 0000000000..f1c3fa6b74 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/batch.rs @@ -0,0 +1,105 @@ +use crate::runtime::WorkerMetrics; + +use std::convert::TryFrom; +use std::sync::atomic::Ordering::Relaxed; +use std::time::Instant; + +pub(crate) struct MetricsBatch { + /// Number of times the worker parked. + park_count: u64, + + /// Number of times the worker woke w/o doing work. + noop_count: u64, + + /// Number of times stolen. + steal_count: u64, + + /// Number of tasks that were polled by the worker. + poll_count: u64, + + /// Number of tasks polled when the worker entered park. This is used to + /// track the noop count. + poll_count_on_last_park: u64, + + /// Number of tasks that were scheduled locally on this worker. + local_schedule_count: u64, + + /// Number of tasks moved to the global queue to make space in the local + /// queue + overflow_count: u64, + + /// The total busy duration in nanoseconds. + busy_duration_total: u64, + last_resume_time: Instant, +} + +impl MetricsBatch { + pub(crate) fn new() -> MetricsBatch { + MetricsBatch { + park_count: 0, + noop_count: 0, + steal_count: 0, + poll_count: 0, + poll_count_on_last_park: 0, + local_schedule_count: 0, + overflow_count: 0, + busy_duration_total: 0, + last_resume_time: Instant::now(), + } + } + + pub(crate) fn submit(&mut self, worker: &WorkerMetrics) { + worker.park_count.store(self.park_count, Relaxed); + worker.noop_count.store(self.noop_count, Relaxed); + worker.steal_count.store(self.steal_count, Relaxed); + worker.poll_count.store(self.poll_count, Relaxed); + + worker + .busy_duration_total + .store(self.busy_duration_total, Relaxed); + + worker + .local_schedule_count + .store(self.local_schedule_count, Relaxed); + worker.overflow_count.store(self.overflow_count, Relaxed); + } + + /// The worker is about to park. + pub(crate) fn about_to_park(&mut self) { + self.park_count += 1; + + if self.poll_count_on_last_park == self.poll_count { + self.noop_count += 1; + } else { + self.poll_count_on_last_park = self.poll_count; + } + + let busy_duration = self.last_resume_time.elapsed(); + let busy_duration = u64::try_from(busy_duration.as_nanos()).unwrap_or(u64::MAX); + self.busy_duration_total += busy_duration; + } + + pub(crate) fn returned_from_park(&mut self) { + self.last_resume_time = Instant::now(); + } + + pub(crate) fn inc_local_schedule_count(&mut self) { + self.local_schedule_count += 1; + } + + pub(crate) fn incr_poll_count(&mut self) { + self.poll_count += 1; + } +} + +cfg_rt_multi_thread! { + impl MetricsBatch { + pub(crate) fn incr_steal_count(&mut self, by: u16) { + self.steal_count += by as u64; + } + + pub(crate) fn incr_overflow_count(&mut self) { + self.overflow_count += 1; + } + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/mock.rs b/third_party/rust/tokio/src/runtime/metrics/mock.rs new file mode 100644 index 0000000000..6b9cf704f4 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/mock.rs @@ -0,0 +1,43 @@ +//! This file contains mocks of the types in src/runtime/metrics + +pub(crate) struct SchedulerMetrics {} + +pub(crate) struct WorkerMetrics {} + +pub(crate) struct MetricsBatch {} + +impl SchedulerMetrics { + pub(crate) fn new() -> Self { + Self {} + } + + /// Increment the number of tasks scheduled externally + pub(crate) fn inc_remote_schedule_count(&self) {} +} + +impl WorkerMetrics { + pub(crate) fn new() -> Self { + Self {} + } + + pub(crate) fn set_queue_depth(&self, _len: usize) {} +} + +impl MetricsBatch { + pub(crate) fn new() -> Self { + Self {} + } + + pub(crate) fn submit(&mut self, _to: &WorkerMetrics) {} + pub(crate) fn about_to_park(&mut self) {} + pub(crate) fn returned_from_park(&mut self) {} + pub(crate) fn incr_poll_count(&mut self) {} + pub(crate) fn inc_local_schedule_count(&mut self) {} +} + +cfg_rt_multi_thread! { + impl MetricsBatch { + pub(crate) fn incr_steal_count(&mut self, _by: u16) {} + pub(crate) fn incr_overflow_count(&mut self) {} + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/mod.rs b/third_party/rust/tokio/src/runtime/metrics/mod.rs new file mode 100644 index 0000000000..ca643a5904 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/mod.rs @@ -0,0 +1,30 @@ +//! This module contains information need to view information about how the +//! runtime is performing. +//! +//! **Note**: This is an [unstable API][unstable]. The public API of types in +//! this module may break in 1.x releases. See [the documentation on unstable +//! features][unstable] for details. +//! +//! [unstable]: crate#unstable-features +#![allow(clippy::module_inception)] + +cfg_metrics! { + mod batch; + pub(crate) use batch::MetricsBatch; + + mod runtime; + #[allow(unreachable_pub)] // rust-lang/rust#57411 + pub use runtime::RuntimeMetrics; + + mod scheduler; + pub(crate) use scheduler::SchedulerMetrics; + + mod worker; + pub(crate) use worker::WorkerMetrics; +} + +cfg_not_metrics! { + mod mock; + + pub(crate) use mock::{SchedulerMetrics, WorkerMetrics, MetricsBatch}; +} diff --git a/third_party/rust/tokio/src/runtime/metrics/runtime.rs b/third_party/rust/tokio/src/runtime/metrics/runtime.rs new file mode 100644 index 0000000000..0f8055907f --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/runtime.rs @@ -0,0 +1,449 @@ +use crate::runtime::Handle; + +use std::sync::atomic::Ordering::Relaxed; +use std::time::Duration; + +/// Handle to the runtime's metrics. +/// +/// This handle is internally reference-counted and can be freely cloned. A +/// `RuntimeMetrics` handle is obtained using the [`Runtime::metrics`] method. +/// +/// [`Runtime::metrics`]: crate::runtime::Runtime::metrics() +#[derive(Clone, Debug)] +pub struct RuntimeMetrics { + handle: Handle, +} + +impl RuntimeMetrics { + pub(crate) fn new(handle: Handle) -> RuntimeMetrics { + RuntimeMetrics { handle } + } + + /// Returns the number of worker threads used by the runtime. + /// + /// The number of workers is set by configuring `worker_threads` on + /// `runtime::Builder`. When using the `current_thread` runtime, the return + /// value is always `1`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.num_workers(); + /// println!("Runtime is using {} workers", n); + /// } + /// ``` + pub fn num_workers(&self) -> usize { + self.handle.spawner.num_workers() + } + + /// Returns the number of tasks scheduled from **outside** of the runtime. + /// + /// The remote schedule count starts at zero when the runtime is created and + /// increases by one each time a task is woken from **outside** of the + /// runtime. This usually means that a task is spawned or notified from a + /// non-runtime thread and must be queued using the Runtime's injection + /// queue, which tends to be slower. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.remote_schedule_count(); + /// println!("{} tasks were scheduled from outside the runtime", n); + /// } + /// ``` + pub fn remote_schedule_count(&self) -> u64 { + self.handle + .spawner + .scheduler_metrics() + .remote_schedule_count + .load(Relaxed) + } + + /// Returns the total number of times the given worker thread has parked. + /// + /// The worker park count starts at zero when the runtime is created and + /// increases by one each time the worker parks the thread waiting for new + /// inbound events to process. This usually means the worker has processed + /// all pending work and is currently idle. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_park_count(0); + /// println!("worker 0 parked {} times", n); + /// } + /// ``` + pub fn worker_park_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .park_count + .load(Relaxed) + } + + /// Returns the number of times the given worker thread unparked but + /// performed no work before parking again. + /// + /// The worker no-op count starts at zero when the runtime is created and + /// increases by one each time the worker unparks the thread but finds no + /// new work and goes back to sleep. This indicates a false-positive wake up. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_noop_count(0); + /// println!("worker 0 had {} no-op unparks", n); + /// } + /// ``` + pub fn worker_noop_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .noop_count + .load(Relaxed) + } + + /// Returns the number of times the given worker thread stole tasks from + /// another worker thread. + /// + /// This metric only applies to the **multi-threaded** runtime and will always return `0` when using the current thread runtime. + /// + /// The worker steal count starts at zero when the runtime is created and + /// increases by one each time the worker has processed its scheduled queue + /// and successfully steals more pending tasks from another worker. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_noop_count(0); + /// println!("worker 0 has stolen tasks {} times", n); + /// } + /// ``` + pub fn worker_steal_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .steal_count + .load(Relaxed) + } + + /// Returns the number of tasks the given worker thread has polled. + /// + /// The worker poll count starts at zero when the runtime is created and + /// increases by one each time the worker polls a scheduled task. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_poll_count(0); + /// println!("worker 0 has polled {} tasks", n); + /// } + /// ``` + pub fn worker_poll_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .poll_count + .load(Relaxed) + } + + /// Returns the amount of time the given worker thread has been busy. + /// + /// The worker busy duration starts at zero when the runtime is created and + /// increases whenever the worker is spending time processing work. Using + /// this value can indicate the load of the given worker. If a lot of time + /// is spent busy, then the worker is under load and will check for inbound + /// events less often. + /// + /// The timer is monotonically increasing. It is never decremented or reset + /// to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_poll_count(0); + /// println!("worker 0 has polled {} tasks", n); + /// } + /// ``` + pub fn worker_total_busy_duration(&self, worker: usize) -> Duration { + let nanos = self + .handle + .spawner + .worker_metrics(worker) + .busy_duration_total + .load(Relaxed); + Duration::from_nanos(nanos) + } + + /// Returns the number of tasks scheduled from **within** the runtime on the + /// given worker's local queue. + /// + /// The local schedule count starts at zero when the runtime is created and + /// increases by one each time a task is woken from **inside** of the + /// runtime on the given worker. This usually means that a task is spawned + /// or notified from within a runtime thread and will be queued on the + /// worker-local queue. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_local_schedule_count(0); + /// println!("{} tasks were scheduled on the worker's local queue", n); + /// } + /// ``` + pub fn worker_local_schedule_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .local_schedule_count + .load(Relaxed) + } + + /// Returns the number of times the given worker thread saturated its local + /// queue. + /// + /// This metric only applies to the **multi-threaded** scheduler. + /// + /// The worker steal count starts at zero when the runtime is created and + /// increases by one each time the worker attempts to schedule a task + /// locally, but its local queue is full. When this happens, half of the + /// local queue is moved to the injection queue. + /// + /// The counter is monotonically increasing. It is never decremented or + /// reset to zero. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_overflow_count(0); + /// println!("worker 0 has overflowed its queue {} times", n); + /// } + /// ``` + pub fn worker_overflow_count(&self, worker: usize) -> u64 { + self.handle + .spawner + .worker_metrics(worker) + .overflow_count + .load(Relaxed) + } + + /// Returns the number of tasks currently scheduled in the runtime's + /// injection queue. + /// + /// Tasks that are spanwed or notified from a non-runtime thread are + /// scheduled using the runtime's injection queue. This metric returns the + /// **current** number of tasks pending in the injection queue. As such, the + /// returned value may increase or decrease as new tasks are scheduled and + /// processed. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.injection_queue_depth(); + /// println!("{} tasks currently pending in the runtime's injection queue", n); + /// } + /// ``` + pub fn injection_queue_depth(&self) -> usize { + self.handle.spawner.injection_queue_depth() + } + + /// Returns the number of tasks currently scheduled in the given worker's + /// local queue. + /// + /// Tasks that are spawned or notified from within a runtime thread are + /// scheduled using that worker's local queue. This metric returns the + /// **current** number of tasks pending in the worker's local queue. As + /// such, the returned value may increase or decrease as new tasks are + /// scheduled and processed. + /// + /// # Arguments + /// + /// `worker` is the index of the worker being queried. The given value must + /// be between 0 and `num_workers()`. The index uniquely identifies a single + /// worker and will continue to indentify the worker throughout the lifetime + /// of the runtime instance. + /// + /// # Panics + /// + /// The method panics when `worker` represents an invalid worker, i.e. is + /// greater than or equal to `num_workers()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.worker_local_queue_depth(0); + /// println!("{} tasks currently pending in worker 0's local queue", n); + /// } + /// ``` + pub fn worker_local_queue_depth(&self, worker: usize) -> usize { + self.handle.spawner.worker_local_queue_depth(worker) + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/scheduler.rs b/third_party/rust/tokio/src/runtime/metrics/scheduler.rs new file mode 100644 index 0000000000..d1ba3b6442 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/scheduler.rs @@ -0,0 +1,27 @@ +use crate::loom::sync::atomic::{AtomicU64, Ordering::Relaxed}; + +/// Retrieves metrics from the Tokio runtime. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[derive(Debug)] +pub(crate) struct SchedulerMetrics { + /// Number of tasks that are scheduled from outside the runtime. + pub(super) remote_schedule_count: AtomicU64, +} + +impl SchedulerMetrics { + pub(crate) fn new() -> SchedulerMetrics { + SchedulerMetrics { + remote_schedule_count: AtomicU64::new(0), + } + } + + /// Increment the number of tasks scheduled externally + pub(crate) fn inc_remote_schedule_count(&self) { + self.remote_schedule_count.fetch_add(1, Relaxed); + } +} diff --git a/third_party/rust/tokio/src/runtime/metrics/worker.rs b/third_party/rust/tokio/src/runtime/metrics/worker.rs new file mode 100644 index 0000000000..c9b85e48e4 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/metrics/worker.rs @@ -0,0 +1,61 @@ +use crate::loom::sync::atomic::Ordering::Relaxed; +use crate::loom::sync::atomic::{AtomicU64, AtomicUsize}; + +/// Retreive runtime worker metrics. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[derive(Debug)] +#[repr(align(128))] +pub(crate) struct WorkerMetrics { + /// Number of times the worker parked. + pub(crate) park_count: AtomicU64, + + /// Number of times the worker woke then parked again without doing work. + pub(crate) noop_count: AtomicU64, + + /// Number of times the worker attempted to steal. + pub(crate) steal_count: AtomicU64, + + /// Number of tasks the worker polled. + pub(crate) poll_count: AtomicU64, + + /// Amount of time the worker spent doing work vs. parking. + pub(crate) busy_duration_total: AtomicU64, + + /// Number of tasks scheduled for execution on the worker's local queue. + pub(crate) local_schedule_count: AtomicU64, + + /// Number of tasks moved from the local queue to the global queue to free space. + pub(crate) overflow_count: AtomicU64, + + /// Number of tasks currently in the local queue. Used only by the + /// current-thread scheduler. + pub(crate) queue_depth: AtomicUsize, +} + +impl WorkerMetrics { + pub(crate) fn new() -> WorkerMetrics { + WorkerMetrics { + park_count: AtomicU64::new(0), + noop_count: AtomicU64::new(0), + steal_count: AtomicU64::new(0), + poll_count: AtomicU64::new(0), + overflow_count: AtomicU64::new(0), + busy_duration_total: AtomicU64::new(0), + local_schedule_count: AtomicU64::new(0), + queue_depth: AtomicUsize::new(0), + } + } + + pub(crate) fn queue_depth(&self) -> usize { + self.queue_depth.load(Relaxed) + } + + pub(crate) fn set_queue_depth(&self, len: usize) { + self.queue_depth.store(len, Relaxed); + } +} diff --git a/third_party/rust/tokio/src/runtime/mod.rs b/third_party/rust/tokio/src/runtime/mod.rs new file mode 100644 index 0000000000..7c381b0bbd --- /dev/null +++ b/third_party/rust/tokio/src/runtime/mod.rs @@ -0,0 +1,623 @@ +//! The Tokio runtime. +//! +//! Unlike other Rust programs, asynchronous applications require runtime +//! support. In particular, the following runtime services are necessary: +//! +//! * An **I/O event loop**, called the driver, which drives I/O resources and +//! dispatches I/O events to tasks that depend on them. +//! * A **scheduler** to execute [tasks] that use these I/O resources. +//! * A **timer** for scheduling work to run after a set period of time. +//! +//! Tokio's [`Runtime`] bundles all of these services as a single type, allowing +//! them to be started, shut down, and configured together. However, often it is +//! not required to configure a [`Runtime`] manually, and a user may just use the +//! [`tokio::main`] attribute macro, which creates a [`Runtime`] under the hood. +//! +//! # Usage +//! +//! When no fine tuning is required, the [`tokio::main`] attribute macro can be +//! used. +//! +//! ```no_run +//! use tokio::net::TcpListener; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; +//! +//! loop { +//! let (mut socket, _) = listener.accept().await?; +//! +//! tokio::spawn(async move { +//! let mut buf = [0; 1024]; +//! +//! // In a loop, read data from the socket and write the data back. +//! loop { +//! let n = match socket.read(&mut buf).await { +//! // socket closed +//! Ok(n) if n == 0 => return, +//! Ok(n) => n, +//! Err(e) => { +//! println!("failed to read from socket; err = {:?}", e); +//! return; +//! } +//! }; +//! +//! // Write the data back +//! if let Err(e) = socket.write_all(&buf[0..n]).await { +//! println!("failed to write to socket; err = {:?}", e); +//! return; +//! } +//! } +//! }); +//! } +//! } +//! ``` +//! +//! From within the context of the runtime, additional tasks are spawned using +//! the [`tokio::spawn`] function. Futures spawned using this function will be +//! executed on the same thread pool used by the [`Runtime`]. +//! +//! A [`Runtime`] instance can also be used directly. +//! +//! ```no_run +//! use tokio::net::TcpListener; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; +//! use tokio::runtime::Runtime; +//! +//! fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // Create the runtime +//! let rt = Runtime::new()?; +//! +//! // Spawn the root task +//! rt.block_on(async { +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; +//! +//! loop { +//! let (mut socket, _) = listener.accept().await?; +//! +//! tokio::spawn(async move { +//! let mut buf = [0; 1024]; +//! +//! // In a loop, read data from the socket and write the data back. +//! loop { +//! let n = match socket.read(&mut buf).await { +//! // socket closed +//! Ok(n) if n == 0 => return, +//! Ok(n) => n, +//! Err(e) => { +//! println!("failed to read from socket; err = {:?}", e); +//! return; +//! } +//! }; +//! +//! // Write the data back +//! if let Err(e) = socket.write_all(&buf[0..n]).await { +//! println!("failed to write to socket; err = {:?}", e); +//! return; +//! } +//! } +//! }); +//! } +//! }) +//! } +//! ``` +//! +//! ## Runtime Configurations +//! +//! Tokio provides multiple task scheduling strategies, suitable for different +//! applications. The [runtime builder] or `#[tokio::main]` attribute may be +//! used to select which scheduler to use. +//! +//! #### Multi-Thread Scheduler +//! +//! The multi-thread scheduler executes futures on a _thread pool_, using a +//! work-stealing strategy. By default, it will start a worker thread for each +//! CPU core available on the system. This tends to be the ideal configuration +//! for most applications. The multi-thread scheduler requires the `rt-multi-thread` +//! feature flag, and is selected by default: +//! ``` +//! use tokio::runtime; +//! +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let threaded_rt = runtime::Runtime::new()?; +//! # Ok(()) } +//! ``` +//! +//! Most applications should use the multi-thread scheduler, except in some +//! niche use-cases, such as when running only a single thread is required. +//! +//! #### Current-Thread Scheduler +//! +//! The current-thread scheduler provides a _single-threaded_ future executor. +//! All tasks will be created and executed on the current thread. This requires +//! the `rt` feature flag. +//! ``` +//! use tokio::runtime; +//! +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let basic_rt = runtime::Builder::new_current_thread() +//! .build()?; +//! # Ok(()) } +//! ``` +//! +//! #### Resource drivers +//! +//! When configuring a runtime by hand, no resource drivers are enabled by +//! default. In this case, attempting to use networking types or time types will +//! fail. In order to enable these types, the resource drivers must be enabled. +//! This is done with [`Builder::enable_io`] and [`Builder::enable_time`]. As a +//! shorthand, [`Builder::enable_all`] enables both resource drivers. +//! +//! ## Lifetime of spawned threads +//! +//! The runtime may spawn threads depending on its configuration and usage. The +//! multi-thread scheduler spawns threads to schedule tasks and for `spawn_blocking` +//! calls. +//! +//! While the `Runtime` is active, threads may shutdown after periods of being +//! idle. Once `Runtime` is dropped, all runtime threads are forcibly shutdown. +//! Any tasks that have not yet completed will be dropped. +//! +//! [tasks]: crate::task +//! [`Runtime`]: Runtime +//! [`tokio::spawn`]: crate::spawn +//! [`tokio::main`]: ../attr.main.html +//! [runtime builder]: crate::runtime::Builder +//! [`Runtime::new`]: crate::runtime::Runtime::new +//! [`Builder::basic_scheduler`]: crate::runtime::Builder::basic_scheduler +//! [`Builder::threaded_scheduler`]: crate::runtime::Builder::threaded_scheduler +//! [`Builder::enable_io`]: crate::runtime::Builder::enable_io +//! [`Builder::enable_time`]: crate::runtime::Builder::enable_time +//! [`Builder::enable_all`]: crate::runtime::Builder::enable_all + +// At the top due to macros +#[cfg(test)] +#[cfg(not(target_arch = "wasm32"))] +#[macro_use] +mod tests; + +pub(crate) mod enter; + +pub(crate) mod task; + +cfg_metrics! { + mod metrics; + pub use metrics::RuntimeMetrics; + + pub(crate) use metrics::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; +} + +cfg_not_metrics! { + pub(crate) mod metrics; + pub(crate) use metrics::{SchedulerMetrics, WorkerMetrics, MetricsBatch}; +} + +cfg_rt! { + mod basic_scheduler; + use basic_scheduler::BasicScheduler; + + mod blocking; + use blocking::BlockingPool; + pub(crate) use blocking::spawn_blocking; + + cfg_trace! { + pub(crate) use blocking::Mandatory; + } + + cfg_fs! { + pub(crate) use blocking::spawn_mandatory_blocking; + } + + mod builder; + pub use self::builder::Builder; + + pub(crate) mod context; + pub(crate) mod driver; + + use self::enter::enter; + + mod handle; + pub use handle::{EnterGuard, Handle, TryCurrentError}; + + mod spawner; + use self::spawner::Spawner; +} + +cfg_rt_multi_thread! { + mod park; + use park::Parker; +} + +cfg_rt_multi_thread! { + mod queue; + + pub(crate) mod thread_pool; + use self::thread_pool::ThreadPool; +} + +cfg_rt! { + use crate::task::JoinHandle; + + use std::future::Future; + use std::time::Duration; + + /// The Tokio runtime. + /// + /// The runtime provides an I/O driver, task scheduler, [timer], and + /// blocking pool, necessary for running asynchronous tasks. + /// + /// Instances of `Runtime` can be created using [`new`], or [`Builder`]. + /// However, most users will use the `#[tokio::main]` annotation on their + /// entry point instead. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Shutdown + /// + /// Shutting down the runtime is done by dropping the value. The current + /// thread will block until the shut down operation has completed. + /// + /// * Drain any scheduled work queues. + /// * Drop any futures that have not yet completed. + /// * Drop the reactor. + /// + /// Once the reactor has dropped, any outstanding I/O resources bound to + /// that reactor will no longer function. Calling any method on them will + /// result in an error. + /// + /// # Sharing + /// + /// The Tokio runtime implements `Sync` and `Send` to allow you to wrap it + /// in a `Arc`. Most fn take `&self` to allow you to call them concurrently + /// across multiple threads. + /// + /// Calls to `shutdown` and `shutdown_timeout` require exclusive ownership of + /// the runtime type and this can be achieved via `Arc::try_unwrap` when only + /// one strong count reference is left over. + /// + /// [timer]: crate::time + /// [mod]: index.html + /// [`new`]: method@Self::new + /// [`Builder`]: struct@Builder + #[derive(Debug)] + pub struct Runtime { + /// Task executor + kind: Kind, + + /// Handle to runtime, also contains driver handles + handle: Handle, + + /// Blocking pool handle, used to signal shutdown + blocking_pool: BlockingPool, + } + + /// The runtime executor is either a thread-pool or a current-thread executor. + #[derive(Debug)] + enum Kind { + /// Execute all tasks on the current-thread. + CurrentThread(BasicScheduler), + + /// Execute tasks across multiple threads. + #[cfg(feature = "rt-multi-thread")] + ThreadPool(ThreadPool), + } + + /// After thread starts / before thread stops + type Callback = std::sync::Arc<dyn Fn() + Send + Sync>; + + impl Runtime { + /// Creates a new runtime instance with default configuration values. + /// + /// This results in the multi threaded scheduler, I/O driver, and time driver being + /// initialized. + /// + /// Most applications will not need to call this function directly. Instead, + /// they will use the [`#[tokio::main]` attribute][main]. When a more complex + /// configuration is necessary, the [runtime builder] may be used. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Examples + /// + /// Creating a new `Runtime` with default configuration values. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// let rt = Runtime::new() + /// .unwrap(); + /// + /// // Use the runtime... + /// ``` + /// + /// [mod]: index.html + /// [main]: ../attr.main.html + /// [threaded scheduler]: index.html#threaded-scheduler + /// [basic scheduler]: index.html#basic-scheduler + /// [runtime builder]: crate::runtime::Builder + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + pub fn new() -> std::io::Result<Runtime> { + Builder::new_multi_thread().enable_all().build() + } + + /// Returns a handle to the runtime's spawner. + /// + /// The returned handle can be used to spawn tasks that run on this runtime, and can + /// be cloned to allow moving the `Handle` to other threads. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// let rt = Runtime::new() + /// .unwrap(); + /// + /// let handle = rt.handle(); + /// + /// // Use the handle... + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Spawns a future onto the Tokio runtime. + /// + /// This spawns the given future onto the runtime's executor, usually a + /// thread pool. The thread pool is then responsible for polling the future + /// until it completes. + /// + /// See [module level][mod] documentation for more details. + /// + /// [mod]: index.html + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Spawn a future onto the runtime + /// rt.spawn(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + #[track_caller] + pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.handle.spawn(future) + } + + /// Runs the provided function on an executor dedicated to blocking operations. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Spawn a blocking function onto the runtime + /// rt.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + #[track_caller] + pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Runs a future to completion on the Tokio runtime. This is the + /// runtime's entry point. + /// + /// This runs the given future on the current thread, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers + /// which the future spawns internally will be executed on the runtime. + /// + /// # Multi thread scheduler + /// + /// When the multi thread scheduler is used this will allow futures + /// to run within the io driver and timer context of the overall runtime. + /// + /// # Current thread scheduler + /// + /// When the current thread scheduler is enabled `block_on` + /// can be called concurrently from multiple threads. The first call + /// will take ownership of the io and timer drivers. This means + /// other threads which do not own the drivers will hook into that one. + /// When the first `block_on` completes, other threads will be able to + /// "steal" the driver to allow continued execution of their futures. + /// + /// # Panics + /// + /// This function panics if the provided future panics, or if called within an + /// asynchronous execution context. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::Runtime; + /// + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + /// + /// [handle]: fn@Handle::block_on + #[track_caller] + pub fn block_on<F: Future>(&self, future: F) -> F::Output { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "block_on", None); + + let _enter = self.enter(); + + match &self.kind { + Kind::CurrentThread(exec) => exec.block_on(future), + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(exec) => exec.block_on(future), + } + } + + /// Enters the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn function_that_spawns(msg: String) { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }); + /// } + /// + /// fn main() { + /// let rt = Runtime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// let _guard = rt.enter(); + /// function_that_spawns(s); + /// } + /// ``` + pub fn enter(&self) -> EnterGuard<'_> { + self.handle.enter() + } + + /// Shuts down the runtime, waiting for at most `duration` for all spawned + /// task to shutdown. + /// + /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to + /// shutdown in a timely fashion. However, dropping a `Runtime` will wait + /// indefinitely for all tasks to terminate, and there are cases where a long + /// blocking task has been spawned, which can block dropping `Runtime`. + /// + /// In this case, calling `shutdown_timeout` with an explicit wait timeout + /// can work. The `shutdown_timeout` will signal all tasks to shutdown and + /// will wait for at most `duration` for all spawned tasks to terminate. If + /// `timeout` elapses before all tasks are dropped, the function returns and + /// outstanding tasks are potentially leaked. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// use tokio::task; + /// + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// task::spawn_blocking(move || { + /// thread::sleep(Duration::from_secs(10_000)); + /// }); + /// }); + /// + /// runtime.shutdown_timeout(Duration::from_millis(100)); + /// } + /// ``` + pub fn shutdown_timeout(mut self, duration: Duration) { + // Wakeup and shutdown all the worker threads + self.handle.clone().shutdown(); + self.blocking_pool.shutdown(Some(duration)); + } + + /// Shuts down the runtime, without waiting for any spawned tasks to shutdown. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn main() { + /// let runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = Runtime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)) + } + } + + #[allow(clippy::single_match)] // there are comments in the error branch, so we don't want if-let + impl Drop for Runtime { + fn drop(&mut self) { + match &mut self.kind { + Kind::CurrentThread(basic) => { + // This ensures that tasks spawned on the basic runtime are dropped inside the + // runtime's context. + match self::context::try_enter(self.handle.clone()) { + Some(guard) => basic.set_context_guard(guard), + None => { + // The context thread-local has already been destroyed. + // + // We don't set the guard in this case. Calls to tokio::spawn in task + // destructors would fail regardless if this happens. + }, + } + }, + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(_) => { + // The threaded scheduler drops its tasks on its worker threads, which is + // already in the runtime's context. + }, + } + } + } + + cfg_metrics! { + impl Runtime { + /// TODO + pub fn metrics(&self) -> RuntimeMetrics { + self.handle.metrics() + } + } + } +} diff --git a/third_party/rust/tokio/src/runtime/park.rs b/third_party/rust/tokio/src/runtime/park.rs new file mode 100644 index 0000000000..033b9f20be --- /dev/null +++ b/third_party/rust/tokio/src/runtime/park.rs @@ -0,0 +1,257 @@ +//! Parks the runtime. +//! +//! A combination of the various resource driver park handles. + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::loom::thread; +use crate::park::{Park, Unpark}; +use crate::runtime::driver::Driver; +use crate::util::TryLock; + +use std::sync::atomic::Ordering::SeqCst; +use std::time::Duration; + +pub(crate) struct Parker { + inner: Arc<Inner>, +} + +pub(crate) struct Unparker { + inner: Arc<Inner>, +} + +struct Inner { + /// Avoids entering the park if possible + state: AtomicUsize, + + /// Used to coordinate access to the driver / condvar + mutex: Mutex<()>, + + /// Condvar to block on if the driver is unavailable. + condvar: Condvar, + + /// Resource (I/O, time, ...) driver + shared: Arc<Shared>, +} + +const EMPTY: usize = 0; +const PARKED_CONDVAR: usize = 1; +const PARKED_DRIVER: usize = 2; +const NOTIFIED: usize = 3; + +/// Shared across multiple Parker handles +struct Shared { + /// Shared driver. Only one thread at a time can use this + driver: TryLock<Driver>, + + /// Unpark handle + handle: <Driver as Park>::Unpark, +} + +impl Parker { + pub(crate) fn new(driver: Driver) -> Parker { + let handle = driver.unpark(); + + Parker { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + mutex: Mutex::new(()), + condvar: Condvar::new(), + shared: Arc::new(Shared { + driver: TryLock::new(driver), + handle, + }), + }), + } + } +} + +impl Clone for Parker { + fn clone(&self) -> Parker { + Parker { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + mutex: Mutex::new(()), + condvar: Condvar::new(), + shared: self.inner.shared.clone(), + }), + } + } +} + +impl Park for Parker { + type Unpark = Unparker; + type Error = (); + + fn unpark(&self) -> Unparker { + Unparker { + inner: self.inner.clone(), + } + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park(); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + // Only parking with zero is supported... + assert_eq!(duration, Duration::from_millis(0)); + + if let Some(mut driver) = self.inner.shared.driver.try_lock() { + driver.park_timeout(duration).map_err(|_| ()) + } else { + Ok(()) + } + } + + fn shutdown(&mut self) { + self.inner.shutdown(); + } +} + +impl Unpark for Unparker { + fn unpark(&self) { + self.inner.unpark(); + } +} + +impl Inner { + /// Parks the current thread for at most `dur`. + fn park(&self) { + for _ in 0..3 { + // If we were previously notified then we consume this notification and + // return quickly. + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + return; + } + + thread::yield_now(); + } + + if let Some(mut driver) = self.shared.driver.try_lock() { + self.park_driver(&mut driver); + } else { + self.park_condvar(); + } + } + + fn park_condvar(&self) { + // Otherwise we need to coordinate going to sleep + let mut m = self.mutex.lock(); + + match self + .state + .compare_exchange(EMPTY, PARKED_CONDVAR, SeqCst, SeqCst) + { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read here, even though we know it will be `NOTIFIED`. + // This is because `unpark` may have been called again since we read + // `NOTIFIED` in the `compare_exchange` above. We must perform an + // acquire operation that synchronizes with that `unpark` to observe + // any writes it made before the call to unpark. To do that we must + // read from the write it made to `state`. + let old = self.state.swap(EMPTY, SeqCst); + debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + + return; + } + Err(actual) => panic!("inconsistent park state; actual = {}", actual), + } + + loop { + m = self.condvar.wait(m).unwrap(); + + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + // got a notification + return; + } + + // spurious wakeup, go back to sleep + } + } + + fn park_driver(&self, driver: &mut Driver) { + match self + .state + .compare_exchange(EMPTY, PARKED_DRIVER, SeqCst, SeqCst) + { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read here, even though we know it will be `NOTIFIED`. + // This is because `unpark` may have been called again since we read + // `NOTIFIED` in the `compare_exchange` above. We must perform an + // acquire operation that synchronizes with that `unpark` to observe + // any writes it made before the call to unpark. To do that we must + // read from the write it made to `state`. + let old = self.state.swap(EMPTY, SeqCst); + debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + + return; + } + Err(actual) => panic!("inconsistent park state; actual = {}", actual), + } + + // TODO: don't unwrap + driver.park().unwrap(); + + match self.state.swap(EMPTY, SeqCst) { + NOTIFIED => {} // got a notification, hurray! + PARKED_DRIVER => {} // no notification, alas + n => panic!("inconsistent park_timeout state: {}", n), + } + } + + fn unpark(&self) { + // To ensure the unparked thread will observe any writes we made before + // this call, we must perform a release operation that `park` can + // synchronize with. To do that we must write `NOTIFIED` even if `state` + // is already `NOTIFIED`. That is why this must be a swap rather than a + // compare-and-swap that returns if it reads `NOTIFIED` on failure. + match self.state.swap(NOTIFIED, SeqCst) { + EMPTY => {} // no one was waiting + NOTIFIED => {} // already unparked + PARKED_CONDVAR => self.unpark_condvar(), + PARKED_DRIVER => self.unpark_driver(), + actual => panic!("inconsistent state in unpark; actual = {}", actual), + } + } + + fn unpark_condvar(&self) { + // There is a period between when the parked thread sets `state` to + // `PARKED` (or last checked `state` in the case of a spurious wake + // up) and when it actually waits on `cvar`. If we were to notify + // during this period it would be ignored and then when the parked + // thread went to sleep it would never wake up. Fortunately, it has + // `lock` locked at this stage so we can acquire `lock` to wait until + // it is ready to receive the notification. + // + // Releasing `lock` before the call to `notify_one` means that when the + // parked thread wakes it doesn't get woken only to have to wait for us + // to release `lock`. + drop(self.mutex.lock()); + + self.condvar.notify_one() + } + + fn unpark_driver(&self) { + self.shared.handle.unpark(); + } + + fn shutdown(&self) { + if let Some(mut driver) = self.shared.driver.try_lock() { + driver.shutdown(); + } + + self.condvar.notify_all(); + } +} diff --git a/third_party/rust/tokio/src/runtime/queue.rs b/third_party/rust/tokio/src/runtime/queue.rs new file mode 100644 index 0000000000..ad9085a654 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/queue.rs @@ -0,0 +1,511 @@ +//! Run-queue structures to support a work-stealing scheduler + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicU16, AtomicU32}; +use crate::loom::sync::Arc; +use crate::runtime::task::{self, Inject}; +use crate::runtime::MetricsBatch; + +use std::mem::MaybeUninit; +use std::ptr; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; + +/// Producer handle. May only be used from a single thread. +pub(super) struct Local<T: 'static> { + inner: Arc<Inner<T>>, +} + +/// Consumer handle. May be used from many threads. +pub(super) struct Steal<T: 'static>(Arc<Inner<T>>); + +pub(super) struct Inner<T: 'static> { + /// Concurrently updated by many threads. + /// + /// Contains two `u16` values. The LSB byte is the "real" head of the queue. + /// The `u16` in the MSB is set by a stealer in process of stealing values. + /// It represents the first value being stolen in the batch. `u16` is used + /// in order to distinguish between `head == tail` and `head == tail - + /// capacity`. + /// + /// When both `u16` values are the same, there is no active stealer. + /// + /// Tracking an in-progress stealer prevents a wrapping scenario. + head: AtomicU32, + + /// Only updated by producer thread but read by many threads. + tail: AtomicU16, + + /// Elements + buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>, +} + +unsafe impl<T> Send for Inner<T> {} +unsafe impl<T> Sync for Inner<T> {} + +#[cfg(not(loom))] +const LOCAL_QUEUE_CAPACITY: usize = 256; + +// Shrink the size of the local queue when using loom. This shouldn't impact +// logic, but allows loom to test more edge cases in a reasonable a mount of +// time. +#[cfg(loom)] +const LOCAL_QUEUE_CAPACITY: usize = 4; + +const MASK: usize = LOCAL_QUEUE_CAPACITY - 1; + +// Constructing the fixed size array directly is very awkward. The only way to +// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as +// the contents are not Copy. The trick with defining a const doesn't work for +// generic types. +fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> { + assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY); + + // safety: We check that the length is correct. + unsafe { Box::from_raw(Box::into_raw(buffer).cast()) } +} + +/// Create a new local run-queue +pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) { + let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY); + + for _ in 0..LOCAL_QUEUE_CAPACITY { + buffer.push(UnsafeCell::new(MaybeUninit::uninit())); + } + + let inner = Arc::new(Inner { + head: AtomicU32::new(0), + tail: AtomicU16::new(0), + buffer: make_fixed_size(buffer.into_boxed_slice()), + }); + + let local = Local { + inner: inner.clone(), + }; + + let remote = Steal(inner); + + (remote, local) +} + +impl<T> Local<T> { + /// Returns true if the queue has entries that can be stealed. + pub(super) fn is_stealable(&self) -> bool { + !self.inner.is_empty() + } + + /// Returns false if there are any entries in the queue + /// + /// Separate to is_stealable so that refactors of is_stealable to "protect" + /// some tasks from stealing won't affect this + pub(super) fn has_tasks(&self) -> bool { + !self.inner.is_empty() + } + + /// Pushes a task to the back of the local queue, skipping the LIFO slot. + pub(super) fn push_back( + &mut self, + mut task: task::Notified<T>, + inject: &Inject<T>, + metrics: &mut MetricsBatch, + ) { + let tail = loop { + let head = self.inner.head.load(Acquire); + let (steal, real) = unpack(head); + + // safety: this is the **only** thread that updates this cell. + let tail = unsafe { self.inner.tail.unsync_load() }; + + if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 { + // There is capacity for the task + break tail; + } else if steal != real { + // Concurrently stealing, this will free up capacity, so only + // push the task onto the inject queue + inject.push(task); + return; + } else { + // Push the current task and half of the queue into the + // inject queue. + match self.push_overflow(task, real, tail, inject, metrics) { + Ok(_) => return, + // Lost the race, try again + Err(v) => { + task = v; + } + } + } + }; + + // Map the position to a slot index. + let idx = tail as usize & MASK; + + self.inner.buffer[idx].with_mut(|ptr| { + // Write the task to the slot + // + // Safety: There is only one producer and the above `if` + // condition ensures we don't touch a cell if there is a + // value, thus no consumer. + unsafe { + ptr::write((*ptr).as_mut_ptr(), task); + } + }); + + // Make the task available. Synchronizes with a load in + // `steal_into2`. + self.inner.tail.store(tail.wrapping_add(1), Release); + } + + /// Moves a batch of tasks into the inject queue. + /// + /// This will temporarily make some of the tasks unavailable to stealers. + /// Once `push_overflow` is done, a notification is sent out, so if other + /// workers "missed" some of the tasks during a steal, they will get + /// another opportunity. + #[inline(never)] + fn push_overflow( + &mut self, + task: task::Notified<T>, + head: u16, + tail: u16, + inject: &Inject<T>, + metrics: &mut MetricsBatch, + ) -> Result<(), task::Notified<T>> { + /// How many elements are we taking from the local queue. + /// + /// This is one less than the number of tasks pushed to the inject + /// queue as we are also inserting the `task` argument. + const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16; + + assert_eq!( + tail.wrapping_sub(head) as usize, + LOCAL_QUEUE_CAPACITY, + "queue is not full; tail = {}; head = {}", + tail, + head + ); + + let prev = pack(head, head); + + // Claim a bunch of tasks + // + // We are claiming the tasks **before** reading them out of the buffer. + // This is safe because only the **current** thread is able to push new + // tasks. + // + // There isn't really any need for memory ordering... Relaxed would + // work. This is because all tasks are pushed into the queue from the + // current thread (or memory has been acquired if the local queue handle + // moved). + if self + .inner + .head + .compare_exchange( + prev, + pack( + head.wrapping_add(NUM_TASKS_TAKEN), + head.wrapping_add(NUM_TASKS_TAKEN), + ), + Release, + Relaxed, + ) + .is_err() + { + // We failed to claim the tasks, losing the race. Return out of + // this function and try the full `push` routine again. The queue + // may not be full anymore. + return Err(task); + } + + /// An iterator that takes elements out of the run queue. + struct BatchTaskIter<'a, T: 'static> { + buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY], + head: u32, + i: u32, + } + impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> { + type Item = task::Notified<T>; + + #[inline] + fn next(&mut self) -> Option<task::Notified<T>> { + if self.i == u32::from(NUM_TASKS_TAKEN) { + None + } else { + let i_idx = self.i.wrapping_add(self.head) as usize & MASK; + let slot = &self.buffer[i_idx]; + + // safety: Our CAS from before has assumed exclusive ownership + // of the task pointers in this range. + let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + self.i += 1; + Some(task) + } + } + } + + // safety: The CAS above ensures that no consumer will look at these + // values again, and we are the only producer. + let batch_iter = BatchTaskIter { + buffer: &*self.inner.buffer, + head: head as u32, + i: 0, + }; + inject.push_batch(batch_iter.chain(std::iter::once(task))); + + // Add 1 to factor in the task currently being scheduled. + metrics.incr_overflow_count(); + + Ok(()) + } + + /// Pops a task from the local queue. + pub(super) fn pop(&mut self) -> Option<task::Notified<T>> { + let mut head = self.inner.head.load(Acquire); + + let idx = loop { + let (steal, real) = unpack(head); + + // safety: this is the **only** thread that updates this cell. + let tail = unsafe { self.inner.tail.unsync_load() }; + + if real == tail { + // queue is empty + return None; + } + + let next_real = real.wrapping_add(1); + + // If `steal == real` there are no concurrent stealers. Both `steal` + // and `real` are updated. + let next = if steal == real { + pack(next_real, next_real) + } else { + assert_ne!(steal, next_real); + pack(steal, next_real) + }; + + // Attempt to claim a task. + let res = self + .inner + .head + .compare_exchange(head, next, AcqRel, Acquire); + + match res { + Ok(_) => break real as usize & MASK, + Err(actual) => head = actual, + } + }; + + Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() })) + } +} + +impl<T> Steal<T> { + pub(super) fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Steals half the tasks from self and place them into `dst`. + pub(super) fn steal_into( + &self, + dst: &mut Local<T>, + dst_metrics: &mut MetricsBatch, + ) -> Option<task::Notified<T>> { + // Safety: the caller is the only thread that mutates `dst.tail` and + // holds a mutable reference. + let dst_tail = unsafe { dst.inner.tail.unsync_load() }; + + // To the caller, `dst` may **look** empty but still have values + // contained in the buffer. If another thread is concurrently stealing + // from `dst` there may not be enough capacity to steal. + let (steal, _) = unpack(dst.inner.head.load(Acquire)); + + if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 { + // we *could* try to steal less here, but for simplicity, we're just + // going to abort. + return None; + } + + // Steal the tasks into `dst`'s buffer. This does not yet expose the + // tasks in `dst`. + let mut n = self.steal_into2(dst, dst_tail); + + if n == 0 { + // No tasks were stolen + return None; + } + + dst_metrics.incr_steal_count(n); + + // We are returning a task here + n -= 1; + + let ret_pos = dst_tail.wrapping_add(n); + let ret_idx = ret_pos as usize & MASK; + + // safety: the value was written as part of `steal_into2` and not + // exposed to stealers, so no other thread can access it. + let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + if n == 0 { + // The `dst` queue is empty, but a single task was stolen + return Some(ret); + } + + // Make the stolen items available to consumers + dst.inner.tail.store(dst_tail.wrapping_add(n), Release); + + Some(ret) + } + + // Steal tasks from `self`, placing them into `dst`. Returns the number of + // tasks that were stolen. + fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 { + let mut prev_packed = self.0.head.load(Acquire); + let mut next_packed; + + let n = loop { + let (src_head_steal, src_head_real) = unpack(prev_packed); + let src_tail = self.0.tail.load(Acquire); + + // If these two do not match, another thread is concurrently + // stealing from the queue. + if src_head_steal != src_head_real { + return 0; + } + + // Number of available tasks to steal + let n = src_tail.wrapping_sub(src_head_real); + let n = n - n / 2; + + if n == 0 { + // No tasks available to steal + return 0; + } + + // Update the real head index to acquire the tasks. + let steal_to = src_head_real.wrapping_add(n); + assert_ne!(src_head_steal, steal_to); + next_packed = pack(src_head_steal, steal_to); + + // Claim all those tasks. This is done by incrementing the "real" + // head but not the steal. By doing this, no other thread is able to + // steal from this queue until the current thread completes. + let res = self + .0 + .head + .compare_exchange(prev_packed, next_packed, AcqRel, Acquire); + + match res { + Ok(_) => break n, + Err(actual) => prev_packed = actual, + } + }; + + assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n); + + let (first, _) = unpack(next_packed); + + // Take all the tasks + for i in 0..n { + // Compute the positions + let src_pos = first.wrapping_add(i); + let dst_pos = dst_tail.wrapping_add(i); + + // Map to slots + let src_idx = src_pos as usize & MASK; + let dst_idx = dst_pos as usize & MASK; + + // Read the task + // + // safety: We acquired the task with the atomic exchange above. + let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + // Write the task to the new slot + // + // safety: `dst` queue is empty and we are the only producer to + // this queue. + dst.inner.buffer[dst_idx] + .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) }); + } + + let mut prev_packed = next_packed; + + // Update `src_head_steal` to match `src_head_real` signalling that the + // stealing routine is complete. + loop { + let head = unpack(prev_packed).1; + next_packed = pack(head, head); + + let res = self + .0 + .head + .compare_exchange(prev_packed, next_packed, AcqRel, Acquire); + + match res { + Ok(_) => return n, + Err(actual) => { + let (actual_steal, actual_real) = unpack(actual); + + assert_ne!(actual_steal, actual_real); + + prev_packed = actual; + } + } + } + } +} + +cfg_metrics! { + impl<T> Steal<T> { + pub(crate) fn len(&self) -> usize { + self.0.len() as _ + } + } +} + +impl<T> Clone for Steal<T> { + fn clone(&self) -> Steal<T> { + Steal(self.0.clone()) + } +} + +impl<T> Drop for Local<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +impl<T> Inner<T> { + fn len(&self) -> u16 { + let (_, head) = unpack(self.head.load(Acquire)); + let tail = self.tail.load(Acquire); + + tail.wrapping_sub(head) + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// Split the head value into the real head and the index a stealer is working +/// on. +fn unpack(n: u32) -> (u16, u16) { + let real = n & u16::MAX as u32; + let steal = n >> 16; + + (steal as u16, real as u16) +} + +/// Join the two head values +fn pack(steal: u16, real: u16) -> u32 { + (real as u32) | ((steal as u32) << 16) +} + +#[test] +fn test_local_queue_capacity() { + assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize); +} diff --git a/third_party/rust/tokio/src/runtime/spawner.rs b/third_party/rust/tokio/src/runtime/spawner.rs new file mode 100644 index 0000000000..d81a806cb5 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/spawner.rs @@ -0,0 +1,83 @@ +use crate::future::Future; +use crate::runtime::basic_scheduler; +use crate::task::JoinHandle; + +cfg_rt_multi_thread! { + use crate::runtime::thread_pool; +} + +#[derive(Debug, Clone)] +pub(crate) enum Spawner { + Basic(basic_scheduler::Spawner), + #[cfg(feature = "rt-multi-thread")] + ThreadPool(thread_pool::Spawner), +} + +impl Spawner { + pub(crate) fn shutdown(&mut self) { + #[cfg(feature = "rt-multi-thread")] + { + if let Spawner::ThreadPool(spawner) = self { + spawner.shutdown(); + } + } + } + + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + match self { + Spawner::Basic(spawner) => spawner.spawn(future), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.spawn(future), + } + } +} + +cfg_metrics! { + use crate::runtime::{SchedulerMetrics, WorkerMetrics}; + + impl Spawner { + pub(crate) fn num_workers(&self) -> usize { + match self { + Spawner::Basic(_) => 1, + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.num_workers(), + } + } + + pub(crate) fn scheduler_metrics(&self) -> &SchedulerMetrics { + match self { + Spawner::Basic(spawner) => spawner.scheduler_metrics(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.scheduler_metrics(), + } + } + + pub(crate) fn worker_metrics(&self, worker: usize) -> &WorkerMetrics { + match self { + Spawner::Basic(spawner) => spawner.worker_metrics(worker), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.worker_metrics(worker), + } + } + + pub(crate) fn injection_queue_depth(&self) -> usize { + match self { + Spawner::Basic(spawner) => spawner.injection_queue_depth(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.injection_queue_depth(), + } + } + + pub(crate) fn worker_local_queue_depth(&self, worker: usize) -> usize { + match self { + Spawner::Basic(spawner) => spawner.worker_metrics(worker).queue_depth(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.worker_local_queue_depth(worker), + } + } + } +} diff --git a/third_party/rust/tokio/src/runtime/task/core.rs b/third_party/rust/tokio/src/runtime/task/core.rs new file mode 100644 index 0000000000..776e8341f3 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/core.rs @@ -0,0 +1,267 @@ +//! Core task module. +//! +//! # Safety +//! +//! The functions in this module are private to the `task` module. All of them +//! should be considered `unsafe` to use, but are not marked as such since it +//! would be too noisy. +//! +//! Make sure to consult the relevant safety section of each function before +//! use. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::runtime::task::raw::{self, Vtable}; +use crate::runtime::task::state::State; +use crate::runtime::task::Schedule; +use crate::util::linked_list; + +use std::pin::Pin; +use std::ptr::NonNull; +use std::task::{Context, Poll, Waker}; + +/// The task cell. Contains the components of the task. +/// +/// It is critical for `Header` to be the first field as the task structure will +/// be referenced by both *mut Cell and *mut Header. +#[repr(C)] +pub(super) struct Cell<T: Future, S> { + /// Hot task state data + pub(super) header: Header, + + /// Either the future or output, depending on the execution stage. + pub(super) core: Core<T, S>, + + /// Cold data + pub(super) trailer: Trailer, +} + +pub(super) struct CoreStage<T: Future> { + stage: UnsafeCell<Stage<T>>, +} + +/// The core of the task. +/// +/// Holds the future or output, depending on the stage of execution. +pub(super) struct Core<T: Future, S> { + /// Scheduler used to drive this future. + pub(super) scheduler: S, + + /// Either the future or the output. + pub(super) stage: CoreStage<T>, +} + +/// Crate public as this is also needed by the pool. +#[repr(C)] +pub(crate) struct Header { + /// Task state. + pub(super) state: State, + + pub(super) owned: UnsafeCell<linked_list::Pointers<Header>>, + + /// Pointer to next task, used with the injection queue. + pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, + + /// Table of function pointers for executing actions on the task. + pub(super) vtable: &'static Vtable, + + /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that + /// this task is stored in. If the task is not in any list, should be the + /// id of the list that it was previously in, or zero if it has never been + /// in any list. + /// + /// Once a task has been bound to a list, it can never be bound to another + /// list, even if removed from the first list. + /// + /// The id is not unset when removed from a list because we want to be able + /// to read the id without synchronization, even if it is concurrently being + /// removed from the list. + pub(super) owner_id: UnsafeCell<u64>, + + /// The tracing ID for this instrumented task. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) id: Option<tracing::Id>, +} + +unsafe impl Send for Header {} +unsafe impl Sync for Header {} + +/// Cold data is stored after the future. +pub(super) struct Trailer { + /// Consumer task waiting on completion of this task. + pub(super) waker: UnsafeCell<Option<Waker>>, +} + +/// Either the future or the output. +pub(super) enum Stage<T: Future> { + Running(T), + Finished(super::Result<T::Output>), + Consumed, +} + +impl<T: Future, S: Schedule> Cell<T, S> { + /// Allocates a new task cell, containing the header, trailer, and core + /// structures. + pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let id = future.id(); + Box::new(Cell { + header: Header { + state, + owned: UnsafeCell::new(linked_list::Pointers::new()), + queue_next: UnsafeCell::new(None), + vtable: raw::vtable::<T, S>(), + owner_id: UnsafeCell::new(0), + #[cfg(all(tokio_unstable, feature = "tracing"))] + id, + }, + core: Core { + scheduler, + stage: CoreStage { + stage: UnsafeCell::new(Stage::Running(future)), + }, + }, + trailer: Trailer { + waker: UnsafeCell::new(None), + }, + }) + } +} + +impl<T: Future> CoreStage<T> { + pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { + self.stage.with_mut(f) + } + + /// Polls the future. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `state` field. This + /// requires ensuring mutual exclusion between any concurrent thread that + /// might modify the future or output field. + /// + /// The mutual exclusion is implemented by `Harness` and the `Lifecycle` + /// component of the task state. + /// + /// `self` must also be pinned. This is handled by storing the task on the + /// heap. + pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> { + let res = { + self.stage.with_mut(|ptr| { + // Safety: The caller ensures mutual exclusion to the field. + let future = match unsafe { &mut *ptr } { + Stage::Running(future) => future, + _ => unreachable!("unexpected stage"), + }; + + // Safety: The caller ensures the future is pinned. + let future = unsafe { Pin::new_unchecked(future) }; + + future.poll(&mut cx) + }) + }; + + if res.is_ready() { + self.drop_future_or_output(); + } + + res + } + + /// Drops the future. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn drop_future_or_output(&self) { + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Consumed); + } + } + + /// Stores the task output. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn store_output(&self, output: super::Result<T::Output>) { + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Finished(output)); + } + } + + /// Takes the task output. + /// + /// # Safety + /// + /// The caller must ensure it is safe to mutate the `stage` field. + pub(super) fn take_output(&self) -> super::Result<T::Output> { + use std::mem; + + self.stage.with_mut(|ptr| { + // Safety:: the caller ensures mutual exclusion to the field. + match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { + Stage::Finished(output) => output, + _ => panic!("JoinHandle polled after completion"), + } + }) + } + + unsafe fn set_stage(&self, stage: Stage<T>) { + self.stage.with_mut(|ptr| *ptr = stage) + } +} + +cfg_rt_multi_thread! { + impl Header { + pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { + self.queue_next.with_mut(|ptr| *ptr = next); + } + } +} + +impl Header { + // safety: The caller must guarantee exclusive access to this field, and + // must ensure that the id is either 0 or the id of the OwnedTasks + // containing this task. + pub(super) unsafe fn set_owner_id(&self, owner: u64) { + self.owner_id.with_mut(|ptr| *ptr = owner); + } + + pub(super) fn get_owner_id(&self) -> u64 { + // safety: If there are concurrent writes, then that write has violated + // the safety requirements on `set_owner_id`. + unsafe { self.owner_id.with(|ptr| *ptr) } + } +} + +impl Trailer { + pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) { + self.waker.with_mut(|ptr| { + *ptr = waker; + }); + } + + pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool { + self.waker + .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) + } + + pub(super) fn wake_join(&self) { + self.waker.with(|ptr| match unsafe { &*ptr } { + Some(waker) => waker.wake_by_ref(), + None => panic!("waker missing"), + }); + } +} + +#[test] +#[cfg(not(loom))] +fn header_lte_cache_line() { + use std::mem::size_of; + + assert!(size_of::<Header>() <= 8 * size_of::<*const ()>()); +} diff --git a/third_party/rust/tokio/src/runtime/task/error.rs b/third_party/rust/tokio/src/runtime/task/error.rs new file mode 100644 index 0000000000..1a8129b2b6 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/error.rs @@ -0,0 +1,146 @@ +use std::any::Any; +use std::fmt; +use std::io; + +use crate::util::SyncWrapper; + +cfg_rt! { + /// Task failed to execute to completion. + pub struct JoinError { + repr: Repr, + } +} + +enum Repr { + Cancelled, + Panic(SyncWrapper<Box<dyn Any + Send + 'static>>), +} + +impl JoinError { + pub(crate) fn cancelled() -> JoinError { + JoinError { + repr: Repr::Cancelled, + } + } + + pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { + JoinError { + repr: Repr::Panic(SyncWrapper::new(err)), + } + } + + /// Returns true if the error was caused by the task being cancelled. + pub fn is_cancelled(&self) -> bool { + matches!(&self.repr, Repr::Cancelled) + } + + /// Returns true if the error was caused by the task panicking. + /// + /// # Examples + /// + /// ``` + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// assert!(err.is_panic()); + /// } + /// ``` + pub fn is_panic(&self) -> bool { + matches!(&self.repr, Repr::Panic(_)) + } + + /// Consumes the join error, returning the object with which the task panicked. + /// + /// # Panics + /// + /// `into_panic()` panics if the `Error` does not represent the underlying + /// task terminating with a panic. Use `is_panic` to check the error reason + /// or `try_into_panic` for a variant that does not panic. + /// + /// # Examples + /// + /// ```should_panic + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// if err.is_panic() { + /// // Resume the panic on the main task + /// panic::resume_unwind(err.into_panic()); + /// } + /// } + /// ``` + pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { + self.try_into_panic() + .expect("`JoinError` reason is not a panic.") + } + + /// Consumes the join error, returning the object with which the task + /// panicked if the task terminated due to a panic. Otherwise, `self` is + /// returned. + /// + /// # Examples + /// + /// ```should_panic + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let err = tokio::spawn(async { + /// panic!("boom"); + /// }).await.unwrap_err(); + /// + /// if let Ok(reason) = err.try_into_panic() { + /// // Resume the panic on the main task + /// panic::resume_unwind(reason); + /// } + /// } + /// ``` + pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, JoinError> { + match self.repr { + Repr::Panic(p) => Ok(p.into_inner()), + _ => Err(self), + } + } +} + +impl fmt::Display for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Cancelled => write!(fmt, "cancelled"), + Repr::Panic(_) => write!(fmt, "panic"), + } + } +} + +impl fmt::Debug for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Cancelled => write!(fmt, "JoinError::Cancelled"), + Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"), + } + } +} + +impl std::error::Error for JoinError {} + +impl From<JoinError> for io::Error { + fn from(src: JoinError) -> io::Error { + io::Error::new( + io::ErrorKind::Other, + match src.repr { + Repr::Cancelled => "task was cancelled", + Repr::Panic(_) => "task panicked", + }, + ) + } +} diff --git a/third_party/rust/tokio/src/runtime/task/harness.rs b/third_party/rust/tokio/src/runtime/task/harness.rs new file mode 100644 index 0000000000..261dccea41 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/harness.rs @@ -0,0 +1,485 @@ +use crate::future::Future; +use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; +use crate::runtime::task::state::Snapshot; +use crate::runtime::task::waker::waker_ref; +use crate::runtime::task::{JoinError, Notified, Schedule, Task}; + +use std::mem; +use std::mem::ManuallyDrop; +use std::panic; +use std::ptr::NonNull; +use std::task::{Context, Poll, Waker}; + +/// Typed raw task handle. +pub(super) struct Harness<T: Future, S: 'static> { + cell: NonNull<Cell<T, S>>, +} + +impl<T, S> Harness<T, S> +where + T: Future, + S: 'static, +{ + pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> { + Harness { + cell: ptr.cast::<Cell<T, S>>(), + } + } + + fn header_ptr(&self) -> NonNull<Header> { + self.cell.cast() + } + + fn header(&self) -> &Header { + unsafe { &self.cell.as_ref().header } + } + + fn trailer(&self) -> &Trailer { + unsafe { &self.cell.as_ref().trailer } + } + + fn core(&self) -> &Core<T, S> { + unsafe { &self.cell.as_ref().core } + } +} + +impl<T, S> Harness<T, S> +where + T: Future, + S: Schedule, +{ + /// Polls the inner future. A ref-count is consumed. + /// + /// All necessary state checks and transitions are performed. + /// Panics raised while polling the future are handled. + pub(super) fn poll(self) { + // We pass our ref-count to `poll_inner`. + match self.poll_inner() { + PollFuture::Notified => { + // The `poll_inner` call has given us two ref-counts back. + // We give one of them to a new task and call `yield_now`. + self.core() + .scheduler + .yield_now(Notified(self.get_new_task())); + + // The remaining ref-count is now dropped. We kept the extra + // ref-count until now to ensure that even if the `yield_now` + // call drops the provided task, the task isn't deallocated + // before after `yield_now` returns. + self.drop_reference(); + } + PollFuture::Complete => { + self.complete(); + } + PollFuture::Dealloc => { + self.dealloc(); + } + PollFuture::Done => (), + } + } + + /// Polls the task and cancel it if necessary. This takes ownership of a + /// ref-count. + /// + /// If the return value is Notified, the caller is given ownership of two + /// ref-counts. + /// + /// If the return value is Complete, the caller is given ownership of a + /// single ref-count, which should be passed on to `complete`. + /// + /// If the return value is Dealloc, then this call consumed the last + /// ref-count and the caller should call `dealloc`. + /// + /// Otherwise the ref-count is consumed and the caller should not access + /// `self` again. + fn poll_inner(&self) -> PollFuture { + use super::state::{TransitionToIdle, TransitionToRunning}; + + match self.header().state.transition_to_running() { + TransitionToRunning::Success => { + let header_ptr = self.header_ptr(); + let waker_ref = waker_ref::<T, S>(&header_ptr); + let cx = Context::from_waker(&*waker_ref); + let res = poll_future(&self.core().stage, cx); + + if res == Poll::Ready(()) { + // The future completed. Move on to complete the task. + return PollFuture::Complete; + } + + match self.header().state.transition_to_idle() { + TransitionToIdle::Ok => PollFuture::Done, + TransitionToIdle::OkNotified => PollFuture::Notified, + TransitionToIdle::OkDealloc => PollFuture::Dealloc, + TransitionToIdle::Cancelled => { + // The transition to idle failed because the task was + // cancelled during the poll. + + cancel_task(&self.core().stage); + PollFuture::Complete + } + } + } + TransitionToRunning::Cancelled => { + cancel_task(&self.core().stage); + PollFuture::Complete + } + TransitionToRunning::Failed => PollFuture::Done, + TransitionToRunning::Dealloc => PollFuture::Dealloc, + } + } + + /// Forcibly shuts down the task. + /// + /// Attempt to transition to `Running` in order to forcibly shutdown the + /// task. If the task is currently running or in a state of completion, then + /// there is nothing further to do. When the task completes running, it will + /// notice the `CANCELLED` bit and finalize the task. + pub(super) fn shutdown(self) { + if !self.header().state.transition_to_shutdown() { + // The task is concurrently running. No further work needed. + self.drop_reference(); + return; + } + + // By transitioning the lifecycle to `Running`, we have permission to + // drop the future. + cancel_task(&self.core().stage); + self.complete(); + } + + pub(super) fn dealloc(self) { + // Release the join waker, if there is one. + self.trailer().waker.with_mut(drop); + + // Check causality + self.core().stage.with_mut(drop); + + unsafe { + drop(Box::from_raw(self.cell.as_ptr())); + } + } + + // ===== join handle ===== + + /// Read the task output into `dst`. + pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) { + if can_read_output(self.header(), self.trailer(), waker) { + *dst = Poll::Ready(self.core().stage.take_output()); + } + } + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + can_read_output(self.header(), self.trailer(), waker) + } + + pub(super) fn drop_join_handle_slow(self) { + // Try to unset `JOIN_INTEREST`. This must be done as a first step in + // case the task concurrently completed. + if self.header().state.unset_join_interested().is_err() { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().stage.drop_future_or_output(); + })); + } + + // Drop the `JoinHandle` reference, possibly deallocating the task + self.drop_reference(); + } + + /// Remotely aborts the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(self) { + if self.header().state.transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + } + + // ===== waker behavior ===== + + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. + pub(super) fn wake_by_val(self) { + use super::state::TransitionToNotifiedByVal; + + match self.header().state.transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + // The caller has given us a ref-count, and the transition has + // created a new ref-count, so we now hold two. We turn the new + // ref-count Notified and pass it to the call to `schedule`. + // + // The old ref-count is retained for now to ensure that the task + // is not dropped during the call to `schedule` if the call + // drops the task it was given. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + + // Now that we have completed the call to schedule, we can + // release our ref-count. + self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } + } + + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. + pub(super) fn wake_by_ref(&self) { + use super::state::TransitionToNotifiedByRef; + + match self.header().state.transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + TransitionToNotifiedByRef::DoNothing => {} + } + } + + pub(super) fn drop_reference(self) { + if self.header().state.ref_dec() { + self.dealloc(); + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) fn id(&self) -> Option<&tracing::Id> { + self.header().id.as_ref() + } + + // ====== internal ====== + + /// Completes the task. This method assumes that the state is RUNNING. + fn complete(self) { + // The future has completed and its output has been written to the task + // stage. We transition from running to complete. + + let snapshot = self.header().state.transition_to_complete(); + + // We catch panics here in case dropping the future or waking the + // JoinHandle panics. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if !snapshot.is_join_interested() { + // The `JoinHandle` is not interested in the output of + // this task. It is our responsibility to drop the + // output. + self.core().stage.drop_future_or_output(); + } else if snapshot.has_join_waker() { + // Notify the join handle. The previous transition obtains the + // lock on the waker cell. + self.trailer().wake_join(); + } + })); + + // The task has completed execution and will no longer be scheduled. + let num_release = self.release(); + + if self.header().state.transition_to_terminal(num_release) { + self.dealloc(); + } + } + + /// Releases the task from the scheduler. Returns the number of ref-counts + /// that should be decremented. + fn release(&self) -> usize { + // We don't actually increment the ref-count here, but the new task is + // never destroyed, so that's ok. + let me = ManuallyDrop::new(self.get_new_task()); + + if let Some(task) = self.core().scheduler.release(&me) { + mem::forget(task); + 2 + } else { + 1 + } + } + + /// Creates a new task that holds its own ref-count. + /// + /// # Safety + /// + /// Any use of `self` after this call must ensure that a ref-count to the + /// task holds the task alive until after the use of `self`. Passing the + /// returned Task to any method on `self` is unsound if dropping the Task + /// could drop `self` before the call on `self` returned. + fn get_new_task(&self) -> Task<S> { + // safety: The header is at the beginning of the cell, so this cast is + // safe. + unsafe { Task::from_raw(self.cell.cast()) } + } +} + +fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { + // Load a snapshot of the current task state + let snapshot = header.state.load(); + + debug_assert!(snapshot.is_join_interested()); + + if !snapshot.is_complete() { + // The waker must be stored in the task struct. + let res = if snapshot.has_join_waker() { + // There already is a waker stored in the struct. If it matches + // the provided waker, then there is no further work to do. + // Otherwise, the waker must be swapped. + let will_wake = unsafe { + // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE` + // may mutate the `waker` field. + trailer.will_wake(waker) + }; + + if will_wake { + // The task is not complete **and** the waker is up to date, + // there is nothing further that needs to be done. + return false; + } + + // Unset the `JOIN_WAKER` to gain mutable access to the `waker` + // field then update the field with the new join worker. + // + // This requires two atomic operations, unsetting the bit and + // then resetting it. If the task transitions to complete + // concurrently to either one of those operations, then setting + // the join waker fails and we proceed to reading the task + // output. + header + .state + .unset_waker() + .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot)) + } else { + set_join_waker(header, trailer, waker.clone(), snapshot) + }; + + match res { + Ok(_) => return false, + Err(snapshot) => { + assert!(snapshot.is_complete()); + } + } + } + true +} + +fn set_join_waker( + header: &Header, + trailer: &Trailer, + waker: Waker, + snapshot: Snapshot, +) -> Result<Snapshot, Snapshot> { + assert!(snapshot.is_join_interested()); + assert!(!snapshot.has_join_waker()); + + // Safety: Only the `JoinHandle` may set the `waker` field. When + // `JOIN_INTEREST` is **not** set, nothing else will touch the field. + unsafe { + trailer.set_waker(Some(waker)); + } + + // Update the `JoinWaker` state accordingly + let res = header.state.set_join_waker(); + + // If the state could not be updated, then clear the join waker + if res.is_err() { + unsafe { + trailer.set_waker(None); + } + } + + res +} + +enum PollFuture { + Complete, + Notified, + Done, + Dealloc, +} + +/// Cancels the task and store the appropriate error in the stage field. +fn cancel_task<T: Future>(stage: &CoreStage<T>) { + // Drop the future from a panic guard. + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + stage.drop_future_or_output(); + })); + + match res { + Ok(()) => { + stage.store_output(Err(JoinError::cancelled())); + } + Err(panic) => { + stage.store_output(Err(JoinError::panic(panic))); + } + } +} + +/// Polls the future. If the future completes, the output is written to the +/// stage field. +fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { + // Poll the future. + let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { + struct Guard<'a, T: Future> { + core: &'a CoreStage<T>, + } + impl<'a, T: Future> Drop for Guard<'a, T> { + fn drop(&mut self) { + // If the future panics on poll, we drop it inside the panic + // guard. + self.core.drop_future_or_output(); + } + } + let guard = Guard { core }; + let res = guard.core.poll(cx); + mem::forget(guard); + res + })); + + // Prepare output for being placed in the core stage. + let output = match output { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(output)) => Ok(output), + Err(panic) => Err(JoinError::panic(panic)), + }; + + // Catch and ignore panics if the future panics on drop. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + core.store_output(output); + })); + + Poll::Ready(()) +} diff --git a/third_party/rust/tokio/src/runtime/task/inject.rs b/third_party/rust/tokio/src/runtime/task/inject.rs new file mode 100644 index 0000000000..1585e13a01 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/inject.rs @@ -0,0 +1,220 @@ +//! Inject queue used to send wakeups to a work-stealing scheduler + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::runtime::task; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{Acquire, Release}; + +/// Growable, MPMC queue used to inject new tasks into the scheduler and as an +/// overflow queue when the local, fixed-size, array queue overflows. +pub(crate) struct Inject<T: 'static> { + /// Pointers to the head and tail of the queue. + pointers: Mutex<Pointers>, + + /// Number of pending tasks in the queue. This helps prevent unnecessary + /// locking in the hot path. + len: AtomicUsize, + + _p: PhantomData<T>, +} + +struct Pointers { + /// True if the queue is closed. + is_closed: bool, + + /// Linked-list head. + head: Option<NonNull<task::Header>>, + + /// Linked-list tail. + tail: Option<NonNull<task::Header>>, +} + +unsafe impl<T> Send for Inject<T> {} +unsafe impl<T> Sync for Inject<T> {} + +impl<T: 'static> Inject<T> { + pub(crate) fn new() -> Inject<T> { + Inject { + pointers: Mutex::new(Pointers { + is_closed: false, + head: None, + tail: None, + }), + len: AtomicUsize::new(0), + _p: PhantomData, + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Closes the injection queue, returns `true` if the queue is open when the + /// transition is made. + pub(crate) fn close(&self) -> bool { + let mut p = self.pointers.lock(); + + if p.is_closed { + return false; + } + + p.is_closed = true; + true + } + + pub(crate) fn is_closed(&self) -> bool { + self.pointers.lock().is_closed + } + + pub(crate) fn len(&self) -> usize { + self.len.load(Acquire) + } + + /// Pushes a value into the queue. + /// + /// This does nothing if the queue is closed. + pub(crate) fn push(&self, task: task::Notified<T>) { + // Acquire queue lock + let mut p = self.pointers.lock(); + + if p.is_closed { + return; + } + + // safety: only mutated with the lock held + let len = unsafe { self.len.unsync_load() }; + let task = task.into_raw(); + + // The next pointer should already be null + debug_assert!(get_next(task).is_none()); + + if let Some(tail) = p.tail { + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(tail, Some(task)); + } else { + p.head = Some(task); + } + + p.tail = Some(task); + + self.len.store(len + 1, Release); + } + + /// Pushes several values into the queue. + #[inline] + pub(crate) fn push_batch<I>(&self, mut iter: I) + where + I: Iterator<Item = task::Notified<T>>, + { + let first = match iter.next() { + Some(first) => first.into_raw(), + None => return, + }; + + // Link up all the tasks. + let mut prev = first; + let mut counter = 1; + + // We are going to be called with an `std::iter::Chain`, and that + // iterator overrides `for_each` to something that is easier for the + // compiler to optimize than a loop. + iter.for_each(|next| { + let next = next.into_raw(); + + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(prev, Some(next)); + prev = next; + counter += 1; + }); + + // Now that the tasks are linked together, insert them into the + // linked list. + self.push_batch_inner(first, prev, counter); + } + + /// Inserts several tasks that have been linked together into the queue. + /// + /// The provided head and tail may be be the same task. In this case, a + /// single task is inserted. + #[inline] + fn push_batch_inner( + &self, + batch_head: NonNull<task::Header>, + batch_tail: NonNull<task::Header>, + num: usize, + ) { + debug_assert!(get_next(batch_tail).is_none()); + + let mut p = self.pointers.lock(); + + if let Some(tail) = p.tail { + set_next(tail, Some(batch_head)); + } else { + p.head = Some(batch_head); + } + + p.tail = Some(batch_tail); + + // Increment the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + let len = unsafe { self.len.unsync_load() }; + + self.len.store(len + num, Release); + } + + pub(crate) fn pop(&self) -> Option<task::Notified<T>> { + // Fast path, if len == 0, then there are no values + if self.is_empty() { + return None; + } + + let mut p = self.pointers.lock(); + + // It is possible to hit null here if another thread popped the last + // task between us checking `len` and acquiring the lock. + let task = p.head?; + + p.head = get_next(task); + + if p.head.is_none() { + p.tail = None; + } + + set_next(task, None); + + // Decrement the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + self.len + .store(unsafe { self.len.unsync_load() } - 1, Release); + + // safety: a `Notified` is pushed into the queue and now it is popped! + Some(unsafe { task::Notified::from_raw(task) }) + } +} + +impl<T: 'static> Drop for Inject<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { + unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } +} + +fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { + unsafe { + header.as_ref().set_next(val); + } +} diff --git a/third_party/rust/tokio/src/runtime/task/join.rs b/third_party/rust/tokio/src/runtime/task/join.rs new file mode 100644 index 0000000000..8beed2eaac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/join.rs @@ -0,0 +1,275 @@ +use crate::runtime::task::RawTask; + +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::panic::{RefUnwindSafe, UnwindSafe}; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; + +cfg_rt! { + /// An owned permission to join on a task (await its termination). + /// + /// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for + /// a task rather than a thread. + /// + /// A `JoinHandle` *detaches* the associated task when it is dropped, which + /// means that there is no longer any handle to the task, and no way to `join` + /// on it. + /// + /// This `struct` is created by the [`task::spawn`] and [`task::spawn_blocking`] + /// functions. + /// + /// # Examples + /// + /// Creation from [`task::spawn`]: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<_> = task::spawn(async { + /// // some work here + /// }); + /// # } + /// ``` + /// + /// Creation from [`task::spawn_blocking`]: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<_> = task::spawn_blocking(|| { + /// // some blocking work here + /// }); + /// # } + /// ``` + /// + /// The generic parameter `T` in `JoinHandle<T>` is the return type of the spawned task. + /// If the return value is an i32, the join handle has type `JoinHandle<i32>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<i32> = task::spawn(async { + /// 5 + 3 + /// }); + /// # } + /// + /// ``` + /// + /// If the task does not have a return value, the join handle has type `JoinHandle<()>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<()> = task::spawn(async { + /// println!("I return nothing."); + /// }); + /// # } + /// ``` + /// + /// Note that `handle.await` doesn't give you the return type directly. It is wrapped in a + /// `Result` because panics in the spawned task are caught by Tokio. The `?` operator has + /// to be double chained to extract the returned value: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// Ok(5 + 3) + /// }); + /// + /// let result = join_handle.await??; + /// assert_eq!(result, 8); + /// Ok(()) + /// } + /// ``` + /// + /// If the task panics, the error is a [`JoinError`] that contains the panic: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// panic!("boom"); + /// }); + /// + /// let err = join_handle.await.unwrap_err(); + /// assert!(err.is_panic()); + /// Ok(()) + /// } + /// + /// ``` + /// Child being detached and outliving its parent: + /// + /// ```no_run + /// use tokio::task; + /// use tokio::time; + /// use std::time::Duration; + /// + /// # #[tokio::main] async fn main() { + /// let original_task = task::spawn(async { + /// let _detached_task = task::spawn(async { + /// // Here we sleep to make sure that the first task returns before. + /// time::sleep(Duration::from_millis(10)).await; + /// // This will be called, even though the JoinHandle is dropped. + /// println!("♫ Still alive ♫"); + /// }); + /// }); + /// + /// original_task.await.expect("The task being joined has panicked"); + /// println!("Original task is joined."); + /// + /// // We make sure that the new task has time to run, before the main + /// // task returns. + /// + /// time::sleep(Duration::from_millis(1000)).await; + /// # } + /// ``` + /// + /// [`task::spawn`]: crate::task::spawn() + /// [`task::spawn_blocking`]: crate::task::spawn_blocking + /// [`std::thread::JoinHandle`]: std::thread::JoinHandle + /// [`JoinError`]: crate::task::JoinError + pub struct JoinHandle<T> { + raw: Option<RawTask>, + _p: PhantomData<T>, + } +} + +unsafe impl<T: Send> Send for JoinHandle<T> {} +unsafe impl<T: Send> Sync for JoinHandle<T> {} + +impl<T> UnwindSafe for JoinHandle<T> {} +impl<T> RefUnwindSafe for JoinHandle<T> {} + +impl<T> JoinHandle<T> { + pub(super) fn new(raw: RawTask) -> JoinHandle<T> { + JoinHandle { + raw: Some(raw), + _p: PhantomData, + } + } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will fail with a [cancelled] `JoinError`. + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// true + /// })); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// false + /// })); + /// + /// for handle in &handles { + /// handle.abort(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + /// [cancelled]: method@super::error::JoinError::is_cancelled + pub fn abort(&self) { + if let Some(raw) = self.raw { + raw.remote_abort(); + } + } + + /// Set the waker that is notified when the task completes. + pub(crate) fn set_join_waker(&mut self, waker: &Waker) { + if let Some(raw) = self.raw { + if raw.try_set_join_waker(waker) { + // In this case the task has already completed. We wake the waker immediately. + waker.wake_by_ref(); + } + } + } +} + +impl<T> Unpin for JoinHandle<T> {} + +impl<T> Future for JoinHandle<T> { + type Output = super::Result<T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut ret = Poll::Pending; + + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + // Raw should always be set. If it is not, this is due to polling after + // completion + let raw = self + .raw + .as_ref() + .expect("polling after `JoinHandle` already completed"); + + // Try to read the task output. If the task is not yet complete, the + // waker is stored and is notified once the task does complete. + // + // The function must go via the vtable, which requires erasing generic + // types. To do this, the function "return" is placed on the stack + // **before** calling the function and is passed into the function using + // `*mut ()`. + // + // Safety: + // + // The type of `T` must match the task's output type. + unsafe { + raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); + } + + if ret.is_ready() { + coop.made_progress(); + } + + ret + } +} + +impl<T> Drop for JoinHandle<T> { + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + if raw.header().state.drop_join_handle_fast().is_ok() { + return; + } + + raw.drop_join_handle_slow(); + } + } +} + +impl<T> fmt::Debug for JoinHandle<T> +where + T: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("JoinHandle").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/task/list.rs b/third_party/rust/tokio/src/runtime/task/list.rs new file mode 100644 index 0000000000..7758f8db7a --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/list.rs @@ -0,0 +1,297 @@ +//! This module has containers for storing the tasks spawned on a scheduler. The +//! `OwnedTasks` container is thread-safe but can only store tasks that +//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can +//! store non-Send tasks. +//! +//! The collections can be closed to prevent adding new tasks during shutdown of +//! the scheduler with the collection. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::Mutex; +use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; +use crate::util::linked_list::{Link, LinkedList}; + +use std::marker::PhantomData; + +// The id from the module below is used to verify whether a given task is stored +// in this OwnedTasks, or some other task. The counter starts at one so we can +// use zero for tasks not owned by any list. +// +// The safety checks in this file can technically be violated if the counter is +// overflown, but the checks are not supposed to ever fail unless there is a +// bug in Tokio, so we accept that certain bugs would not be caught if the two +// mixed up runtimes happen to have the same id. + +cfg_has_atomic_u64! { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return id; + } + } + } +} + +cfg_not_has_atomic_u64! { + use std::sync::atomic::{AtomicU32, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return u64::from(id); + } + } + } +} + +pub(crate) struct OwnedTasks<S: 'static> { + inner: Mutex<OwnedTasksInner<S>>, + id: u64, +} +pub(crate) struct LocalOwnedTasks<S: 'static> { + inner: UnsafeCell<OwnedTasksInner<S>>, + id: u64, + _not_send_or_sync: PhantomData<*const ()>, +} +struct OwnedTasksInner<S: 'static> { + list: LinkedList<Task<S>, <Task<S> as Link>::Target>, + closed: bool, +} + +impl<S: 'static> OwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + } + } + + /// Binds the provided task to this OwnedTasks instance. This fails if the + /// OwnedTasks has been closed. + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + let mut lock = self.inner.lock(); + if lock.closed { + drop(lock); + drop(notified); + task.shutdown(); + (join, None) + } else { + lock.list.push_front(task); + (join, Some(notified)) + } + } + + /// Asserts that the given task is owned by this OwnedTasks and convert it to + /// a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: All tasks bound to this OwnedTasks are Send, so it is safe + // to poll it on this thread no matter what thread we are on. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + // The first iteration of the loop was unrolled so it can set the + // closed bool. + let first_task = { + let mut lock = self.inner.lock(); + lock.closed = true; + lock.list.pop_back() + }; + match first_task { + Some(task) => task.shutdown(), + None => return, + } + + loop { + let task = match self.inner.lock().list.pop_back() { + Some(task) => task, + None => return, + }; + + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.inner.lock().list.remove(task.header().into()) } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.lock().list.is_empty() + } +} + +impl<S: 'static> LocalOwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: UnsafeCell::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + _not_send_or_sync: PhantomData, + } + } + + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + if self.is_closed() { + drop(notified); + task.shutdown(); + (join, None) + } else { + self.with_inner(|inner| { + inner.list.push_front(task); + }); + (join, Some(notified)) + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + self.with_inner(|inner| inner.closed = true); + + while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + self.with_inner(|inner| + // safety: We just checked that the provided task is not in some + // other linked list. + unsafe { inner.list.remove(task.header().into()) }) + } + + /// Asserts that the given task is owned by this LocalOwnedTasks and convert + /// it to a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: The task was bound to this LocalOwnedTasks, and the + // LocalOwnedTasks is not Send or Sync, so we are on the right thread + // for polling this task. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + #[inline] + fn with_inner<F, T>(&self, f: F) -> T + where + F: FnOnce(&mut OwnedTasksInner<S>) -> T, + { + // safety: This type is not Sync, so concurrent calls of this method + // can't happen. Furthermore, all uses of this method in this file make + // sure that they don't call `with_inner` recursively. + self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) + } + + pub(crate) fn is_closed(&self) -> bool { + self.with_inner(|inner| inner.closed) + } + + pub(crate) fn is_empty(&self) -> bool { + self.with_inner(|inner| inner.list.is_empty()) + } +} + +#[cfg(all(test))] +mod tests { + use super::*; + + // This test may run in parallel with other tests, so we only test that ids + // come in increasing order. + #[test] + fn test_id_not_broken() { + let mut last_id = get_next_id(); + assert_ne!(last_id, 0); + + for _ in 0..1000 { + let next_id = get_next_id(); + assert_ne!(next_id, 0); + assert!(last_id < next_id); + last_id = next_id; + } + } +} diff --git a/third_party/rust/tokio/src/runtime/task/mod.rs b/third_party/rust/tokio/src/runtime/task/mod.rs new file mode 100644 index 0000000000..2a492dc985 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/mod.rs @@ -0,0 +1,445 @@ +//! The task module. +//! +//! The task module contains the code that manages spawned tasks and provides a +//! safe API for the rest of the runtime to use. Each task in a runtime is +//! stored in an OwnedTasks or LocalOwnedTasks object. +//! +//! # Task reference types +//! +//! A task is usually referenced by multiple handles, and there are several +//! types of handles. +//! +//! * OwnedTask - tasks stored in an OwnedTasks or LocalOwnedTasks are of this +//! reference type. +//! +//! * JoinHandle - each task has a JoinHandle that allows access to the output +//! of the task. +//! +//! * Waker - every waker for a task has this reference type. There can be any +//! number of waker references. +//! +//! * Notified - tracks whether the task is notified. +//! +//! * Unowned - this task reference type is used for tasks not stored in any +//! runtime. Mainly used for blocking tasks, but also in tests. +//! +//! The task uses a reference count to keep track of how many active references +//! exist. The Unowned reference type takes up two ref-counts. All other +//! reference types take up a single ref-count. +//! +//! Besides the waker type, each task has at most one of each reference type. +//! +//! # State +//! +//! The task stores its state in an atomic usize with various bitfields for the +//! necessary information. The state has the following bitfields: +//! +//! * RUNNING - Tracks whether the task is currently being polled or cancelled. +//! This bit functions as a lock around the task. +//! +//! * COMPLETE - Is one once the future has fully completed and has been +//! dropped. Never unset once set. Never set together with RUNNING. +//! +//! * NOTIFIED - Tracks whether a Notified object currently exists. +//! +//! * CANCELLED - Is set to one for tasks that should be cancelled as soon as +//! possible. May take any value for completed tasks. +//! +//! * JOIN_INTEREST - Is set to one if there exists a JoinHandle. +//! +//! * JOIN_WAKER - Is set to one if the JoinHandle has set a waker. +//! +//! The rest of the bits are used for the ref-count. +//! +//! # Fields in the task +//! +//! The task has various fields. This section describes how and when it is safe +//! to access a field. +//! +//! * The state field is accessed with atomic instructions. +//! +//! * The OwnedTask reference has exclusive access to the `owned` field. +//! +//! * The Notified reference has exclusive access to the `queue_next` field. +//! +//! * The `owner_id` field can be set as part of construction of the task, but +//! is otherwise immutable and anyone can access the field immutably without +//! synchronization. +//! +//! * If COMPLETE is one, then the JoinHandle has exclusive access to the +//! stage field. If COMPLETE is zero, then the RUNNING bitfield functions as +//! a lock for the stage field, and it can be accessed only by the thread +//! that set RUNNING to one. +//! +//! * If JOIN_WAKER is zero, then the JoinHandle has exclusive access to the +//! join handle waker. If JOIN_WAKER and COMPLETE are both one, then the +//! thread that set COMPLETE to one has exclusive access to the join handle +//! waker. +//! +//! All other fields are immutable and can be accessed immutably without +//! synchronization by anyone. +//! +//! # Safety +//! +//! This section goes through various situations and explains why the API is +//! safe in that situation. +//! +//! ## Polling or dropping the future +//! +//! Any mutable access to the future happens after obtaining a lock by modifying +//! the RUNNING field, so exclusive access is ensured. +//! +//! When the task completes, exclusive access to the output is transferred to +//! the JoinHandle. If the JoinHandle is already dropped when the transition to +//! complete happens, the thread performing that transition retains exclusive +//! access to the output and should immediately drop it. +//! +//! ## Non-Send futures +//! +//! If a future is not Send, then it is bound to a LocalOwnedTasks. The future +//! will only ever be polled or dropped given a LocalNotified or inside a call +//! to LocalOwnedTasks::shutdown_all. In either case, it is guaranteed that the +//! future is on the right thread. +//! +//! If the task is never removed from the LocalOwnedTasks, then it is leaked, so +//! there is no risk that the task is dropped on some other thread when the last +//! ref-count drops. +//! +//! ## Non-Send output +//! +//! When a task completes, the output is placed in the stage of the task. Then, +//! a transition that sets COMPLETE to true is performed, and the value of +//! JOIN_INTEREST when this transition happens is read. +//! +//! If JOIN_INTEREST is zero when the transition to COMPLETE happens, then the +//! output is immediately dropped. +//! +//! If JOIN_INTEREST is one when the transition to COMPLETE happens, then the +//! JoinHandle is responsible for cleaning up the output. If the output is not +//! Send, then this happens: +//! +//! 1. The output is created on the thread that the future was polled on. Since +//! only non-Send futures can have non-Send output, the future was polled on +//! the thread that the future was spawned from. +//! 2. Since JoinHandle<Output> is not Send if Output is not Send, the +//! JoinHandle is also on the thread that the future was spawned from. +//! 3. Thus, the JoinHandle will not move the output across threads when it +//! takes or drops the output. +//! +//! ## Recursive poll/shutdown +//! +//! Calling poll from inside a shutdown call or vice-versa is not prevented by +//! the API exposed by the task module, so this has to be safe. In either case, +//! the lock in the RUNNING bitfield makes the inner call return immediately. If +//! the inner call is a `shutdown` call, then the CANCELLED bit is set, and the +//! poll call will notice it when the poll finishes, and the task is cancelled +//! at that point. + +// Some task infrastructure is here to support `JoinSet`, which is currently +// unstable. This should be removed once `JoinSet` is stabilized. +#![cfg_attr(not(tokio_unstable), allow(dead_code))] + +mod core; +use self::core::Cell; +use self::core::Header; + +mod error; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::error::JoinError; + +mod harness; +use self::harness::Harness; + +cfg_rt_multi_thread! { + mod inject; + pub(super) use self::inject::Inject; +} + +mod join; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::join::JoinHandle; + +mod list; +pub(crate) use self::list::{LocalOwnedTasks, OwnedTasks}; + +mod raw; +use self::raw::RawTask; + +mod state; +use self::state::State; + +mod waker; + +use crate::future::Future; +use crate::util::linked_list; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::{fmt, mem}; + +/// An owned handle to the task, tracked by ref count. +#[repr(transparent)] +pub(crate) struct Task<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +unsafe impl<S> Send for Task<S> {} +unsafe impl<S> Sync for Task<S> {} + +/// A task was notified. +#[repr(transparent)] +pub(crate) struct Notified<S: 'static>(Task<S>); + +// safety: This type cannot be used to touch the task without first verifying +// that the value is on a thread where it is safe to poll the task. +unsafe impl<S: Schedule> Send for Notified<S> {} +unsafe impl<S: Schedule> Sync for Notified<S> {} + +/// A non-Send variant of Notified with the invariant that it is on a thread +/// where it is safe to poll it. +#[repr(transparent)] +pub(crate) struct LocalNotified<S: 'static> { + task: Task<S>, + _not_send: PhantomData<*const ()>, +} + +/// A task that is not owned by any OwnedTasks. Used for blocking tasks. +/// This type holds two ref-counts. +pub(crate) struct UnownedTask<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +// safety: This type can only be created given a Send task. +unsafe impl<S> Send for UnownedTask<S> {} +unsafe impl<S> Sync for UnownedTask<S> {} + +/// Task result sent back. +pub(crate) type Result<T> = std::result::Result<T, JoinError>; + +pub(crate) trait Schedule: Sync + Sized + 'static { + /// The task has completed work and is ready to be released. The scheduler + /// should release it immediately and return it. The task module will batch + /// the ref-dec with setting other options. + /// + /// If the scheduler has already released the task, then None is returned. + fn release(&self, task: &Task<Self>) -> Option<Task<Self>>; + + /// Schedule the task + fn schedule(&self, task: Notified<Self>); + + /// Schedule the task to run in the near future, yielding the thread to + /// other tasks. + fn yield_now(&self, task: Notified<Self>) { + self.schedule(task); + } +} + +cfg_rt! { + /// This is the constructor for a new task. Three references to the task are + /// created. The first task reference is usually put into an OwnedTasks + /// immediately. The Notified is sent to the scheduler as an ordinary + /// notification. + fn new_task<T, S>( + task: T, + scheduler: S + ) -> (Task<S>, Notified<S>, JoinHandle<T::Output>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let raw = RawTask::new::<T, S>(task, scheduler); + let task = Task { + raw, + _p: PhantomData, + }; + let notified = Notified(Task { + raw, + _p: PhantomData, + }); + let join = JoinHandle::new(raw); + + (task, notified, join) + } + + /// Creates a new task with an associated join handle. This method is used + /// only when the task is not going to be stored in an `OwnedTasks` list. + /// + /// Currently only blocking tasks use this method. + pub(crate) fn unowned<T, S>(task: T, scheduler: S) -> (UnownedTask<S>, JoinHandle<T::Output>) + where + S: Schedule, + T: Send + Future + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = new_task(task, scheduler); + + // This transfers the ref-count of task and notified into an UnownedTask. + // This is valid because an UnownedTask holds two ref-counts. + let unowned = UnownedTask { + raw: task.raw, + _p: PhantomData, + }; + std::mem::forget(task); + std::mem::forget(notified); + + (unowned, join) + } +} + +impl<S: 'static> Task<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + Task { + raw: RawTask::from_raw(ptr), + _p: PhantomData, + } + } + + fn header(&self) -> &Header { + self.raw.header() + } +} + +impl<S: 'static> Notified<S> { + fn header(&self) -> &Header { + self.0.header() + } +} + +cfg_rt_multi_thread! { + impl<S: 'static> Notified<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { + Notified(Task::from_raw(ptr)) + } + } + + impl<S: 'static> Task<S> { + fn into_raw(self) -> NonNull<Header> { + let ret = self.raw.header_ptr(); + mem::forget(self); + ret + } + } + + impl<S: 'static> Notified<S> { + fn into_raw(self) -> NonNull<Header> { + self.0.into_raw() + } + } +} + +impl<S: Schedule> Task<S> { + /// Pre-emptively cancels the task as part of the shutdown process. + pub(crate) fn shutdown(self) { + let raw = self.raw; + mem::forget(self); + raw.shutdown(); + } +} + +impl<S: Schedule> LocalNotified<S> { + /// Runs the task. + pub(crate) fn run(self) { + let raw = self.task.raw; + mem::forget(self); + raw.poll(); + } +} + +impl<S: Schedule> UnownedTask<S> { + // Used in test of the inject queue. + #[cfg(test)] + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] + pub(super) fn into_notified(self) -> Notified<S> { + Notified(self.into_task()) + } + + fn into_task(self) -> Task<S> { + // Convert into a task. + let task = Task { + raw: self.raw, + _p: PhantomData, + }; + mem::forget(self); + + // Drop a ref-count since an UnownedTask holds two. + task.header().state.ref_dec(); + + task + } + + pub(crate) fn run(self) { + let raw = self.raw; + mem::forget(self); + + // Transfer one ref-count to a Task object. + let task = Task::<S> { + raw, + _p: PhantomData, + }; + + // Use the other ref-count to poll the task. + raw.poll(); + // Decrement our extra ref-count + drop(task); + } + + pub(crate) fn shutdown(self) { + self.into_task().shutdown() + } +} + +impl<S: 'static> Drop for Task<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.header().state.ref_dec() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + +impl<S: 'static> Drop for UnownedTask<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.raw.header().state.ref_dec_twice() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + +impl<S> fmt::Debug for Task<S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "Task({:p})", self.header()) + } +} + +impl<S> fmt::Debug for Notified<S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "task::Notified({:p})", self.0.header()) + } +} + +/// # Safety +/// +/// Tasks are pinned. +unsafe impl<S> linked_list::Link for Task<S> { + type Handle = Task<S>; + type Target = Header; + + fn as_raw(handle: &Task<S>) -> NonNull<Header> { + handle.raw.header_ptr() + } + + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + Task::from_raw(ptr) + } + + unsafe fn pointers(target: NonNull<Header>) -> NonNull<linked_list::Pointers<Header>> { + // Not super great as it avoids some of looms checking... + NonNull::from(target.as_ref().owned.with_mut(|ptr| &mut *ptr)) + } +} diff --git a/third_party/rust/tokio/src/runtime/task/raw.rs b/third_party/rust/tokio/src/runtime/task/raw.rs new file mode 100644 index 0000000000..2e4420b5c1 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/raw.rs @@ -0,0 +1,165 @@ +use crate::future::Future; +use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; + +use std::ptr::NonNull; +use std::task::{Poll, Waker}; + +/// Raw task handle +pub(super) struct RawTask { + ptr: NonNull<Header>, +} + +pub(super) struct Vtable { + /// Polls the future. + pub(super) poll: unsafe fn(NonNull<Header>), + + /// Deallocates the memory. + pub(super) dealloc: unsafe fn(NonNull<Header>), + + /// Reads the task output, if complete. + pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker), + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool, + + /// The join handle has been dropped. + pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>), + + /// The task is remotely aborted. + pub(super) remote_abort: unsafe fn(NonNull<Header>), + + /// Scheduler is being shutdown. + pub(super) shutdown: unsafe fn(NonNull<Header>), +} + +/// Get the vtable for the requested `T` and `S` generics. +pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable { + &Vtable { + poll: poll::<T, S>, + dealloc: dealloc::<T, S>, + try_read_output: try_read_output::<T, S>, + try_set_join_waker: try_set_join_waker::<T, S>, + drop_join_handle_slow: drop_join_handle_slow::<T, S>, + remote_abort: remote_abort::<T, S>, + shutdown: shutdown::<T, S>, + } +} + +impl RawTask { + pub(super) fn new<T, S>(task: T, scheduler: S) -> RawTask + where + T: Future, + S: Schedule, + { + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); + let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; + + RawTask { ptr } + } + + pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask { + RawTask { ptr } + } + + pub(super) fn header_ptr(&self) -> NonNull<Header> { + self.ptr + } + + /// Returns a reference to the task's meta structure. + /// + /// Safe as `Header` is `Sync`. + pub(super) fn header(&self) -> &Header { + unsafe { self.ptr.as_ref() } + } + + /// Safety: mutual exclusion is required to call this function. + pub(super) fn poll(self) { + let vtable = self.header().vtable; + unsafe { (vtable.poll)(self.ptr) } + } + + pub(super) fn dealloc(self) { + let vtable = self.header().vtable; + unsafe { + (vtable.dealloc)(self.ptr); + } + } + + /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T` + /// is the future stored by the task. + pub(super) unsafe fn try_read_output(self, dst: *mut (), waker: &Waker) { + let vtable = self.header().vtable; + (vtable.try_read_output)(self.ptr, dst, waker); + } + + pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool { + let vtable = self.header().vtable; + unsafe { (vtable.try_set_join_waker)(self.ptr, waker) } + } + + pub(super) fn drop_join_handle_slow(self) { + let vtable = self.header().vtable; + unsafe { (vtable.drop_join_handle_slow)(self.ptr) } + } + + pub(super) fn shutdown(self) { + let vtable = self.header().vtable; + unsafe { (vtable.shutdown)(self.ptr) } + } + + pub(super) fn remote_abort(self) { + let vtable = self.header().vtable; + unsafe { (vtable.remote_abort)(self.ptr) } + } +} + +impl Clone for RawTask { + fn clone(&self) -> Self { + RawTask { ptr: self.ptr } + } +} + +impl Copy for RawTask {} + +unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.poll(); +} + +unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.dealloc(); +} + +unsafe fn try_read_output<T: Future, S: Schedule>( + ptr: NonNull<Header>, + dst: *mut (), + waker: &Waker, +) { + let out = &mut *(dst as *mut Poll<super::Result<T::Output>>); + + let harness = Harness::<T, S>::from_raw(ptr); + harness.try_read_output(out, waker); +} + +unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool { + let harness = Harness::<T, S>::from_raw(ptr); + harness.try_set_join_waker(waker) +} + +unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.drop_join_handle_slow() +} + +unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.remote_abort() +} + +unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.shutdown() +} diff --git a/third_party/rust/tokio/src/runtime/task/state.rs b/third_party/rust/tokio/src/runtime/task/state.rs new file mode 100644 index 0000000000..c2d5b28eac --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/state.rs @@ -0,0 +1,595 @@ +use crate::loom::sync::atomic::AtomicUsize; + +use std::fmt; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::usize; + +pub(super) struct State { + val: AtomicUsize, +} + +/// Current state value. +#[derive(Copy, Clone)] +pub(super) struct Snapshot(usize); + +type UpdateResult = Result<Snapshot, Snapshot>; + +/// The task is currently being run. +const RUNNING: usize = 0b0001; + +/// The task is complete. +/// +/// Once this bit is set, it is never unset. +const COMPLETE: usize = 0b0010; + +/// Extracts the task's lifecycle value from the state. +const LIFECYCLE_MASK: usize = 0b11; + +/// Flag tracking if the task has been pushed into a run queue. +const NOTIFIED: usize = 0b100; + +/// The join handle is still around. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const JOIN_INTEREST: usize = 0b1_000; + +/// A join handle waker has been set. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const JOIN_WAKER: usize = 0b10_000; + +/// The task has been forcibly cancelled. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 +const CANCELLED: usize = 0b100_000; + +/// All bits. +const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; + +/// Bits used by the ref count portion of the state. +const REF_COUNT_MASK: usize = !STATE_MASK; + +/// Number of positions to shift the ref count. +const REF_COUNT_SHIFT: usize = REF_COUNT_MASK.count_zeros() as usize; + +/// One ref count. +const REF_ONE: usize = 1 << REF_COUNT_SHIFT; + +/// State a task is initialized with. +/// +/// A task is initialized with three references: +/// +/// * A reference that will be stored in an OwnedTasks or LocalOwnedTasks. +/// * A reference that will be sent to the scheduler as an ordinary notification. +/// * A reference for the JoinHandle. +/// +/// As the task starts with a `JoinHandle`, `JOIN_INTEREST` is set. +/// As the task starts with a `Notified`, `NOTIFIED` is set. +const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; + +#[must_use] +pub(super) enum TransitionToRunning { + Success, + Cancelled, + Failed, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToIdle { + Ok, + OkNotified, + OkDealloc, + Cancelled, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByVal { + DoNothing, + Submit, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByRef { + DoNothing, + Submit, +} + +/// All transitions are performed via RMW operations. This establishes an +/// unambiguous modification order. +impl State { + /// Returns a task's initial state. + pub(super) fn new() -> State { + // The raw task returned by this method has a ref-count of three. See + // the comment on INITIAL_STATE for more. + State { + val: AtomicUsize::new(INITIAL_STATE), + } + } + + /// Loads the current state, establishes `Acquire` ordering. + pub(super) fn load(&self) -> Snapshot { + Snapshot(self.val.load(Acquire)) + } + + /// Attempts to transition the lifecycle to `Running`. This sets the + /// notified bit to false so notifications during the poll can be detected. + pub(super) fn transition_to_running(&self) -> TransitionToRunning { + self.fetch_update_action(|mut next| { + let action; + assert!(next.is_notified()); + + if !next.is_idle() { + // This happens if the task is either currently running or if it + // has already completed, e.g. if it was cancelled during + // shutdown. Consume the ref-count and return. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToRunning::Dealloc; + } else { + action = TransitionToRunning::Failed; + } + } else { + // We are able to lock the RUNNING bit. + next.set_running(); + next.unset_notified(); + + if next.is_cancelled() { + action = TransitionToRunning::Cancelled; + } else { + action = TransitionToRunning::Success; + } + } + (action, Some(next)) + }) + } + + /// Transitions the task from `Running` -> `Idle`. + /// + /// Returns `true` if the transition to `Idle` is successful, `false` otherwise. + /// The transition to `Idle` fails if the task has been flagged to be + /// cancelled. + pub(super) fn transition_to_idle(&self) -> TransitionToIdle { + self.fetch_update_action(|curr| { + assert!(curr.is_running()); + + if curr.is_cancelled() { + return (TransitionToIdle::Cancelled, None); + } + + let mut next = curr; + let action; + next.unset_running(); + + if !next.is_notified() { + // Polling the future consumes the ref-count of the Notified. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToIdle::OkDealloc; + } else { + action = TransitionToIdle::Ok; + } + } else { + // The caller will schedule a new notification, so we create a + // new ref-count for the notification. Our own ref-count is kept + // for now, and the caller will drop it shortly. + next.ref_inc(); + action = TransitionToIdle::OkNotified; + } + + (action, Some(next)) + }) + } + + /// Transitions the task from `Running` -> `Complete`. + pub(super) fn transition_to_complete(&self) -> Snapshot { + const DELTA: usize = RUNNING | COMPLETE; + + let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel)); + assert!(prev.is_running()); + assert!(!prev.is_complete()); + + Snapshot(prev.0 ^ DELTA) + } + + /// Transitions from `Complete` -> `Terminal`, decrementing the reference + /// count the specified number of times. + /// + /// Returns true if the task should be deallocated. + pub(super) fn transition_to_terminal(&self, count: usize) -> bool { + let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); + assert!( + prev.ref_count() >= count, + "current: {}, sub: {}", + prev.ref_count(), + count + ); + prev.ref_count() == count + } + + /// Transitions the state to `NOTIFIED`. + /// + /// If no task needs to be submitted, a ref-count is consumed. + /// + /// If a task needs to be submitted, the ref-count is incremented for the + /// new Notified. + pub(super) fn transition_to_notified_by_val(&self) -> TransitionToNotifiedByVal { + self.fetch_update_action(|mut snapshot| { + let action; + + if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit anything as the thread currently running the + // future is responsible for that. + snapshot.set_notified(); + snapshot.ref_dec(); + + // The thread that set the running bit also holds a ref-count. + assert!(snapshot.ref_count() > 0); + + action = TransitionToNotifiedByVal::DoNothing; + } else if snapshot.is_complete() || snapshot.is_notified() { + // We do not need to submit any notifications, but we have to + // decrement the ref-count. + snapshot.ref_dec(); + + if snapshot.ref_count() == 0 { + action = TransitionToNotifiedByVal::Dealloc; + } else { + action = TransitionToNotifiedByVal::DoNothing; + } + } else { + // We create a new notified that we can submit. The caller + // retains ownership of the ref-count they passed in. + snapshot.set_notified(); + snapshot.ref_inc(); + action = TransitionToNotifiedByVal::Submit; + } + + (action, Some(snapshot)) + }) + } + + /// Transitions the state to `NOTIFIED`. + pub(super) fn transition_to_notified_by_ref(&self) -> TransitionToNotifiedByRef { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_complete() || snapshot.is_notified() { + // There is nothing to do in this case. + (TransitionToNotifiedByRef::DoNothing, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit as the thread currently running the future is + // responsible for that. + snapshot.set_notified(); + (TransitionToNotifiedByRef::DoNothing, Some(snapshot)) + } else { + // The task is idle and not notified. We should submit a + // notification. + snapshot.set_notified(); + snapshot.ref_inc(); + (TransitionToNotifiedByRef::Submit, Some(snapshot)) + } + }) + } + + /// Sets the cancelled bit and transitions the state to `NOTIFIED` if idle. + /// + /// Returns `true` if the task needs to be submitted to the pool for + /// execution. + pub(super) fn transition_to_notified_and_cancel(&self) -> bool { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_cancelled() || snapshot.is_complete() { + // Aborts to completed or cancelled tasks are no-ops. + (false, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as cancelled. The thread + // running the task will notice the cancelled bit when it + // stops polling and it will kill the task. + // + // The set_notified() call is not strictly necessary but it will + // in some cases let a wake_by_ref call return without having + // to perform a compare_exchange. + snapshot.set_notified(); + snapshot.set_cancelled(); + (false, Some(snapshot)) + } else { + // The task is idle. We set the cancelled and notified bits and + // submit a notification if the notified bit was not already + // set. + snapshot.set_cancelled(); + if !snapshot.is_notified() { + snapshot.set_notified(); + snapshot.ref_inc(); + (true, Some(snapshot)) + } else { + (false, Some(snapshot)) + } + } + }) + } + + /// Sets the `CANCELLED` bit and attempts to transition to `Running`. + /// + /// Returns `true` if the transition to `Running` succeeded. + pub(super) fn transition_to_shutdown(&self) -> bool { + let mut prev = Snapshot(0); + + let _ = self.fetch_update(|mut snapshot| { + prev = snapshot; + + if snapshot.is_idle() { + snapshot.set_running(); + } + + // If the task was not idle, the thread currently running the task + // will notice the cancelled bit and cancel it once the poll + // completes. + snapshot.set_cancelled(); + Some(snapshot) + }); + + prev.is_idle() + } + + /// Optimistically tries to swap the state assuming the join handle is + /// __immediately__ dropped on spawn. + pub(super) fn drop_join_handle_fast(&self) -> Result<(), ()> { + use std::sync::atomic::Ordering::Relaxed; + + // Relaxed is acceptable as if this function is called and succeeds, + // then nothing has been done w/ the join handle. + // + // The moment the join handle is used (polled), the `JOIN_WAKER` flag is + // set, at which point the CAS will fail. + // + // Given this, there is no risk if this operation is reordered. + self.val + .compare_exchange_weak( + INITIAL_STATE, + (INITIAL_STATE - REF_ONE) & !JOIN_INTEREST, + Release, + Relaxed, + ) + .map(|_| ()) + .map_err(|_| ()) + } + + /// Tries to unset the JOIN_INTEREST flag. + /// + /// Returns `Ok` if the operation happens before the task transitions to a + /// completed state, `Err` otherwise. + pub(super) fn unset_join_interested(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.unset_join_interested(); + + Some(next) + }) + } + + /// Sets the `JOIN_WAKER` bit. + /// + /// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if + /// the task has completed. + pub(super) fn set_join_waker(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + assert!(!curr.has_join_waker()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.set_join_waker(); + + Some(next) + }) + } + + /// Unsets the `JOIN_WAKER` bit. + /// + /// Returns `Ok` has been unset, `Err` otherwise. This operation fails if + /// the task has completed. + pub(super) fn unset_waker(&self) -> UpdateResult { + self.fetch_update(|curr| { + assert!(curr.is_join_interested()); + assert!(curr.has_join_waker()); + + if curr.is_complete() { + return None; + } + + let mut next = curr; + next.unset_join_waker(); + + Some(next) + }) + } + + pub(super) fn ref_inc(&self) { + use std::process; + use std::sync::atomic::Ordering::Relaxed; + + // Using a relaxed ordering is alright here, as knowledge of the + // original reference prevents other threads from erroneously deleting + // the object. + // + // As explained in the [Boost documentation][1], Increasing the + // reference counter can always be done with memory_order_relaxed: New + // references to an object can only be formed from an existing + // reference, and passing an existing reference from one thread to + // another must already provide any required synchronization. + // + // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html) + let prev = self.val.fetch_add(REF_ONE, Relaxed); + + // If the reference count overflowed, abort. + if prev > isize::MAX as usize { + process::abort(); + } + } + + /// Returns `true` if the task should be released. + pub(super) fn ref_dec(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 1); + prev.ref_count() == 1 + } + + /// Returns `true` if the task should be released. + pub(super) fn ref_dec_twice(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 2); + prev.ref_count() == 2 + } + + fn fetch_update_action<F, T>(&self, mut f: F) -> T + where + F: FnMut(Snapshot) -> (T, Option<Snapshot>), + { + let mut curr = self.load(); + + loop { + let (output, next) = f(curr); + let next = match next { + Some(next) => next, + None => return output, + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return output, + Err(actual) => curr = Snapshot(actual), + } + } + } + + fn fetch_update<F>(&self, mut f: F) -> Result<Snapshot, Snapshot> + where + F: FnMut(Snapshot) -> Option<Snapshot>, + { + let mut curr = self.load(); + + loop { + let next = match f(curr) { + Some(next) => next, + None => return Err(curr), + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return Ok(next), + Err(actual) => curr = Snapshot(actual), + } + } + } +} + +// ===== impl Snapshot ===== + +impl Snapshot { + /// Returns `true` if the task is in an idle state. + pub(super) fn is_idle(self) -> bool { + self.0 & (RUNNING | COMPLETE) == 0 + } + + /// Returns `true` if the task has been flagged as notified. + pub(super) fn is_notified(self) -> bool { + self.0 & NOTIFIED == NOTIFIED + } + + fn unset_notified(&mut self) { + self.0 &= !NOTIFIED + } + + fn set_notified(&mut self) { + self.0 |= NOTIFIED + } + + pub(super) fn is_running(self) -> bool { + self.0 & RUNNING == RUNNING + } + + fn set_running(&mut self) { + self.0 |= RUNNING; + } + + fn unset_running(&mut self) { + self.0 &= !RUNNING; + } + + pub(super) fn is_cancelled(self) -> bool { + self.0 & CANCELLED == CANCELLED + } + + fn set_cancelled(&mut self) { + self.0 |= CANCELLED; + } + + /// Returns `true` if the task's future has completed execution. + pub(super) fn is_complete(self) -> bool { + self.0 & COMPLETE == COMPLETE + } + + pub(super) fn is_join_interested(self) -> bool { + self.0 & JOIN_INTEREST == JOIN_INTEREST + } + + fn unset_join_interested(&mut self) { + self.0 &= !JOIN_INTEREST + } + + pub(super) fn has_join_waker(self) -> bool { + self.0 & JOIN_WAKER == JOIN_WAKER + } + + fn set_join_waker(&mut self) { + self.0 |= JOIN_WAKER; + } + + fn unset_join_waker(&mut self) { + self.0 &= !JOIN_WAKER + } + + pub(super) fn ref_count(self) -> usize { + (self.0 & REF_COUNT_MASK) >> REF_COUNT_SHIFT + } + + fn ref_inc(&mut self) { + assert!(self.0 <= isize::MAX as usize); + self.0 += REF_ONE; + } + + pub(super) fn ref_dec(&mut self) { + assert!(self.ref_count() > 0); + self.0 -= REF_ONE + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let snapshot = self.load(); + snapshot.fmt(fmt) + } +} + +impl fmt::Debug for Snapshot { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Snapshot") + .field("is_running", &self.is_running()) + .field("is_complete", &self.is_complete()) + .field("is_notified", &self.is_notified()) + .field("is_cancelled", &self.is_cancelled()) + .field("is_join_interested", &self.is_join_interested()) + .field("has_join_waker", &self.has_join_waker()) + .field("ref_count", &self.ref_count()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/task/waker.rs b/third_party/rust/tokio/src/runtime/task/waker.rs new file mode 100644 index 0000000000..74a29f4a84 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/task/waker.rs @@ -0,0 +1,130 @@ +use crate::future::Future; +use crate::runtime::task::harness::Harness; +use crate::runtime::task::{Header, Schedule}; + +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops; +use std::ptr::NonNull; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +pub(super) struct WakerRef<'a, S: 'static> { + waker: ManuallyDrop<Waker>, + _p: PhantomData<(&'a Header, S)>, +} + +/// Returns a `WakerRef` which avoids having to pre-emptively increase the +/// refcount if there is no need to do so. +pub(super) fn waker_ref<T, S>(header: &NonNull<Header>) -> WakerRef<'_, S> +where + T: Future, + S: Schedule, +{ + // `Waker::will_wake` uses the VTABLE pointer as part of the check. This + // means that `will_wake` will always return false when using the current + // task's waker. (discussion at rust-lang/rust#66281). + // + // To fix this, we use a single vtable. Since we pass in a reference at this + // point and not an *owned* waker, we must ensure that `drop` is never + // called on this waker instance. This is done by wrapping it with + // `ManuallyDrop` and then never calling drop. + let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T, S>(*header))) }; + + WakerRef { + waker, + _p: PhantomData, + } +} + +impl<S> ops::Deref for WakerRef<'_, S> { + type Target = Waker; + + fn deref(&self) -> &Waker { + &self.waker + } +} + +cfg_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + if let Some(id) = $harness.id() { + tracing::trace!( + target: "tokio::task::waker", + op = $op, + task.id = id.into_u64(), + ); + } + } + } +} + +cfg_not_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + // noop + let _ = &$harness; + } + } +} + +unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker +where + T: Future, + S: Schedule, +{ + let header = ptr as *const Header; + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.clone"); + (*header).state.ref_inc(); + raw_waker::<T, S>(ptr) +} + +unsafe fn drop_waker<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.drop"); + harness.drop_reference(); +} + +unsafe fn wake_by_val<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake"); + harness.wake_by_val(); +} + +// Wake without consuming the waker +unsafe fn wake_by_ref<T, S>(ptr: *const ()) +where + T: Future, + S: Schedule, +{ + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake_by_ref"); + harness.wake_by_ref(); +} + +fn raw_waker<T, S>(header: NonNull<Header>) -> RawWaker +where + T: Future, + S: Schedule, +{ + let ptr = header.as_ptr() as *const (); + let vtable = &RawWakerVTable::new( + clone_waker::<T, S>, + wake_by_val::<T, S>, + wake_by_ref::<T, S>, + drop_waker::<T, S>, + ); + RawWaker::new(ptr, vtable) +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs new file mode 100644 index 0000000000..a772603f71 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_basic_scheduler.rs @@ -0,0 +1,142 @@ +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +use crate::loom::thread; +use crate::runtime::{Builder, Runtime}; +use crate::sync::oneshot::{self, Receiver}; +use crate::task; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::task::{Context, Poll}; + +fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) { + let (tx, rx) = oneshot::channel(); + let num_polls = Arc::new(AtomicUsize::new(0)); + rt.spawn(async move { + for _ in 0..12 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + rt.block_on(async { + BlockedFuture { + rx, + num_polls: num_polls.clone(), + } + .await; + }); + + let polls = num_polls.load(Acquire); + assert!(polls <= at_most_polls); +} + +#[test] +fn block_on_num_polls() { + loom::model(|| { + // we expect at most 4 number of polls because there are three points at + // which we poll the future and an opportunity for a false-positive.. At + // any of these points it can be ready: + // + // - when we fail to steal the parker and we block on a notification + // that it is available. + // + // - when we steal the parker and we schedule the future + // + // - when the future is woken up and we have ran the max number of tasks + // for the current tick or there are no more tasks to run. + // + // - a thread is notified that the parker is available but a third + // thread acquires it before the notified thread can. + // + let at_most = 4; + + let rt1 = Arc::new(Builder::new_current_thread().build().unwrap()); + let rt2 = rt1.clone(); + let rt3 = rt1.clone(); + + let th1 = thread::spawn(move || assert_at_most_num_polls(rt1, at_most)); + let th2 = thread::spawn(move || assert_at_most_num_polls(rt2, at_most)); + let th3 = thread::spawn(move || assert_at_most_num_polls(rt3, at_most)); + + th1.join().unwrap(); + th2.join().unwrap(); + th3.join().unwrap(); + }); +} + +#[test] +fn assert_no_unnecessary_polls() { + loom::model(|| { + // // After we poll outer future, woken should reset to false + let rt = Builder::new_current_thread().build().unwrap(); + let (tx, rx) = oneshot::channel(); + let pending_cnt = Arc::new(AtomicUsize::new(0)); + + rt.spawn(async move { + for _ in 0..24 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + let pending_cnt_clone = pending_cnt.clone(); + rt.block_on(async move { + // use task::yield_now() to ensure woken set to true + // ResetFuture will be polled at most once + // Here comes two cases + // 1. recv no message from channel, ResetFuture will be polled + // but get Pending and we record ResetFuture.pending_cnt ++. + // Then when message arrive, ResetFuture returns Ready. So we + // expect ResetFuture.pending_cnt = 1 + // 2. recv message from channel, ResetFuture returns Ready immediately. + // We expect ResetFuture.pending_cnt = 0 + task::yield_now().await; + ResetFuture { + rx, + pending_cnt: pending_cnt_clone, + } + .await; + }); + + let pending_cnt = pending_cnt.load(Acquire); + assert!(pending_cnt <= 1); + }); +} + +struct BlockedFuture { + rx: Receiver<()>, + num_polls: Arc<AtomicUsize>, +} + +impl Future for BlockedFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.num_polls.fetch_add(1, Release); + + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => Poll::Pending, + _ => Poll::Ready(()), + } + } +} + +struct ResetFuture { + rx: Receiver<()>, + pending_cnt: Arc<AtomicUsize>, +} + +impl Future for ResetFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => { + self.pending_cnt.fetch_add(1, Release); + Poll::Pending + } + _ => Poll::Ready(()), + } + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs new file mode 100644 index 0000000000..89de85e436 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_blocking.rs @@ -0,0 +1,81 @@ +use crate::runtime::{self, Runtime}; + +use std::sync::Arc; + +#[test] +fn blocking_shutdown() { + loom::model(|| { + let v = Arc::new(()); + + let rt = mk_runtime(1); + { + let _enter = rt.enter(); + for _ in 0..2 { + let v = v.clone(); + crate::task::spawn_blocking(move || { + assert!(1 < Arc::strong_count(&v)); + }); + } + } + + drop(rt); + assert_eq!(1, Arc::strong_count(&v)); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_always_run() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + + let (tx, rx) = loom_oneshot::channel(); + let _enter = rt.enter(); + runtime::spawn_blocking(|| {}); + runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }) + .unwrap(); + + drop(rt); + + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + }); +} + +#[test] +fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() { + use crate::runtime::tests::loom_oneshot; + loom::model(|| { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + let handle = rt.handle().clone(); + + // Drop the runtime in a different thread + { + loom::thread::spawn(move || { + drop(rt); + }); + } + + let _enter = handle.enter(); + let (tx, rx) = loom_oneshot::channel(); + let handle = runtime::spawn_mandatory_blocking(move || { + let _ = tx.send(()); + }); + + // handle.is_some() means that `spawn_mandatory_blocking` + // promised us to run the blocking task + if handle.is_some() { + // This call will deadlock if `spawn_mandatory_blocking` doesn't run. + let () = rx.recv(); + } + }); +} + +fn mk_runtime(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs new file mode 100644 index 0000000000..e87ddb0140 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_join_set.rs @@ -0,0 +1,82 @@ +use crate::runtime::Builder; +use crate::task::JoinSet; + +#[test] +fn test_join_set() { + loom::model(|| { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + let mut set = JoinSet::new(); + + rt.block_on(async { + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + set.spawn(async { () }); + assert_eq!(set.len(), 2); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 1); + let () = set.join_one().await.unwrap().unwrap(); + assert_eq!(set.len(), 0); + set.spawn(async { () }); + assert_eq!(set.len(), 1); + }); + + drop(set); + drop(rt); + }); +} + +#[test] +fn abort_all_during_completion() { + use std::sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }; + + // These booleans assert that at least one execution had the task complete first, and that at + // least one execution had the task be cancelled before it completed. + let complete_happened = Arc::new(AtomicBool::new(false)); + let cancel_happened = Arc::new(AtomicBool::new(false)); + + { + let complete_happened = complete_happened.clone(); + let cancel_happened = cancel_happened.clone(); + loom::model(move || { + let rt = Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(); + + let mut set = JoinSet::new(); + + rt.block_on(async { + set.spawn(async { () }); + set.abort_all(); + + match set.join_one().await { + Ok(Some(())) => complete_happened.store(true, SeqCst), + Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst), + Err(err) => panic!("fail: {}", err), + Ok(None) => { + unreachable!("Aborting the task does not remove it from the JoinSet.") + } + } + + assert!(matches!(set.join_one().await, Ok(None))); + }); + + drop(set); + drop(rt); + }); + } + + assert!(complete_happened.load(SeqCst)); + assert!(cancel_happened.load(SeqCst)); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_local.rs b/third_party/rust/tokio/src/runtime/tests/loom_local.rs new file mode 100644 index 0000000000..d9a07a45f0 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_local.rs @@ -0,0 +1,47 @@ +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::Builder; +use crate::task::LocalSet; + +use std::task::Poll; + +/// Waking a runtime will attempt to push a task into a queue of notifications +/// in the runtime, however the tasks in such a queue usually have a reference +/// to the runtime itself. This means that if they are not properly removed at +/// runtime shutdown, this will cause a memory leak. +/// +/// This test verifies that waking something during shutdown of a LocalSet does +/// not result in tasks lingering in the queue once shutdown is complete. This +/// is verified using loom's leak finder. +#[test] +fn wake_during_shutdown() { + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + let ls = LocalSet::new(); + + let (send, recv) = oneshot::channel(); + + ls.spawn_local(async move { + let mut send = Some(send); + + let () = futures::future::poll_fn(|cx| { + if let Some(send) = send.take() { + send.send(cx.waker().clone()); + } + + Poll::Pending + }) + .await; + }); + + let handle = loom::thread::spawn(move || { + let waker = recv.recv(); + waker.wake(); + }); + + ls.block_on(&rt, crate::task::yield_now()); + + drop(ls); + handle.join().unwrap(); + drop(rt); + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs new file mode 100644 index 0000000000..87eb638642 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_oneshot.rs @@ -0,0 +1,48 @@ +use crate::loom::sync::{Arc, Mutex}; +use loom::sync::Notify; + +pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) { + let inner = Arc::new(Inner { + notify: Notify::new(), + value: Mutex::new(None), + }); + + let tx = Sender { + inner: inner.clone(), + }; + let rx = Receiver { inner }; + + (tx, rx) +} + +pub(crate) struct Sender<T> { + inner: Arc<Inner<T>>, +} + +pub(crate) struct Receiver<T> { + inner: Arc<Inner<T>>, +} + +struct Inner<T> { + notify: Notify, + value: Mutex<Option<T>>, +} + +impl<T> Sender<T> { + pub(crate) fn send(self, value: T) { + *self.inner.value.lock() = Some(value); + self.inner.notify.notify(); + } +} + +impl<T> Receiver<T> { + pub(crate) fn recv(self) -> T { + loop { + if let Some(v) = self.inner.value.lock().take() { + return v; + } + + self.inner.notify.wait(); + } + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_pool.rs b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs new file mode 100644 index 0000000000..b3ecd43124 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_pool.rs @@ -0,0 +1,430 @@ +/// Full runtime loom tests. These are heavy tests and take significant time to +/// run on CI. +/// +/// Use `LOOM_MAX_PREEMPTIONS=1` to do a "quick" run as a smoke test. +/// +/// In order to speed up the C +use crate::future::poll_fn; +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::{self, Runtime}; +use crate::{spawn, task}; +use tokio_test::assert_ok; + +use loom::sync::atomic::{AtomicBool, AtomicUsize}; +use loom::sync::Arc; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::{Relaxed, SeqCst}; +use std::task::{Context, Poll}; + +mod atomic_take { + use loom::sync::atomic::AtomicBool; + use std::mem::MaybeUninit; + use std::sync::atomic::Ordering::SeqCst; + + pub(super) struct AtomicTake<T> { + inner: MaybeUninit<T>, + taken: AtomicBool, + } + + impl<T> AtomicTake<T> { + pub(super) fn new(value: T) -> Self { + Self { + inner: MaybeUninit::new(value), + taken: AtomicBool::new(false), + } + } + + pub(super) fn take(&self) -> Option<T> { + // safety: Only one thread will see the boolean change from false + // to true, so that thread is able to take the value. + match self.taken.fetch_or(true, SeqCst) { + false => unsafe { Some(std::ptr::read(self.inner.as_ptr())) }, + true => None, + } + } + } + + impl<T> Drop for AtomicTake<T> { + fn drop(&mut self) { + drop(self.take()); + } + } +} + +#[derive(Clone)] +struct AtomicOneshot<T> { + value: std::sync::Arc<atomic_take::AtomicTake<oneshot::Sender<T>>>, +} +impl<T> AtomicOneshot<T> { + fn new(sender: oneshot::Sender<T>) -> Self { + Self { + value: std::sync::Arc::new(atomic_take::AtomicTake::new(sender)), + } + } + + fn assert_send(&self, value: T) { + self.value.take().unwrap().send(value); + } +} + +/// Tests are divided into groups to make the runs faster on CI. +mod group_a { + use super::*; + + #[test] + fn racy_shutdown() { + loom::model(|| { + let pool = mk_pool(1); + + // here's the case we want to exercise: + // + // a worker that still has tasks in its local queue gets sent to the blocking pool (due to + // block_in_place). the blocking pool is shut down, so drops the worker. the worker's + // shutdown method never gets run. + // + // we do this by spawning two tasks on one worker, the first of which does block_in_place, + // and then immediately drop the pool. + + pool.spawn(track(async { + crate::task::block_in_place(|| {}); + })); + pool.spawn(track(async {})); + drop(pool); + }); + } + + #[test] + fn pool_multi_spawn() { + loom::model(|| { + let pool = mk_pool(2); + let c1 = Arc::new(AtomicUsize::new(0)); + + let (tx, rx) = oneshot::channel(); + let tx1 = AtomicOneshot::new(tx); + + // Spawn a task + let c2 = c1.clone(); + let tx2 = tx1.clone(); + pool.spawn(track(async move { + spawn(track(async move { + if 1 == c1.fetch_add(1, Relaxed) { + tx1.assert_send(()); + } + })); + })); + + // Spawn a second task + pool.spawn(track(async move { + spawn(track(async move { + if 1 == c2.fetch_add(1, Relaxed) { + tx2.assert_send(()); + } + })); + })); + + rx.recv(); + }); + } + + fn only_blocking_inner(first_pending: bool) { + loom::model(move || { + let pool = mk_pool(1); + let (block_tx, block_rx) = oneshot::channel(); + + pool.spawn(track(async move { + crate::task::block_in_place(move || { + block_tx.send(()); + }); + if first_pending { + task::yield_now().await + } + })); + + block_rx.recv(); + drop(pool); + }); + } + + #[test] + fn only_blocking_without_pending() { + only_blocking_inner(false) + } + + #[test] + fn only_blocking_with_pending() { + only_blocking_inner(true) + } +} + +mod group_b { + use super::*; + + fn blocking_and_regular_inner(first_pending: bool) { + const NUM: usize = 3; + loom::model(move || { + let pool = mk_pool(1); + let cnt = Arc::new(AtomicUsize::new(0)); + + let (block_tx, block_rx) = oneshot::channel(); + let (done_tx, done_rx) = oneshot::channel(); + let done_tx = AtomicOneshot::new(done_tx); + + pool.spawn(track(async move { + crate::task::block_in_place(move || { + block_tx.send(()); + }); + if first_pending { + task::yield_now().await + } + })); + + for _ in 0..NUM { + let cnt = cnt.clone(); + let done_tx = done_tx.clone(); + + pool.spawn(track(async move { + if NUM == cnt.fetch_add(1, Relaxed) + 1 { + done_tx.assert_send(()); + } + })); + } + + done_rx.recv(); + block_rx.recv(); + + drop(pool); + }); + } + + #[test] + fn blocking_and_regular() { + blocking_and_regular_inner(false); + } + + #[test] + fn blocking_and_regular_with_pending() { + blocking_and_regular_inner(true); + } + + #[test] + fn join_output() { + loom::model(|| { + let rt = mk_pool(1); + + rt.block_on(async { + let t = crate::spawn(track(async { "hello" })); + + let out = assert_ok!(t.await); + assert_eq!("hello", out.into_inner()); + }); + }); + } + + #[test] + fn poll_drop_handle_then_drop() { + loom::model(|| { + let rt = mk_pool(1); + + rt.block_on(async move { + let mut t = crate::spawn(track(async { "hello" })); + + poll_fn(|cx| { + let _ = Pin::new(&mut t).poll(cx); + Poll::Ready(()) + }) + .await; + }); + }) + } + + #[test] + fn complete_block_on_under_load() { + loom::model(|| { + let pool = mk_pool(1); + + pool.block_on(async { + // Trigger a re-schedule + crate::spawn(track(async { + for _ in 0..2 { + task::yield_now().await; + } + })); + + gated2(true).await + }); + }); + } + + #[test] + fn shutdown_with_notification() { + use crate::sync::oneshot; + + loom::model(|| { + let rt = mk_pool(2); + let (done_tx, done_rx) = oneshot::channel::<()>(); + + rt.spawn(track(async move { + let (tx, rx) = oneshot::channel::<()>(); + + crate::spawn(async move { + crate::task::spawn_blocking(move || { + let _ = tx.send(()); + }); + + let _ = done_rx.await; + }); + + let _ = rx.await; + + let _ = done_tx.send(()); + })); + }); + } +} + +mod group_c { + use super::*; + + #[test] + fn pool_shutdown() { + loom::model(|| { + let pool = mk_pool(2); + + pool.spawn(track(async move { + gated2(true).await; + })); + + pool.spawn(track(async move { + gated2(false).await; + })); + + drop(pool); + }); + } +} + +mod group_d { + use super::*; + + #[test] + fn pool_multi_notify() { + loom::model(|| { + let pool = mk_pool(2); + + let c1 = Arc::new(AtomicUsize::new(0)); + + let (done_tx, done_rx) = oneshot::channel(); + let done_tx1 = AtomicOneshot::new(done_tx); + let done_tx2 = done_tx1.clone(); + + // Spawn a task + let c2 = c1.clone(); + pool.spawn(track(async move { + gated().await; + gated().await; + + if 1 == c1.fetch_add(1, Relaxed) { + done_tx1.assert_send(()); + } + })); + + // Spawn a second task + pool.spawn(track(async move { + gated().await; + gated().await; + + if 1 == c2.fetch_add(1, Relaxed) { + done_tx2.assert_send(()); + } + })); + + done_rx.recv(); + }); + } +} + +fn mk_pool(num_threads: usize) -> Runtime { + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build() + .unwrap() +} + +fn gated() -> impl Future<Output = &'static str> { + gated2(false) +} + +fn gated2(thread: bool) -> impl Future<Output = &'static str> { + use loom::thread; + use std::sync::Arc; + + let gate = Arc::new(AtomicBool::new(false)); + let mut fired = false; + + poll_fn(move |cx| { + if !fired { + let gate = gate.clone(); + let waker = cx.waker().clone(); + + if thread { + thread::spawn(move || { + gate.store(true, SeqCst); + waker.wake_by_ref(); + }); + } else { + spawn(track(async move { + gate.store(true, SeqCst); + waker.wake_by_ref(); + })); + } + + fired = true; + + return Poll::Pending; + } + + if gate.load(SeqCst) { + Poll::Ready("hello world") + } else { + Poll::Pending + } + }) +} + +fn track<T: Future>(f: T) -> Track<T> { + Track { + inner: f, + arc: Arc::new(()), + } +} + +pin_project! { + struct Track<T> { + #[pin] + inner: T, + // Arc is used to hook into loom's leak tracking. + arc: Arc<()>, + } +} + +impl<T> Track<T> { + fn into_inner(self) -> T { + self.inner + } +} + +impl<T: Future> Future for Track<T> { + type Output = Track<T::Output>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + Poll::Ready(Track { + inner: ready!(me.inner.poll(cx)), + arc: me.arc.clone(), + }) + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_queue.rs b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs new file mode 100644 index 0000000000..b5f78d7ebe --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_queue.rs @@ -0,0 +1,208 @@ +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::Inject; +use crate::runtime::{queue, MetricsBatch}; + +use loom::thread; + +#[test] +fn basic() { + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..3 { + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + } + + n + }); + + let mut n = 0; + + for _ in 0..2 { + for _ in 0..2 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + if local.pop().is_some() { + n += 1; + } + + // Push another task + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + while local.pop().is_some() { + n += 1; + } + } + + while inject.pop().is_some() { + n += 1; + } + + n += th.join().unwrap(); + + assert_eq!(6, n); + }); +} + +#[test] +fn steal_overflow() { + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + n + }); + + let mut n = 0; + + // push a task, pop a task + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + if local.pop().is_some() { + n += 1; + } + + for _ in 0..6 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + n += th.join().unwrap(); + + while local.pop().is_some() { + n += 1; + } + + while inject.pop().is_some() { + n += 1; + } + + assert_eq!(7, n); + }); +} + +#[test] +fn multi_stealer() { + const NUM_TASKS: usize = 5; + + fn steal_tasks(steal: queue::Steal<NoopSchedule>) -> usize { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + + if steal.steal_into(&mut local, &mut metrics).is_none() { + return 0; + } + + let mut n = 1; + + while local.pop().is_some() { + n += 1; + } + + n + } + + loom::model(|| { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + // Push work + for _ in 0..NUM_TASKS { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + let th1 = { + let steal = steal.clone(); + thread::spawn(move || steal_tasks(steal)) + }; + + let th2 = thread::spawn(move || steal_tasks(steal)); + + let mut n = 0; + + while local.pop().is_some() { + n += 1; + } + + while inject.pop().is_some() { + n += 1; + } + + n += th1.join().unwrap(); + n += th2.join().unwrap(); + + assert_eq!(n, NUM_TASKS); + }); +} + +#[test] +fn chained_steal() { + loom::model(|| { + let mut metrics = MetricsBatch::new(); + let (s1, mut l1) = queue::local(); + let (s2, mut l2) = queue::local(); + let inject = Inject::new(); + + // Load up some tasks + for _ in 0..4 { + let (task, _) = super::unowned(async {}); + l1.push_back(task, &inject, &mut metrics); + + let (task, _) = super::unowned(async {}); + l2.push_back(task, &inject, &mut metrics); + } + + // Spawn a task to steal from **our** queue + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + s1.steal_into(&mut local, &mut metrics); + + while local.pop().is_some() {} + }); + + // Drain our tasks, then attempt to steal + while l1.pop().is_some() {} + + s2.steal_into(&mut l1, &mut metrics); + + th.join().unwrap(); + + while l1.pop().is_some() {} + while l2.pop().is_some() {} + while inject.pop().is_some() {} + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs new file mode 100644 index 0000000000..6fbc4bfded --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/loom_shutdown_join.rs @@ -0,0 +1,28 @@ +use crate::runtime::{Builder, Handle}; + +#[test] +fn join_handle_cancel_on_shutdown() { + let mut builder = loom::model::Builder::new(); + builder.preemption_bound = Some(2); + builder.check(|| { + use futures::future::FutureExt; + + let rt = Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(); + + let handle = rt.block_on(async move { Handle::current() }); + + let jh1 = handle.spawn(futures::future::pending::<()>()); + + drop(rt); + + let jh2 = handle.spawn(futures::future::pending::<()>()); + + let err1 = jh1.now_or_never().unwrap().unwrap_err(); + let err2 = jh2.now_or_never().unwrap().unwrap_err(); + assert!(err1.is_cancelled()); + assert!(err2.is_cancelled()); + }); +} diff --git a/third_party/rust/tokio/src/runtime/tests/mod.rs b/third_party/rust/tokio/src/runtime/tests/mod.rs new file mode 100644 index 0000000000..4b49698a86 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/mod.rs @@ -0,0 +1,50 @@ +use self::unowned_wrapper::unowned; + +mod unowned_wrapper { + use crate::runtime::blocking::NoopSchedule; + use crate::runtime::task::{JoinHandle, Notified}; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + use tracing::Instrument; + let span = tracing::trace_span!("test_span"); + let task = task.instrument(span); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } +} + +cfg_loom! { + mod loom_basic_scheduler; + mod loom_blocking; + mod loom_local; + mod loom_oneshot; + mod loom_pool; + mod loom_queue; + mod loom_shutdown_join; + mod loom_join_set; +} + +cfg_not_loom! { + mod queue; + + #[cfg(not(miri))] + mod task_combinations; + + #[cfg(miri)] + mod task; +} diff --git a/third_party/rust/tokio/src/runtime/tests/queue.rs b/third_party/rust/tokio/src/runtime/tests/queue.rs new file mode 100644 index 0000000000..0fd1e0c6d9 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/queue.rs @@ -0,0 +1,248 @@ +use crate::runtime::queue; +use crate::runtime::task::{self, Inject, Schedule, Task}; +use crate::runtime::MetricsBatch; + +use std::thread; +use std::time::Duration; + +#[allow(unused)] +macro_rules! assert_metrics { + ($metrics:ident, $field:ident == $v:expr) => {{ + use crate::runtime::WorkerMetrics; + use std::sync::atomic::Ordering::Relaxed; + + let worker = WorkerMetrics::new(); + $metrics.submit(&worker); + + let expect = $v; + let actual = worker.$field.load(Relaxed); + + assert!(actual == expect, "expect = {}; actual = {}", expect, actual) + }}; +} + +#[test] +fn fits_256() { + let (_, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + for _ in 0..256 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + cfg_metrics! { + assert_metrics!(metrics, overflow_count == 0); + } + + assert!(inject.pop().is_none()); + + while local.pop().is_some() {} +} + +#[test] +fn overflow() { + let (_, mut local) = queue::local(); + let inject = Inject::new(); + let mut metrics = MetricsBatch::new(); + + for _ in 0..257 { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + cfg_metrics! { + assert_metrics!(metrics, overflow_count == 1); + } + + let mut n = 0; + + while inject.pop().is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + assert_eq!(n, 257); +} + +#[test] +fn steal_batch() { + let mut metrics = MetricsBatch::new(); + + let (steal1, mut local1) = queue::local(); + let (_, mut local2) = queue::local(); + let inject = Inject::new(); + + for _ in 0..4 { + let (task, _) = super::unowned(async {}); + local1.push_back(task, &inject, &mut metrics); + } + + assert!(steal1.steal_into(&mut local2, &mut metrics).is_some()); + + cfg_metrics! { + assert_metrics!(metrics, steal_count == 2); + } + + for _ in 0..1 { + assert!(local2.pop().is_some()); + } + + assert!(local2.pop().is_none()); + + for _ in 0..2 { + assert!(local1.pop().is_some()); + } + + assert!(local1.pop().is_none()); +} + +const fn normal_or_miri(normal: usize, miri: usize) -> usize { + if cfg!(miri) { + miri + } else { + normal + } +} + +#[test] +fn stress1() { + const NUM_ITER: usize = 1; + const NUM_STEAL: usize = normal_or_miri(1_000, 10); + const NUM_LOCAL: usize = normal_or_miri(1_000, 10); + const NUM_PUSH: usize = normal_or_miri(500, 10); + const NUM_POP: usize = normal_or_miri(250, 10); + + let mut metrics = MetricsBatch::new(); + + for _ in 0..NUM_ITER { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + + let th = thread::spawn(move || { + let mut metrics = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..NUM_STEAL { + if steal.steal_into(&mut local, &mut metrics).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + thread::yield_now(); + } + + cfg_metrics! { + assert_metrics!(metrics, steal_count == n as _); + } + + n + }); + + let mut n = 0; + + for _ in 0..NUM_LOCAL { + for _ in 0..NUM_PUSH { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + } + + for _ in 0..NUM_POP { + if local.pop().is_some() { + n += 1; + } else { + break; + } + } + } + + while inject.pop().is_some() { + n += 1; + } + + n += th.join().unwrap(); + + assert_eq!(n, NUM_LOCAL * NUM_PUSH); + } +} + +#[test] +fn stress2() { + const NUM_ITER: usize = 1; + const NUM_TASKS: usize = normal_or_miri(1_000_000, 50); + const NUM_STEAL: usize = normal_or_miri(1_000, 10); + + let mut metrics = MetricsBatch::new(); + + for _ in 0..NUM_ITER { + let (steal, mut local) = queue::local(); + let inject = Inject::new(); + + let th = thread::spawn(move || { + let mut stats = MetricsBatch::new(); + let (_, mut local) = queue::local(); + let mut n = 0; + + for _ in 0..NUM_STEAL { + if steal.steal_into(&mut local, &mut stats).is_some() { + n += 1; + } + + while local.pop().is_some() { + n += 1; + } + + thread::sleep(Duration::from_micros(10)); + } + + n + }); + + let mut num_pop = 0; + + for i in 0..NUM_TASKS { + let (task, _) = super::unowned(async {}); + local.push_back(task, &inject, &mut metrics); + + if i % 128 == 0 && local.pop().is_some() { + num_pop += 1; + } + + while inject.pop().is_some() { + num_pop += 1; + } + } + + num_pop += th.join().unwrap(); + + while local.pop().is_some() { + num_pop += 1; + } + + while inject.pop().is_some() { + num_pop += 1; + } + + assert_eq!(num_pop, NUM_TASKS); + } +} + +struct Runtime; + +impl Schedule for Runtime { + fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { + None + } + + fn schedule(&self, _task: task::Notified<Self>) { + unreachable!(); + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/task.rs b/third_party/rust/tokio/src/runtime/tests/task.rs new file mode 100644 index 0000000000..04e1b56e77 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/task.rs @@ -0,0 +1,288 @@ +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::util::TryLock; + +use std::collections::VecDeque; +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +struct AssertDropHandle { + is_dropped: Arc<AtomicBool>, +} +impl AssertDropHandle { + #[track_caller] + fn assert_dropped(&self) { + assert!(self.is_dropped.load(Ordering::SeqCst)); + } + + #[track_caller] + fn assert_not_dropped(&self) { + assert!(!self.is_dropped.load(Ordering::SeqCst)); + } +} + +struct AssertDrop { + is_dropped: Arc<AtomicBool>, +} +impl AssertDrop { + fn new() -> (Self, AssertDropHandle) { + let shared = Arc::new(AtomicBool::new(false)); + ( + AssertDrop { + is_dropped: shared.clone(), + }, + AssertDropHandle { + is_dropped: shared.clone(), + }, + ) + } +} +impl Drop for AssertDrop { + fn drop(&mut self) { + self.is_dropped.store(true, Ordering::SeqCst); + } +} + +// A Notified does not shut down on drop, but it is dropped once the ref-count +// hits zero. +#[test] +fn create_drop1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(notified); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + +#[test] +fn create_drop2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_dropped(); +} + +// Shutting down through Notified works +#[test] +fn create_shutdown1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); +} + +#[test] +fn create_shutdown2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); + drop(join); +} + +#[test] +fn unowned_poll() { + let (task, _) = unowned(async {}, NoopSchedule); + task.run(); +} + +#[test] +fn schedule() { + with(|rt| { + rt.spawn(async { + crate::task::yield_now().await; + }); + + assert_eq!(2, rt.tick()); + rt.shutdown(); + }) +} + +#[test] +fn shutdown() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + + rt.shutdown(); + }) +} + +#[test] +fn shutdown_immediately() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.shutdown(); + }) +} + +#[test] +fn spawn_during_shutdown() { + static DID_SPAWN: AtomicBool = AtomicBool::new(false); + + struct SpawnOnDrop(Runtime); + impl Drop for SpawnOnDrop { + fn drop(&mut self) { + DID_SPAWN.store(true, Ordering::SeqCst); + self.0.spawn(async {}); + } + } + + with(|rt| { + let rt2 = rt.clone(); + rt.spawn(async move { + let _spawn_on_drop = SpawnOnDrop(rt2); + + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + rt.shutdown(); + }); + + assert!(DID_SPAWN.load(Ordering::SeqCst)); +} + +fn with(f: impl FnOnce(Runtime)) { + struct Reset; + + impl Drop for Reset { + fn drop(&mut self) { + let _rt = CURRENT.try_lock().unwrap().take(); + } + } + + let _reset = Reset; + + let rt = Runtime(Arc::new(Inner { + owned: OwnedTasks::new(), + core: TryLock::new(Core { + queue: VecDeque::new(), + }), + })); + + *CURRENT.try_lock().unwrap() = Some(rt.clone()); + f(rt) +} + +#[derive(Clone)] +struct Runtime(Arc<Inner>); + +struct Inner { + core: TryLock<Core>, + owned: OwnedTasks<Runtime>, +} + +struct Core { + queue: VecDeque<task::Notified<Runtime>>, +} + +static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); + +impl Runtime { + fn spawn<T>(&self, future: T) -> JoinHandle<T::Output> + where + T: 'static + Send + Future, + T::Output: 'static + Send, + { + let (handle, notified) = self.0.owned.bind(future, self.clone()); + + if let Some(notified) = notified { + self.schedule(notified); + } + + handle + } + + fn tick(&self) -> usize { + self.tick_max(usize::MAX) + } + + fn tick_max(&self, max: usize) -> usize { + let mut n = 0; + + while !self.is_empty() && n < max { + let task = self.next_task(); + n += 1; + let task = self.0.owned.assert_owner(task); + task.run(); + } + + n + } + + fn is_empty(&self) -> bool { + self.0.core.try_lock().unwrap().queue.is_empty() + } + + fn next_task(&self) -> task::Notified<Runtime> { + self.0.core.try_lock().unwrap().queue.pop_front().unwrap() + } + + fn shutdown(&self) { + let mut core = self.0.core.try_lock().unwrap(); + + self.0.owned.close_and_shutdown_all(); + + while let Some(task) = core.queue.pop_back() { + drop(task); + } + + drop(core); + + assert!(self.0.owned.is_empty()); + } +} + +impl Schedule for Runtime { + fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { + self.0.owned.remove(task) + } + + fn schedule(&self, task: task::Notified<Self>) { + self.0.core.try_lock().unwrap().queue.push_back(task); + } +} diff --git a/third_party/rust/tokio/src/runtime/tests/task_combinations.rs b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs new file mode 100644 index 0000000000..76ce2330c2 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/tests/task_combinations.rs @@ -0,0 +1,380 @@ +use std::future::Future; +use std::panic; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Builder; +use crate::sync::oneshot; +use crate::task::JoinHandle; + +use futures::future::FutureExt; + +// Enums for each option in the combinations being tested + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiRuntime { + CurrentThread, + Multi1, + Multi2, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiLocalSet { + Yes, + No, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiTask { + PanicOnRun, + PanicOnDrop, + PanicOnRunAndDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiOutput { + PanicOnDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinInterest { + Polled, + NotPolled, +} +#[allow(clippy::enum_variant_names)] // we aren't using glob imports +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinHandle { + DropImmediately = 1, + DropFirstPoll = 2, + DropAfterNoConsume = 3, + DropAfterConsume = 4, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbort { + NotAborted = 0, + AbortedImmediately = 1, + AbortedFirstPoll = 2, + AbortedAfterFinish = 3, + AbortedAfterConsumeOutput = 4, +} + +#[test] +fn test_combinations() { + let mut rt = &[ + CombiRuntime::CurrentThread, + CombiRuntime::Multi1, + CombiRuntime::Multi2, + ][..]; + + if cfg!(miri) { + rt = &[CombiRuntime::CurrentThread]; + } + + let ls = [CombiLocalSet::Yes, CombiLocalSet::No]; + let task = [ + CombiTask::NoPanic, + CombiTask::PanicOnRun, + CombiTask::PanicOnDrop, + CombiTask::PanicOnRunAndDrop, + ]; + let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop]; + let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled]; + let jh = [ + CombiJoinHandle::DropImmediately, + CombiJoinHandle::DropFirstPoll, + CombiJoinHandle::DropAfterNoConsume, + CombiJoinHandle::DropAfterConsume, + ]; + let abort = [ + CombiAbort::NotAborted, + CombiAbort::AbortedImmediately, + CombiAbort::AbortedFirstPoll, + CombiAbort::AbortedAfterFinish, + CombiAbort::AbortedAfterConsumeOutput, + ]; + + for rt in rt.iter().copied() { + for ls in ls.iter().copied() { + for task in task.iter().copied() { + for output in output.iter().copied() { + for ji in ji.iter().copied() { + for jh in jh.iter().copied() { + for abort in abort.iter().copied() { + test_combination(rt, ls, task, output, ji, jh, abort); + } + } + } + } + } + } + } +} + +fn test_combination( + rt: CombiRuntime, + ls: CombiLocalSet, + task: CombiTask, + output: CombiOutput, + ji: CombiJoinInterest, + jh: CombiJoinHandle, + abort: CombiAbort, +) { + if (jh as usize) < (abort as usize) { + // drop before abort not possible + return; + } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { + // this causes double panic + return; + } + if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) { + // this causes double panic + return; + } + + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + + // A runtime optionally with a LocalSet + struct Rt { + rt: crate::runtime::Runtime, + ls: Option<crate::task::LocalSet>, + } + impl Rt { + fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self { + let rt = match rt { + CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(), + CombiRuntime::Multi1 => Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(), + CombiRuntime::Multi2 => Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(), + }; + + let ls = match ls { + CombiLocalSet::Yes => Some(crate::task::LocalSet::new()), + CombiLocalSet::No => None, + }; + + Self { rt, ls } + } + fn block_on<T>(&self, task: T) -> T::Output + where + T: Future, + { + match &self.ls { + Some(ls) => ls.block_on(&self.rt, task), + None => self.rt.block_on(task), + } + } + fn spawn<T>(&self, task: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + match &self.ls { + Some(ls) => ls.spawn_local(task), + None => self.rt.spawn(task), + } + } + } + + // The type used for the output of the future + struct Output { + panic_on_drop: bool, + on_drop: Option<oneshot::Sender<()>>, + } + impl Output { + fn disarm(&mut self) { + self.panic_on_drop = false; + } + } + impl Drop for Output { + fn drop(&mut self) { + let _ = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in Output"); + } + } + } + + // A wrapper around the future that is spawned + struct FutWrapper<F> { + inner: F, + on_drop: Option<oneshot::Sender<()>>, + panic_on_drop: bool, + } + impl<F: Future> Future for FutWrapper<F> { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> { + unsafe { + let me = Pin::into_inner_unchecked(self); + let inner = Pin::new_unchecked(&mut me.inner); + inner.poll(cx) + } + } + } + impl<F> Drop for FutWrapper<F> { + fn drop(&mut self) { + let _: Result<(), ()> = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in FutWrapper"); + } + } + } + + // The channels passed to the task + struct Signals { + on_first_poll: Option<oneshot::Sender<()>>, + wait_complete: Option<oneshot::Receiver<()>>, + on_output_drop: Option<oneshot::Sender<()>>, + } + + // The task we will spawn + async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output { + // Signal that we have been polled once + let _ = signal.on_first_poll.take().unwrap().send(()); + + // Wait for a signal, then complete the future + let _ = signal.wait_complete.take().unwrap().await; + + // If the task gets past wait_complete without yielding, then aborts + // may not be caught without this yield_now. + crate::task::yield_now().await; + + if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop { + panic!("Panicking in my_task on {:?}", std::thread::current().id()); + } + + Output { + panic_on_drop: out == CombiOutput::PanicOnDrop, + on_drop: signal.on_output_drop.take(), + } + } + + let rt = Rt::new(rt, ls); + + let (on_first_poll, wait_first_poll) = oneshot::channel(); + let (on_complete, wait_complete) = oneshot::channel(); + let (on_future_drop, wait_future_drop) = oneshot::channel(); + let (on_output_drop, wait_output_drop) = oneshot::channel(); + let signal = Signals { + on_first_poll: Some(on_first_poll), + wait_complete: Some(wait_complete), + on_output_drop: Some(on_output_drop), + }; + + // === Spawn task === + let mut handle = Some(rt.spawn(FutWrapper { + inner: my_task(signal, task, output), + on_drop: Some(on_future_drop), + panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop, + })); + + // Keep track of whether the task has been killed with an abort + let mut aborted = false; + + // If we want to poll the JoinHandle, do it now + if ji == CombiJoinInterest::Polled { + assert!( + handle.as_mut().unwrap().now_or_never().is_none(), + "Polling handle succeeded" + ); + } + + if abort == CombiAbort::AbortedImmediately { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropImmediately { + drop(handle.take().unwrap()); + } + + // === Wait for first poll === + let got_polled = rt.block_on(wait_first_poll).is_ok(); + if !got_polled { + // it's possible that we are aborted but still got polled + assert!( + aborted, + "Task completed without ever being polled but was not aborted." + ); + } + + if abort == CombiAbort::AbortedFirstPoll { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropFirstPoll { + drop(handle.take().unwrap()); + } + + // Signal the future that it can return now + let _ = on_complete.send(()); + // === Wait for future to be dropped === + assert!( + rt.block_on(wait_future_drop).is_ok(), + "The future should always be dropped." + ); + + if abort == CombiAbort::AbortedAfterFinish { + // Don't set aborted to true here as the task already finished + handle.as_mut().unwrap().abort(); + } + if jh == CombiJoinHandle::DropAfterNoConsume { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } + } + + // Check whether we drop after consuming the output + if jh == CombiJoinHandle::DropAfterConsume { + // Using as_mut() to not immediately drop the handle + let result = rt.block_on(handle.as_mut().unwrap()); + + match result { + Ok(mut output) => { + // Don't panic here. + output.disarm(); + assert!(!aborted, "Task was aborted but returned output"); + } + Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"), + Err(err) if err.is_panic() => { + assert!( + (task == CombiTask::PanicOnRun) + || (task == CombiTask::PanicOnDrop) + || (task == CombiTask::PanicOnRunAndDrop) + || (output == CombiOutput::PanicOnDrop), + "Panic but nothing should panic" + ); + } + _ => unreachable!(), + } + + let handle = handle.take().unwrap(); + if abort == CombiAbort::AbortedAfterConsumeOutput { + handle.abort(); + } + drop(handle); + } + + // The output should have been dropped now. Check whether the output + // object was created at all. + let output_created = rt.block_on(wait_output_drop).is_ok(); + assert_eq!( + output_created, + (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted, + "Creation of output object" + ); +} diff --git a/third_party/rust/tokio/src/runtime/thread_pool/idle.rs b/third_party/rust/tokio/src/runtime/thread_pool/idle.rs new file mode 100644 index 0000000000..a57bf6a0b1 --- /dev/null +++ b/third_party/rust/tokio/src/runtime/thread_pool/idle.rs @@ -0,0 +1,226 @@ +//! Coordinates idling workers + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; + +use std::fmt; +use std::sync::atomic::Ordering::{self, SeqCst}; + +pub(super) struct Idle { + /// Tracks both the number of searching workers and the number of unparked + /// workers. + /// + /// Used as a fast-path to avoid acquiring the lock when needed. + state: AtomicUsize, + + /// Sleeping workers + sleepers: Mutex<Vec<usize>>, + + /// Total number of workers. + num_workers: usize, +} + +const UNPARK_SHIFT: usize = 16; +const UNPARK_MASK: usize = !SEARCH_MASK; +const SEARCH_MASK: usize = (1 << UNPARK_SHIFT) - 1; + +#[derive(Copy, Clone)] +struct State(usize); + +impl Idle { + pub(super) fn new(num_workers: usize) -> Idle { + let init = State::new(num_workers); + + Idle { + state: AtomicUsize::new(init.into()), + sleepers: Mutex::new(Vec::with_capacity(num_workers)), + num_workers, + } + } + + /// If there are no workers actively searching, returns the index of a + /// worker currently sleeping. + pub(super) fn worker_to_notify(&self) -> Option<usize> { + // If at least one worker is spinning, work being notified will + // eventually be found. A searching thread will find **some** work and + // notify another worker, eventually leading to our work being found. + // + // For this to happen, this load must happen before the thread + // transitioning `num_searching` to zero. Acquire / Release does not + // provide sufficient guarantees, so this load is done with `SeqCst` and + // will pair with the `fetch_sub(1)` when transitioning out of + // searching. + if !self.notify_should_wakeup() { + return None; + } + + // Acquire the lock + let mut sleepers = self.sleepers.lock(); + + // Check again, now that the lock is acquired + if !self.notify_should_wakeup() { + return None; + } + + // A worker should be woken up, atomically increment the number of + // searching workers as well as the number of unparked workers. + State::unpark_one(&self.state, 1); + + // Get the worker to unpark + let ret = sleepers.pop(); + debug_assert!(ret.is_some()); + + ret + } + + /// Returns `true` if the worker needs to do a final check for submitted + /// work. + pub(super) fn transition_worker_to_parked(&self, worker: usize, is_searching: bool) -> bool { + // Acquire the lock + let mut sleepers = self.sleepers.lock(); + + // Decrement the number of unparked threads + let ret = State::dec_num_unparked(&self.state, is_searching); + + // Track the sleeping worker + sleepers.push(worker); + + ret + } + + pub(super) fn transition_worker_to_searching(&self) -> bool { + let state = State::load(&self.state, SeqCst); + if 2 * state.num_searching() >= self.num_workers { + return false; + } + + // It is possible for this routine to allow more than 50% of the workers + // to search. That is OK. Limiting searchers is only an optimization to + // prevent too much contention. + State::inc_num_searching(&self.state, SeqCst); + true + } + + /// A lightweight transition from searching -> running. + /// + /// Returns `true` if this is the final searching worker. The caller + /// **must** notify a new worker. + pub(super) fn transition_worker_from_searching(&self) -> bool { + State::dec_num_searching(&self.state) + } + + /// Unpark a specific worker. This happens if tasks are submitted from + /// within the worker's park routine. + /// + /// Returns `true` if the worker was parked before calling the method. + pub(super) fn unpark_worker_by_id(&self, worker_id: usize) -> bool { + let mut sleepers = self.sleepers.lock(); + + for index in 0..sleepers.len() { + if sleepers[index] == worker_id { + sleepers.swap_remove(index); + + // Update the state accordingly while the lock is held. + State::unpark_one(&self.state, 0); + + return true; + } + } + + false + } + + /// Returns `true` if `worker_id` is contained in the sleep set. + pub(super) fn is_parked(&self, worker_id: usize) -> bool { + let sleepers = self.sleepers.lock(); + sleepers.contains(&worker_id) + } + + fn notify_should_wakeup(&self) -> bool { + let state = State(self.state.fetch_add(0, SeqCst)); + state.num_searching() == 0 && state.num_unparked() < self.num_workers + } +} + +impl State { + fn new(num_workers: usize) -> State { + // All workers start in the unparked state + let ret = State(num_workers << UNPARK_SHIFT); + debug_assert_eq!(num_workers, ret.num_unparked()); + debug_assert_eq!(0, ret.num_searching()); + ret + } + + fn load(cell: &AtomicUsize, ordering: Ordering) -> State { + State(cell.load(ordering)) + } + + fn unpark_one(cell: &AtomicUsize, num_searching: usize) { + cell.fetch_add(num_searching | (1 << UNPARK_SHIFT), SeqCst); + } + + fn inc_num_searching(cell: &AtomicUsize, ordering: Ordering) { + cell.fetch_add(1, ordering); + } + + /// Returns `true` if this is the final searching worker + fn dec_num_searching(cell: &AtomicUsize) -> bool { + let state = State(cell.fetch_sub(1, SeqCst)); + state.num_searching() == 1 + } + + /// Track a sleeping worker + /// + /// Returns `true` if this is the final searching worker. + fn dec_num_unparked(cell: &AtomicUsize, is_searching: bool) -> bool { + let mut dec = 1 << UNPARK_SHIFT; + + if is_searching { + dec += 1; + } + + let prev = State(cell.fetch_sub(dec, SeqCst)); + is_searching && prev.num_searching() == 1 + } + + /// Number of workers currently searching + fn num_searching(self) -> usize { + self.0 & SEARCH_MASK + } + + /// Number of workers currently unparked + fn num_unparked(self) -> usize { + (self.0 & UNPARK_MASK) >> UNPARK_SHIFT + } +} + +impl From<usize> for State { + fn from(src: usize) -> State { + State(src) + } +} + +impl From<State> for usize { + fn from(src: State) -> usize { + src.0 + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("worker::State") + .field("num_unparked", &self.num_unparked()) + .field("num_searching", &self.num_searching()) + .finish() + } +} + +#[test] +fn test_state() { + assert_eq!(0, UNPARK_MASK & SEARCH_MASK); + assert_eq!(0, !(UNPARK_MASK | SEARCH_MASK)); + + let state = State::new(10); + assert_eq!(10, state.num_unparked()); + assert_eq!(0, state.num_searching()); +} diff --git a/third_party/rust/tokio/src/runtime/thread_pool/mod.rs b/third_party/rust/tokio/src/runtime/thread_pool/mod.rs new file mode 100644 index 0000000000..d3f46517cb --- /dev/null +++ b/third_party/rust/tokio/src/runtime/thread_pool/mod.rs @@ -0,0 +1,136 @@ +//! Threadpool + +mod idle; +use self::idle::Idle; + +mod worker; +pub(crate) use worker::Launch; + +pub(crate) use worker::block_in_place; + +use crate::loom::sync::Arc; +use crate::runtime::task::JoinHandle; +use crate::runtime::{Callback, Parker}; + +use std::fmt; +use std::future::Future; + +/// Work-stealing based thread pool for executing futures. +pub(crate) struct ThreadPool { + spawner: Spawner, +} + +/// Submits futures to the associated thread pool for execution. +/// +/// A `Spawner` instance is a handle to a single thread pool that allows the owner +/// of the handle to spawn futures onto the thread pool. +/// +/// The `Spawner` handle is *only* used for spawning new futures. It does not +/// impact the lifecycle of the thread pool in any way. The thread pool may +/// shut down while there are outstanding `Spawner` instances. +/// +/// `Spawner` instances are obtained by calling [`ThreadPool::spawner`]. +/// +/// [`ThreadPool::spawner`]: method@ThreadPool::spawner +#[derive(Clone)] +pub(crate) struct Spawner { + shared: Arc<worker::Shared>, +} + +// ===== impl ThreadPool ===== + +impl ThreadPool { + pub(crate) fn new( + size: usize, + parker: Parker, + before_park: Option<Callback>, + after_unpark: Option<Callback>, + ) -> (ThreadPool, Launch) { + let (shared, launch) = worker::create(size, parker, before_park, after_unpark); + let spawner = Spawner { shared }; + let thread_pool = ThreadPool { spawner }; + + (thread_pool, launch) + } + + /// Returns reference to `Spawner`. + /// + /// The `Spawner` handle can be cloned and enables spawning tasks from other + /// threads. + pub(crate) fn spawner(&self) -> &Spawner { + &self.spawner + } + + /// Blocks the current thread waiting for the future to complete. + /// + /// The future will execute on the current thread, but all spawned tasks + /// will be executed on the thread pool. + pub(crate) fn block_on<F>(&self, future: F) -> F::Output + where + F: Future, + { + let mut enter = crate::runtime::enter(true); + enter.block_on(future).expect("failed to park thread") + } +} + +impl fmt::Debug for ThreadPool { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("ThreadPool").finish() + } +} + +impl Drop for ThreadPool { + fn drop(&mut self) { + self.spawner.shutdown(); + } +} + +// ==== impl Spawner ===== + +impl Spawner { + /// Spawns a future onto the thread pool + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: crate::future::Future + Send + 'static, + F::Output: Send + 'static, + { + worker::Shared::bind_new_task(&self.shared, future) + } + + pub(crate) fn shutdown(&mut self) { + self.shared.close(); + } +} + +cfg_metrics! { + use crate::runtime::{SchedulerMetrics, WorkerMetrics}; + + impl Spawner { + pub(crate) fn num_workers(&self) -> usize { + self.shared.worker_metrics.len() + } + + pub(crate) fn scheduler_metrics(&self) -> &SchedulerMetrics { + &self.shared.scheduler_metrics + } + + pub(crate) fn worker_metrics(&self, worker: usize) -> &WorkerMetrics { + &self.shared.worker_metrics[worker] + } + + pub(crate) fn injection_queue_depth(&self) -> usize { + self.shared.injection_queue_depth() + } + + pub(crate) fn worker_local_queue_depth(&self, worker: usize) -> usize { + self.shared.worker_local_queue_depth(worker) + } + } +} + +impl fmt::Debug for Spawner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Spawner").finish() + } +} diff --git a/third_party/rust/tokio/src/runtime/thread_pool/worker.rs b/third_party/rust/tokio/src/runtime/thread_pool/worker.rs new file mode 100644 index 0000000000..7e4989701e --- /dev/null +++ b/third_party/rust/tokio/src/runtime/thread_pool/worker.rs @@ -0,0 +1,848 @@ +//! A scheduler is initialized with a fixed number of workers. Each worker is +//! driven by a thread. Each worker has a "core" which contains data such as the +//! run queue and other state. When `block_in_place` is called, the worker's +//! "core" is handed off to a new thread allowing the scheduler to continue to +//! make progress while the originating thread blocks. +//! +//! # Shutdown +//! +//! Shutting down the runtime involves the following steps: +//! +//! 1. The Shared::close method is called. This closes the inject queue and +//! OwnedTasks instance and wakes up all worker threads. +//! +//! 2. Each worker thread observes the close signal next time it runs +//! Core::maintenance by checking whether the inject queue is closed. +//! The Core::is_shutdown flag is set to true. +//! +//! 3. The worker thread calls `pre_shutdown` in parallel. Here, the worker +//! will keep removing tasks from OwnedTasks until it is empty. No new +//! tasks can be pushed to the OwnedTasks during or after this step as it +//! was closed in step 1. +//! +//! 5. The workers call Shared::shutdown to enter the single-threaded phase of +//! shutdown. These calls will push their core to Shared::shutdown_cores, +//! and the last thread to push its core will finish the shutdown procedure. +//! +//! 6. The local run queue of each core is emptied, then the inject queue is +//! emptied. +//! +//! At this point, shutdown has completed. It is not possible for any of the +//! collections to contain any tasks at this point, as each collection was +//! closed first, then emptied afterwards. +//! +//! ## Spawns during shutdown +//! +//! When spawning tasks during shutdown, there are two cases: +//! +//! * The spawner observes the OwnedTasks being open, and the inject queue is +//! closed. +//! * The spawner observes the OwnedTasks being closed and doesn't check the +//! inject queue. +//! +//! The first case can only happen if the OwnedTasks::bind call happens before +//! or during step 1 of shutdown. In this case, the runtime will clean up the +//! task in step 3 of shutdown. +//! +//! In the latter case, the task was not spawned and the task is immediately +//! cancelled by the spawner. +//! +//! The correctness of shutdown requires both the inject queue and OwnedTasks +//! collection to have a closed bit. With a close bit on only the inject queue, +//! spawning could run in to a situation where a task is successfully bound long +//! after the runtime has shut down. With a close bit on only the OwnedTasks, +//! the first spawning situation could result in the notification being pushed +//! to the inject queue after step 6 of shutdown, which would leave a task in +//! the inject queue indefinitely. This would be a ref-count cycle and a memory +//! leak. + +use crate::coop; +use crate::future::Future; +use crate::loom::rand::seed; +use crate::loom::sync::{Arc, Mutex}; +use crate::park::{Park, Unpark}; +use crate::runtime; +use crate::runtime::enter::EnterContext; +use crate::runtime::park::{Parker, Unparker}; +use crate::runtime::task::{Inject, JoinHandle, OwnedTasks}; +use crate::runtime::thread_pool::Idle; +use crate::runtime::{queue, task, Callback, MetricsBatch, SchedulerMetrics, WorkerMetrics}; +use crate::util::atomic_cell::AtomicCell; +use crate::util::FastRand; + +use std::cell::RefCell; +use std::time::Duration; + +/// A scheduler worker +pub(super) struct Worker { + /// Reference to shared state + shared: Arc<Shared>, + + /// Index holding this worker's remote state + index: usize, + + /// Used to hand-off a worker's core to another thread. + core: AtomicCell<Core>, +} + +/// Core data +struct Core { + /// Used to schedule bookkeeping tasks every so often. + tick: u8, + + /// When a task is scheduled from a worker, it is stored in this slot. The + /// worker will check this slot for a task **before** checking the run + /// queue. This effectively results in the **last** scheduled task to be run + /// next (LIFO). This is an optimization for message passing patterns and + /// helps to reduce latency. + lifo_slot: Option<Notified>, + + /// The worker-local run queue. + run_queue: queue::Local<Arc<Shared>>, + + /// True if the worker is currently searching for more work. Searching + /// involves attempting to steal from other workers. + is_searching: bool, + + /// True if the scheduler is being shutdown + is_shutdown: bool, + + /// Parker + /// + /// Stored in an `Option` as the parker is added / removed to make the + /// borrow checker happy. + park: Option<Parker>, + + /// Batching metrics so they can be submitted to RuntimeMetrics. + metrics: MetricsBatch, + + /// Fast random number generator. + rand: FastRand, +} + +/// State shared across all workers +pub(super) struct Shared { + /// Per-worker remote state. All other workers have access to this and is + /// how they communicate between each other. + remotes: Box<[Remote]>, + + /// Submits work to the scheduler while **not** currently on a worker thread. + inject: Inject<Arc<Shared>>, + + /// Coordinates idle workers + idle: Idle, + + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + + /// Cores that have observed the shutdown signal + /// + /// The core is **not** placed back in the worker to avoid it from being + /// stolen by a thread that was spawned as part of `block_in_place`. + #[allow(clippy::vec_box)] // we're moving an already-boxed value + shutdown_cores: Mutex<Vec<Box<Core>>>, + + /// Callback for a worker parking itself + before_park: Option<Callback>, + /// Callback for a worker unparking itself + after_unpark: Option<Callback>, + + /// Collects metrics from the runtime. + pub(super) scheduler_metrics: SchedulerMetrics, + + pub(super) worker_metrics: Box<[WorkerMetrics]>, +} + +/// Used to communicate with a worker from other threads. +struct Remote { + /// Steals tasks from this worker. + steal: queue::Steal<Arc<Shared>>, + + /// Unparks the associated worker thread + unpark: Unparker, +} + +/// Thread-local context +struct Context { + /// Worker + worker: Arc<Worker>, + + /// Core data + core: RefCell<Option<Box<Core>>>, +} + +/// Starts the workers +pub(crate) struct Launch(Vec<Arc<Worker>>); + +/// Running a task may consume the core. If the core is still available when +/// running the task completes, it is returned. Otherwise, the worker will need +/// to stop processing. +type RunResult = Result<Box<Core>, ()>; + +/// A task handle +type Task = task::Task<Arc<Shared>>; + +/// A notified task handle +type Notified = task::Notified<Arc<Shared>>; + +// Tracks thread-local state +scoped_thread_local!(static CURRENT: Context); + +pub(super) fn create( + size: usize, + park: Parker, + before_park: Option<Callback>, + after_unpark: Option<Callback>, +) -> (Arc<Shared>, Launch) { + let mut cores = vec![]; + let mut remotes = vec![]; + let mut worker_metrics = vec![]; + + // Create the local queues + for _ in 0..size { + let (steal, run_queue) = queue::local(); + + let park = park.clone(); + let unpark = park.unpark(); + + cores.push(Box::new(Core { + tick: 0, + lifo_slot: None, + run_queue, + is_searching: false, + is_shutdown: false, + park: Some(park), + metrics: MetricsBatch::new(), + rand: FastRand::new(seed()), + })); + + remotes.push(Remote { steal, unpark }); + worker_metrics.push(WorkerMetrics::new()); + } + + let shared = Arc::new(Shared { + remotes: remotes.into_boxed_slice(), + inject: Inject::new(), + idle: Idle::new(size), + owned: OwnedTasks::new(), + shutdown_cores: Mutex::new(vec![]), + before_park, + after_unpark, + scheduler_metrics: SchedulerMetrics::new(), + worker_metrics: worker_metrics.into_boxed_slice(), + }); + + let mut launch = Launch(vec![]); + + for (index, core) in cores.drain(..).enumerate() { + launch.0.push(Arc::new(Worker { + shared: shared.clone(), + index, + core: AtomicCell::new(Some(core)), + })); + } + + (shared, launch) +} + +pub(crate) fn block_in_place<F, R>(f: F) -> R +where + F: FnOnce() -> R, +{ + // Try to steal the worker core back + struct Reset(coop::Budget); + + impl Drop for Reset { + fn drop(&mut self) { + CURRENT.with(|maybe_cx| { + if let Some(cx) = maybe_cx { + let core = cx.worker.core.take(); + let mut cx_core = cx.core.borrow_mut(); + assert!(cx_core.is_none()); + *cx_core = core; + + // Reset the task budget as we are re-entering the + // runtime. + coop::set(self.0); + } + }); + } + } + + let mut had_entered = false; + + CURRENT.with(|maybe_cx| { + match (crate::runtime::enter::context(), maybe_cx.is_some()) { + (EnterContext::Entered { .. }, true) => { + // We are on a thread pool runtime thread, so we just need to + // set up blocking. + had_entered = true; + } + (EnterContext::Entered { allow_blocking }, false) => { + // We are on an executor, but _not_ on the thread pool. That is + // _only_ okay if we are in a thread pool runtime's block_on + // method: + if allow_blocking { + had_entered = true; + return; + } else { + // This probably means we are on the basic_scheduler or in a + // LocalSet, where it is _not_ okay to block. + panic!("can call blocking only when running on the multi-threaded runtime"); + } + } + (EnterContext::NotEntered, true) => { + // This is a nested call to block_in_place (we already exited). + // All the necessary setup has already been done. + return; + } + (EnterContext::NotEntered, false) => { + // We are outside of the tokio runtime, so blocking is fine. + // We can also skip all of the thread pool blocking setup steps. + return; + } + } + + let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); + + // Get the worker core. If none is set, then blocking is fine! + let core = match cx.core.borrow_mut().take() { + Some(core) => core, + None => return, + }; + + // The parker should be set here + assert!(core.park.is_some()); + + // In order to block, the core must be sent to another thread for + // execution. + // + // First, move the core back into the worker's shared core slot. + cx.worker.core.set(core); + + // Next, clone the worker handle and send it to a new thread for + // processing. + // + // Once the blocking task is done executing, we will attempt to + // steal the core back. + let worker = cx.worker.clone(); + runtime::spawn_blocking(move || run(worker)); + }); + + if had_entered { + // Unset the current task's budget. Blocking sections are not + // constrained by task budgets. + let _reset = Reset(coop::stop()); + + crate::runtime::enter::exit(f) + } else { + f() + } +} + +/// After how many ticks is the global queue polled. This helps to ensure +/// fairness. +/// +/// The number is fairly arbitrary. I believe this value was copied from golang. +const GLOBAL_POLL_INTERVAL: u8 = 61; + +impl Launch { + pub(crate) fn launch(mut self) { + for worker in self.0.drain(..) { + runtime::spawn_blocking(move || run(worker)); + } + } +} + +fn run(worker: Arc<Worker>) { + // Acquire a core. If this fails, then another thread is running this + // worker and there is nothing further to do. + let core = match worker.core.take() { + Some(core) => core, + None => return, + }; + + // Set the worker context. + let cx = Context { + worker, + core: RefCell::new(None), + }; + + let _enter = crate::runtime::enter(true); + + CURRENT.set(&cx, || { + // This should always be an error. It only returns a `Result` to support + // using `?` to short circuit. + assert!(cx.run(core).is_err()); + }); +} + +impl Context { + fn run(&self, mut core: Box<Core>) -> RunResult { + while !core.is_shutdown { + // Increment the tick + core.tick(); + + // Run maintenance, if needed + core = self.maintenance(core); + + // First, check work available to the current worker. + if let Some(task) = core.next_task(&self.worker) { + core = self.run_task(task, core)?; + continue; + } + + // There is no more **local** work to process, try to steal work + // from other workers. + if let Some(task) = core.steal_work(&self.worker) { + core = self.run_task(task, core)?; + } else { + // Wait for work + core = self.park(core); + } + } + + core.pre_shutdown(&self.worker); + + // Signal shutdown + self.worker.shared.shutdown(core); + Err(()) + } + + fn run_task(&self, task: Notified, mut core: Box<Core>) -> RunResult { + let task = self.worker.shared.owned.assert_owner(task); + + // Make sure the worker is not in the **searching** state. This enables + // another idle worker to try to steal work. + core.transition_from_searching(&self.worker); + + // Make the core available to the runtime context + core.metrics.incr_poll_count(); + *self.core.borrow_mut() = Some(core); + + // Run the task + coop::budget(|| { + task.run(); + + // As long as there is budget remaining and a task exists in the + // `lifo_slot`, then keep running. + loop { + // Check if we still have the core. If not, the core was stolen + // by another worker. + let mut core = match self.core.borrow_mut().take() { + Some(core) => core, + None => return Err(()), + }; + + // Check for a task in the LIFO slot + let task = match core.lifo_slot.take() { + Some(task) => task, + None => return Ok(core), + }; + + if coop::has_budget_remaining() { + // Run the LIFO task, then loop + core.metrics.incr_poll_count(); + *self.core.borrow_mut() = Some(core); + let task = self.worker.shared.owned.assert_owner(task); + task.run(); + } else { + // Not enough budget left to run the LIFO task, push it to + // the back of the queue and return. + core.run_queue + .push_back(task, self.worker.inject(), &mut core.metrics); + return Ok(core); + } + } + }) + } + + fn maintenance(&self, mut core: Box<Core>) -> Box<Core> { + if core.tick % GLOBAL_POLL_INTERVAL == 0 { + // Call `park` with a 0 timeout. This enables the I/O driver, timer, ... + // to run without actually putting the thread to sleep. + core = self.park_timeout(core, Some(Duration::from_millis(0))); + + // Run regularly scheduled maintenance + core.maintenance(&self.worker); + } + + core + } + + fn park(&self, mut core: Box<Core>) -> Box<Core> { + if let Some(f) = &self.worker.shared.before_park { + f(); + } + + if core.transition_to_parked(&self.worker) { + while !core.is_shutdown { + core.metrics.about_to_park(); + core = self.park_timeout(core, None); + core.metrics.returned_from_park(); + + // Run regularly scheduled maintenance + core.maintenance(&self.worker); + + if core.transition_from_parked(&self.worker) { + break; + } + } + } + + if let Some(f) = &self.worker.shared.after_unpark { + f(); + } + core + } + + fn park_timeout(&self, mut core: Box<Core>, duration: Option<Duration>) -> Box<Core> { + // Take the parker out of core + let mut park = core.park.take().expect("park missing"); + + // Store `core` in context + *self.core.borrow_mut() = Some(core); + + // Park thread + if let Some(timeout) = duration { + park.park_timeout(timeout).expect("park failed"); + } else { + park.park().expect("park failed"); + } + + // Remove `core` from context + core = self.core.borrow_mut().take().expect("core missing"); + + // Place `park` back in `core` + core.park = Some(park); + + // If there are tasks available to steal, but this worker is not + // looking for tasks to steal, notify another worker. + if !core.is_searching && core.run_queue.is_stealable() { + self.worker.shared.notify_parked(); + } + + core + } +} + +impl Core { + /// Increment the tick + fn tick(&mut self) { + self.tick = self.tick.wrapping_add(1); + } + + /// Return the next notified task available to this worker. + fn next_task(&mut self, worker: &Worker) -> Option<Notified> { + if self.tick % GLOBAL_POLL_INTERVAL == 0 { + worker.inject().pop().or_else(|| self.next_local_task()) + } else { + self.next_local_task().or_else(|| worker.inject().pop()) + } + } + + fn next_local_task(&mut self) -> Option<Notified> { + self.lifo_slot.take().or_else(|| self.run_queue.pop()) + } + + fn steal_work(&mut self, worker: &Worker) -> Option<Notified> { + if !self.transition_to_searching(worker) { + return None; + } + + let num = worker.shared.remotes.len(); + // Start from a random worker + let start = self.rand.fastrand_n(num as u32) as usize; + + for i in 0..num { + let i = (start + i) % num; + + // Don't steal from ourself! We know we don't have work. + if i == worker.index { + continue; + } + + let target = &worker.shared.remotes[i]; + if let Some(task) = target + .steal + .steal_into(&mut self.run_queue, &mut self.metrics) + { + return Some(task); + } + } + + // Fallback on checking the global queue + worker.shared.inject.pop() + } + + fn transition_to_searching(&mut self, worker: &Worker) -> bool { + if !self.is_searching { + self.is_searching = worker.shared.idle.transition_worker_to_searching(); + } + + self.is_searching + } + + fn transition_from_searching(&mut self, worker: &Worker) { + if !self.is_searching { + return; + } + + self.is_searching = false; + worker.shared.transition_worker_from_searching(); + } + + /// Prepares the worker state for parking. + /// + /// Returns true if the transition happend, false if there is work to do first. + fn transition_to_parked(&mut self, worker: &Worker) -> bool { + // Workers should not park if they have work to do + if self.lifo_slot.is_some() || self.run_queue.has_tasks() { + return false; + } + + // When the final worker transitions **out** of searching to parked, it + // must check all the queues one last time in case work materialized + // between the last work scan and transitioning out of searching. + let is_last_searcher = worker + .shared + .idle + .transition_worker_to_parked(worker.index, self.is_searching); + + // The worker is no longer searching. Setting this is the local cache + // only. + self.is_searching = false; + + if is_last_searcher { + worker.shared.notify_if_work_pending(); + } + + true + } + + /// Returns `true` if the transition happened. + fn transition_from_parked(&mut self, worker: &Worker) -> bool { + // If a task is in the lifo slot, then we must unpark regardless of + // being notified + if self.lifo_slot.is_some() { + // When a worker wakes, it should only transition to the "searching" + // state when the wake originates from another worker *or* a new task + // is pushed. We do *not* want the worker to transition to "searching" + // when it wakes when the I/O driver receives new events. + self.is_searching = !worker.shared.idle.unpark_worker_by_id(worker.index); + return true; + } + + if worker.shared.idle.is_parked(worker.index) { + return false; + } + + // When unparked, the worker is in the searching state. + self.is_searching = true; + true + } + + /// Runs maintenance work such as checking the pool's state. + fn maintenance(&mut self, worker: &Worker) { + self.metrics + .submit(&worker.shared.worker_metrics[worker.index]); + + if !self.is_shutdown { + // Check if the scheduler has been shutdown + self.is_shutdown = worker.inject().is_closed(); + } + } + + /// Signals all tasks to shut down, and waits for them to complete. Must run + /// before we enter the single-threaded phase of shutdown processing. + fn pre_shutdown(&mut self, worker: &Worker) { + // Signal to all tasks to shut down. + worker.shared.owned.close_and_shutdown_all(); + + self.metrics + .submit(&worker.shared.worker_metrics[worker.index]); + } + + /// Shuts down the core. + fn shutdown(&mut self) { + // Take the core + let mut park = self.park.take().expect("park missing"); + + // Drain the queue + while self.next_local_task().is_some() {} + + park.shutdown(); + } +} + +impl Worker { + /// Returns a reference to the scheduler's injection queue. + fn inject(&self) -> &Inject<Arc<Shared>> { + &self.shared.inject + } +} + +impl task::Schedule for Arc<Shared> { + fn release(&self, task: &Task) -> Option<Task> { + self.owned.remove(task) + } + + fn schedule(&self, task: Notified) { + (**self).schedule(task, false); + } + + fn yield_now(&self, task: Notified) { + (**self).schedule(task, true); + } +} + +impl Shared { + pub(super) fn bind_new_task<T>(me: &Arc<Self>, future: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (handle, notified) = me.owned.bind(future, me.clone()); + + if let Some(notified) = notified { + me.schedule(notified, false); + } + + handle + } + + pub(super) fn schedule(&self, task: Notified, is_yield: bool) { + CURRENT.with(|maybe_cx| { + if let Some(cx) = maybe_cx { + // Make sure the task is part of the **current** scheduler. + if self.ptr_eq(&cx.worker.shared) { + // And the current thread still holds a core + if let Some(core) = cx.core.borrow_mut().as_mut() { + self.schedule_local(core, task, is_yield); + return; + } + } + } + + // Otherwise, use the inject queue. + self.inject.push(task); + self.scheduler_metrics.inc_remote_schedule_count(); + self.notify_parked(); + }) + } + + fn schedule_local(&self, core: &mut Core, task: Notified, is_yield: bool) { + core.metrics.inc_local_schedule_count(); + + // Spawning from the worker thread. If scheduling a "yield" then the + // task must always be pushed to the back of the queue, enabling other + // tasks to be executed. If **not** a yield, then there is more + // flexibility and the task may go to the front of the queue. + let should_notify = if is_yield { + core.run_queue + .push_back(task, &self.inject, &mut core.metrics); + true + } else { + // Push to the LIFO slot + let prev = core.lifo_slot.take(); + let ret = prev.is_some(); + + if let Some(prev) = prev { + core.run_queue + .push_back(prev, &self.inject, &mut core.metrics); + } + + core.lifo_slot = Some(task); + + ret + }; + + // Only notify if not currently parked. If `park` is `None`, then the + // scheduling is from a resource driver. As notifications often come in + // batches, the notification is delayed until the park is complete. + if should_notify && core.park.is_some() { + self.notify_parked(); + } + } + + pub(super) fn close(&self) { + if self.inject.close() { + self.notify_all(); + } + } + + fn notify_parked(&self) { + if let Some(index) = self.idle.worker_to_notify() { + self.remotes[index].unpark.unpark(); + } + } + + fn notify_all(&self) { + for remote in &self.remotes[..] { + remote.unpark.unpark(); + } + } + + fn notify_if_work_pending(&self) { + for remote in &self.remotes[..] { + if !remote.steal.is_empty() { + self.notify_parked(); + return; + } + } + + if !self.inject.is_empty() { + self.notify_parked(); + } + } + + fn transition_worker_from_searching(&self) { + if self.idle.transition_worker_from_searching() { + // We are the final searching worker. Because work was found, we + // need to notify another worker. + self.notify_parked(); + } + } + + /// Signals that a worker has observed the shutdown signal and has replaced + /// its core back into its handle. + /// + /// If all workers have reached this point, the final cleanup is performed. + fn shutdown(&self, core: Box<Core>) { + let mut cores = self.shutdown_cores.lock(); + cores.push(core); + + if cores.len() != self.remotes.len() { + return; + } + + debug_assert!(self.owned.is_empty()); + + for mut core in cores.drain(..) { + core.shutdown(); + } + + // Drain the injection queue + // + // We already shut down every task, so we can simply drop the tasks. + while let Some(task) = self.inject.pop() { + drop(task); + } + } + + fn ptr_eq(&self, other: &Shared) -> bool { + std::ptr::eq(self, other) + } +} + +cfg_metrics! { + impl Shared { + pub(super) fn injection_queue_depth(&self) -> usize { + self.inject.len() + } + + pub(super) fn worker_local_queue_depth(&self, worker: usize) -> usize { + self.remotes[worker].steal.len() + } + } +} diff --git a/third_party/rust/tokio/src/signal/ctrl_c.rs b/third_party/rust/tokio/src/signal/ctrl_c.rs new file mode 100644 index 0000000000..b26ab7ead6 --- /dev/null +++ b/third_party/rust/tokio/src/signal/ctrl_c.rs @@ -0,0 +1,62 @@ +#[cfg(unix)] +use super::unix::{self as os_impl}; +#[cfg(windows)] +use super::windows::{self as os_impl}; + +use std::io; + +/// Completes when a "ctrl-c" notification is sent to the process. +/// +/// While signals are handled very differently between Unix and Windows, both +/// platforms support receiving a signal on "ctrl-c". This function provides a +/// portable API for receiving this notification. +/// +/// Once the returned future is polled, a listener is registered. The future +/// will complete on the first received `ctrl-c` **after** the initial call to +/// either `Future::poll` or `.await`. +/// +/// # Caveats +/// +/// On Unix platforms, the first time that a `Signal` instance is registered for a +/// particular signal kind, an OS signal-handler is installed which replaces the +/// default platform behavior when that signal is received, **for the duration of +/// the entire process**. +/// +/// For example, Unix systems will terminate a process by default when it +/// receives a signal generated by "CTRL+C" on the terminal. But, when a +/// `ctrl_c` stream is created to listen for this signal, the time it arrives, +/// it will be translated to a stream event, and the process will continue to +/// execute. **Even if this `Signal` instance is dropped, subsequent SIGINT +/// deliveries will end up captured by Tokio, and the default platform behavior +/// will NOT be reset**. +/// +/// Thus, applications should take care to ensure the expected signal behavior +/// occurs as expected after listening for specific signals. +/// +/// # Examples +/// +/// ```rust,no_run +/// use tokio::signal; +/// +/// #[tokio::main] +/// async fn main() { +/// println!("waiting for ctrl-c"); +/// +/// signal::ctrl_c().await.expect("failed to listen for event"); +/// +/// println!("received ctrl-c event"); +/// } +/// ``` +/// +/// Listen in the background: +/// +/// ```rust,no_run +/// tokio::spawn(async move { +/// tokio::signal::ctrl_c().await.unwrap(); +/// // Your handler here +/// }); +/// ``` +pub async fn ctrl_c() -> io::Result<()> { + os_impl::ctrl_c()?.recv().await; + Ok(()) +} diff --git a/third_party/rust/tokio/src/signal/mod.rs b/third_party/rust/tokio/src/signal/mod.rs new file mode 100644 index 0000000000..882218a0f6 --- /dev/null +++ b/third_party/rust/tokio/src/signal/mod.rs @@ -0,0 +1,100 @@ +//! Asynchronous signal handling for Tokio. +//! +//! Note that signal handling is in general a very tricky topic and should be +//! used with great care. This crate attempts to implement 'best practice' for +//! signal handling, but it should be evaluated for your own applications' needs +//! to see if it's suitable. +//! +//! There are some fundamental limitations of this crate documented on the OS +//! specific structures, as well. +//! +//! # Examples +//! +//! Print on "ctrl-c" notification. +//! +//! ```rust,no_run +//! use tokio::signal; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! signal::ctrl_c().await?; +//! println!("ctrl-c received!"); +//! Ok(()) +//! } +//! ``` +//! +//! Wait for SIGHUP on Unix +//! +//! ```rust,no_run +//! # #[cfg(unix)] { +//! use tokio::signal::unix::{signal, SignalKind}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! // An infinite stream of hangup signals. +//! let mut stream = signal(SignalKind::hangup())?; +//! +//! // Print whenever a HUP signal is received +//! loop { +//! stream.recv().await; +//! println!("got signal HUP"); +//! } +//! } +//! # } +//! ``` +use crate::sync::watch::Receiver; +use std::task::{Context, Poll}; + +mod ctrl_c; +pub use ctrl_c::ctrl_c; + +mod registry; + +mod os { + #[cfg(unix)] + pub(crate) use super::unix::{OsExtraData, OsStorage}; + + #[cfg(windows)] + pub(crate) use super::windows::{OsExtraData, OsStorage}; +} + +pub mod unix; +pub mod windows; + +mod reusable_box; +use self::reusable_box::ReusableBoxFuture; + +#[derive(Debug)] +struct RxFuture { + inner: ReusableBoxFuture<Receiver<()>>, +} + +async fn make_future(mut rx: Receiver<()>) -> Receiver<()> { + match rx.changed().await { + Ok(()) => rx, + Err(_) => panic!("signal sender went away"), + } +} + +impl RxFuture { + fn new(rx: Receiver<()>) -> Self { + Self { + inner: ReusableBoxFuture::new(make_future(rx)), + } + } + + async fn recv(&mut self) -> Option<()> { + use crate::future::poll_fn; + poll_fn(|cx| self.poll_recv(cx)).await + } + + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + match self.inner.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(rx) => { + self.inner.set(make_future(rx)); + Poll::Ready(Some(())) + } + } + } +} diff --git a/third_party/rust/tokio/src/signal/registry.rs b/third_party/rust/tokio/src/signal/registry.rs new file mode 100644 index 0000000000..6d8eb9e748 --- /dev/null +++ b/third_party/rust/tokio/src/signal/registry.rs @@ -0,0 +1,279 @@ +#![allow(clippy::unit_arg)] + +use crate::signal::os::{OsExtraData, OsStorage}; + +use crate::sync::watch; + +use once_cell::sync::Lazy; +use std::ops; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; + +pub(crate) type EventId = usize; + +/// State for a specific event, whether a notification is pending delivery, +/// and what listeners are registered. +#[derive(Debug)] +pub(crate) struct EventInfo { + pending: AtomicBool, + tx: watch::Sender<()>, +} + +impl Default for EventInfo { + fn default() -> Self { + let (tx, _rx) = watch::channel(()); + + Self { + pending: AtomicBool::new(false), + tx, + } + } +} + +/// An interface for retrieving the `EventInfo` for a particular eventId. +pub(crate) trait Storage { + /// Gets the `EventInfo` for `id` if it exists. + fn event_info(&self, id: EventId) -> Option<&EventInfo>; + + /// Invokes `f` once for each defined `EventInfo` in this storage. + fn for_each<'a, F>(&'a self, f: F) + where + F: FnMut(&'a EventInfo); +} + +impl Storage for Vec<EventInfo> { + fn event_info(&self, id: EventId) -> Option<&EventInfo> { + self.get(id) + } + + fn for_each<'a, F>(&'a self, f: F) + where + F: FnMut(&'a EventInfo), + { + self.iter().for_each(f) + } +} + +/// An interface for initializing a type. Useful for situations where we cannot +/// inject a configured instance in the constructor of another type. +pub(crate) trait Init { + fn init() -> Self; +} + +/// Manages and distributes event notifications to any registered listeners. +/// +/// Generic over the underlying storage to allow for domain specific +/// optimizations (e.g. eventIds may or may not be contiguous). +#[derive(Debug)] +pub(crate) struct Registry<S> { + storage: S, +} + +impl<S> Registry<S> { + fn new(storage: S) -> Self { + Self { storage } + } +} + +impl<S: Storage> Registry<S> { + /// Registers a new listener for `event_id`. + fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> { + self.storage + .event_info(event_id) + .unwrap_or_else(|| panic!("invalid event_id: {}", event_id)) + .tx + .subscribe() + } + + /// Marks `event_id` as having been delivered, without broadcasting it to + /// any listeners. + fn record_event(&self, event_id: EventId) { + if let Some(event_info) = self.storage.event_info(event_id) { + event_info.pending.store(true, Ordering::SeqCst) + } + } + + /// Broadcasts all previously recorded events to their respective listeners. + /// + /// Returns `true` if an event was delivered to at least one listener. + fn broadcast(&self) -> bool { + let mut did_notify = false; + self.storage.for_each(|event_info| { + // Any signal of this kind arrived since we checked last? + if !event_info.pending.swap(false, Ordering::SeqCst) { + return; + } + + // Ignore errors if there are no listeners + if event_info.tx.send(()).is_ok() { + did_notify = true; + } + }); + + did_notify + } +} + +pub(crate) struct Globals { + extra: OsExtraData, + registry: Registry<OsStorage>, +} + +impl ops::Deref for Globals { + type Target = OsExtraData; + + fn deref(&self) -> &Self::Target { + &self.extra + } +} + +impl Globals { + /// Registers a new listener for `event_id`. + pub(crate) fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> { + self.registry.register_listener(event_id) + } + + /// Marks `event_id` as having been delivered, without broadcasting it to + /// any listeners. + pub(crate) fn record_event(&self, event_id: EventId) { + self.registry.record_event(event_id); + } + + /// Broadcasts all previously recorded events to their respective listeners. + /// + /// Returns `true` if an event was delivered to at least one listener. + pub(crate) fn broadcast(&self) -> bool { + self.registry.broadcast() + } + + #[cfg(unix)] + pub(crate) fn storage(&self) -> &OsStorage { + &self.registry.storage + } +} + +pub(crate) fn globals() -> Pin<&'static Globals> +where + OsExtraData: 'static + Send + Sync + Init, + OsStorage: 'static + Send + Sync + Init, +{ + static GLOBALS: Lazy<Pin<Box<Globals>>> = Lazy::new(|| { + Box::pin(Globals { + extra: OsExtraData::init(), + registry: Registry::new(OsStorage::init()), + }) + }); + + GLOBALS.as_ref() +} + +#[cfg(all(test, not(loom)))] +mod tests { + use super::*; + use crate::runtime::{self, Runtime}; + use crate::sync::{oneshot, watch}; + + use futures::future; + + #[test] + fn smoke() { + let rt = rt(); + rt.block_on(async move { + let registry = Registry::new(vec![ + EventInfo::default(), + EventInfo::default(), + EventInfo::default(), + ]); + + let first = registry.register_listener(0); + let second = registry.register_listener(1); + let third = registry.register_listener(2); + + let (fire, wait) = oneshot::channel(); + + crate::spawn(async { + wait.await.expect("wait failed"); + + // Record some events which should get coalesced + registry.record_event(0); + registry.record_event(0); + registry.record_event(1); + registry.record_event(1); + registry.broadcast(); + + // Yield so the previous broadcast can get received + // + // This yields many times since the block_on task is only polled every 61 + // ticks. + for _ in 0..100 { + crate::task::yield_now().await; + } + + // Send subsequent signal + registry.record_event(0); + registry.broadcast(); + + drop(registry); + }); + + let _ = fire.send(()); + let all = future::join3(collect(first), collect(second), collect(third)); + + let (first_results, second_results, third_results) = all.await; + assert_eq!(2, first_results.len()); + assert_eq!(1, second_results.len()); + assert_eq!(0, third_results.len()); + }); + } + + #[test] + #[should_panic = "invalid event_id: 1"] + fn register_panics_on_invalid_input() { + let registry = Registry::new(vec![EventInfo::default()]); + + registry.register_listener(1); + } + + #[test] + fn record_invalid_event_does_nothing() { + let registry = Registry::new(vec![EventInfo::default()]); + registry.record_event(42); + } + + #[test] + fn broadcast_returns_if_at_least_one_event_fired() { + let registry = Registry::new(vec![EventInfo::default(), EventInfo::default()]); + + registry.record_event(0); + assert!(!registry.broadcast()); + + let first = registry.register_listener(0); + let second = registry.register_listener(1); + + registry.record_event(0); + assert!(registry.broadcast()); + + drop(first); + registry.record_event(0); + assert!(!registry.broadcast()); + + drop(second); + } + + fn rt() -> Runtime { + runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap() + } + + async fn collect(mut rx: watch::Receiver<()>) -> Vec<()> { + let mut ret = vec![]; + + while let Ok(v) = rx.changed().await { + ret.push(v); + } + + ret + } +} diff --git a/third_party/rust/tokio/src/signal/reusable_box.rs b/third_party/rust/tokio/src/signal/reusable_box.rs new file mode 100644 index 0000000000..02f32474b1 --- /dev/null +++ b/third_party/rust/tokio/src/signal/reusable_box.rs @@ -0,0 +1,228 @@ +use std::alloc::Layout; +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::pin::Pin; +use std::ptr::{self, NonNull}; +use std::task::{Context, Poll}; +use std::{fmt, panic}; + +/// A reusable `Pin<Box<dyn Future<Output = T> + Send>>`. +/// +/// This type lets you replace the future stored in the box without +/// reallocating when the size and alignment permits this. +pub(crate) struct ReusableBoxFuture<T> { + boxed: NonNull<dyn Future<Output = T> + Send>, +} + +impl<T> ReusableBoxFuture<T> { + /// Create a new `ReusableBoxFuture<T>` containing the provided future. + pub(crate) fn new<F>(future: F) -> Self + where + F: Future<Output = T> + Send + 'static, + { + let boxed: Box<dyn Future<Output = T> + Send> = Box::new(future); + + let boxed = Box::into_raw(boxed); + + // SAFETY: Box::into_raw does not return null pointers. + let boxed = unsafe { NonNull::new_unchecked(boxed) }; + + Self { boxed } + } + + /// Replaces the future currently stored in this box. + /// + /// This reallocates if and only if the layout of the provided future is + /// different from the layout of the currently stored future. + pub(crate) fn set<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'static, + { + if let Err(future) = self.try_set(future) { + *self = Self::new(future); + } + } + + /// Replaces the future currently stored in this box. + /// + /// This function never reallocates, but returns an error if the provided + /// future has a different size or alignment from the currently stored + /// future. + pub(crate) fn try_set<F>(&mut self, future: F) -> Result<(), F> + where + F: Future<Output = T> + Send + 'static, + { + // SAFETY: The pointer is not dangling. + let self_layout = { + let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() }; + Layout::for_value(dyn_future) + }; + + if Layout::new::<F>() == self_layout { + // SAFETY: We just checked that the layout of F is correct. + unsafe { + self.set_same_layout(future); + } + + Ok(()) + } else { + Err(future) + } + } + + /// Sets the current future. + /// + /// # Safety + /// + /// This function requires that the layout of the provided future is the + /// same as `self.layout`. + unsafe fn set_same_layout<F>(&mut self, future: F) + where + F: Future<Output = T> + Send + 'static, + { + // Drop the existing future, catching any panics. + let result = panic::catch_unwind(AssertUnwindSafe(|| { + ptr::drop_in_place(self.boxed.as_ptr()); + })); + + // Overwrite the future behind the pointer. This is safe because the + // allocation was allocated with the same size and alignment as the type F. + let self_ptr: *mut F = self.boxed.as_ptr() as *mut F; + ptr::write(self_ptr, future); + + // Update the vtable of self.boxed. The pointer is not null because we + // just got it from self.boxed, which is not null. + self.boxed = NonNull::new_unchecked(self_ptr); + + // If the old future's destructor panicked, resume unwinding. + match result { + Ok(()) => {} + Err(payload) => { + panic::resume_unwind(payload); + } + } + } + + /// Gets a pinned reference to the underlying future. + pub(crate) fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> { + // SAFETY: The user of this box cannot move the box, and we do not move it + // either. + unsafe { Pin::new_unchecked(self.boxed.as_mut()) } + } + + /// Polls the future stored inside this box. + pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> { + self.get_pin().poll(cx) + } +} + +impl<T> Future for ReusableBoxFuture<T> { + type Output = T; + + /// Polls the future stored inside this box. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { + Pin::into_inner(self).get_pin().poll(cx) + } +} + +// The future stored inside ReusableBoxFuture<T> must be Send. +unsafe impl<T> Send for ReusableBoxFuture<T> {} + +// The only method called on self.boxed is poll, which takes &mut self, so this +// struct being Sync does not permit any invalid access to the Future, even if +// the future is not Sync. +unsafe impl<T> Sync for ReusableBoxFuture<T> {} + +// Just like a Pin<Box<dyn Future>> is always Unpin, so is this type. +impl<T> Unpin for ReusableBoxFuture<T> {} + +impl<T> Drop for ReusableBoxFuture<T> { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(self.boxed.as_ptr())); + } + } +} + +impl<T> fmt::Debug for ReusableBoxFuture<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReusableBoxFuture").finish() + } +} + +#[cfg(test)] +#[cfg(not(miri))] // Miri breaks when you use Pin<&mut dyn Future> +mod test { + use super::ReusableBoxFuture; + use futures::future::FutureExt; + use std::alloc::Layout; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + #[test] + fn test_different_futures() { + let fut = async move { 10 }; + // Not zero sized! + assert_eq!(Layout::for_value(&fut).size(), 1); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(10)); + + b.try_set(async move { 20 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(20)); + + b.try_set(async move { 30 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(30)); + } + + #[test] + fn test_different_sizes() { + let fut1 = async move { 10 }; + let val = [0u32; 1000]; + let fut2 = async move { val[0] }; + let fut3 = ZeroSizedFuture {}; + + assert_eq!(Layout::for_value(&fut1).size(), 1); + assert_eq!(Layout::for_value(&fut2).size(), 4004); + assert_eq!(Layout::for_value(&fut3).size(), 0); + + let mut b = ReusableBoxFuture::new(fut1); + assert_eq!(b.get_pin().now_or_never(), Some(10)); + b.set(fut2); + assert_eq!(b.get_pin().now_or_never(), Some(0)); + b.set(fut3); + assert_eq!(b.get_pin().now_or_never(), Some(5)); + } + + struct ZeroSizedFuture {} + impl Future for ZeroSizedFuture { + type Output = u32; + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u32> { + Poll::Ready(5) + } + } + + #[test] + fn test_zero_sized() { + let fut = ZeroSizedFuture {}; + // Zero sized! + assert_eq!(Layout::for_value(&fut).size(), 0); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); + + b.try_set(ZeroSizedFuture {}) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); + } +} diff --git a/third_party/rust/tokio/src/signal/unix.rs b/third_party/rust/tokio/src/signal/unix.rs new file mode 100644 index 0000000000..86ea9a93ee --- /dev/null +++ b/third_party/rust/tokio/src/signal/unix.rs @@ -0,0 +1,477 @@ +//! Unix-specific types for signal handling. +//! +//! This module is only defined on Unix platforms and contains the primary +//! `Signal` type for receiving notifications of signals. + +#![cfg(unix)] +#![cfg_attr(docsrs, doc(cfg(all(unix, feature = "signal"))))] + +use crate::signal::registry::{globals, EventId, EventInfo, Globals, Init, Storage}; +use crate::signal::RxFuture; +use crate::sync::watch; + +use mio::net::UnixStream; +use std::io::{self, Error, ErrorKind, Write}; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Once; +use std::task::{Context, Poll}; + +pub(crate) mod driver; +use self::driver::Handle; + +pub(crate) type OsStorage = Vec<SignalInfo>; + +// Number of different unix signals +// (FreeBSD has 33) +const SIGNUM: usize = 33; + +impl Init for OsStorage { + fn init() -> Self { + (0..SIGNUM).map(|_| SignalInfo::default()).collect() + } +} + +impl Storage for OsStorage { + fn event_info(&self, id: EventId) -> Option<&EventInfo> { + self.get(id).map(|si| &si.event_info) + } + + fn for_each<'a, F>(&'a self, f: F) + where + F: FnMut(&'a EventInfo), + { + self.iter().map(|si| &si.event_info).for_each(f) + } +} + +#[derive(Debug)] +pub(crate) struct OsExtraData { + sender: UnixStream, + receiver: UnixStream, +} + +impl Init for OsExtraData { + fn init() -> Self { + let (receiver, sender) = UnixStream::pair().expect("failed to create UnixStream"); + + Self { sender, receiver } + } +} + +/// Represents the specific kind of signal to listen for. +#[derive(Debug, Clone, Copy)] +pub struct SignalKind(libc::c_int); + +impl SignalKind { + /// Allows for listening to any valid OS signal. + /// + /// For example, this can be used for listening for platform-specific + /// signals. + /// ```rust,no_run + /// # use tokio::signal::unix::SignalKind; + /// # let signum = -1; + /// // let signum = libc::OS_SPECIFIC_SIGNAL; + /// let kind = SignalKind::from_raw(signum); + /// ``` + // Use `std::os::raw::c_int` on public API to prevent leaking a non-stable + // type alias from libc. + // `libc::c_int` and `std::os::raw::c_int` are currently the same type, and are + // unlikely to change to other types, but technically libc can change this + // in the future minor version. + // See https://github.com/tokio-rs/tokio/issues/3767 for more. + pub fn from_raw(signum: std::os::raw::c_int) -> Self { + Self(signum as libc::c_int) + } + + /// Represents the SIGALRM signal. + /// + /// On Unix systems this signal is sent when a real-time timer has expired. + /// By default, the process is terminated by this signal. + pub fn alarm() -> Self { + Self(libc::SIGALRM) + } + + /// Represents the SIGCHLD signal. + /// + /// On Unix systems this signal is sent when the status of a child process + /// has changed. By default, this signal is ignored. + pub fn child() -> Self { + Self(libc::SIGCHLD) + } + + /// Represents the SIGHUP signal. + /// + /// On Unix systems this signal is sent when the terminal is disconnected. + /// By default, the process is terminated by this signal. + pub fn hangup() -> Self { + Self(libc::SIGHUP) + } + + /// Represents the SIGINFO signal. + /// + /// On Unix systems this signal is sent to request a status update from the + /// process. By default, this signal is ignored. + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + ))] + pub fn info() -> Self { + Self(libc::SIGINFO) + } + + /// Represents the SIGINT signal. + /// + /// On Unix systems this signal is sent to interrupt a program. + /// By default, the process is terminated by this signal. + pub fn interrupt() -> Self { + Self(libc::SIGINT) + } + + /// Represents the SIGIO signal. + /// + /// On Unix systems this signal is sent when I/O operations are possible + /// on some file descriptor. By default, this signal is ignored. + pub fn io() -> Self { + Self(libc::SIGIO) + } + + /// Represents the SIGPIPE signal. + /// + /// On Unix systems this signal is sent when the process attempts to write + /// to a pipe which has no reader. By default, the process is terminated by + /// this signal. + pub fn pipe() -> Self { + Self(libc::SIGPIPE) + } + + /// Represents the SIGQUIT signal. + /// + /// On Unix systems this signal is sent to issue a shutdown of the + /// process, after which the OS will dump the process core. + /// By default, the process is terminated by this signal. + pub fn quit() -> Self { + Self(libc::SIGQUIT) + } + + /// Represents the SIGTERM signal. + /// + /// On Unix systems this signal is sent to issue a shutdown of the + /// process. By default, the process is terminated by this signal. + pub fn terminate() -> Self { + Self(libc::SIGTERM) + } + + /// Represents the SIGUSR1 signal. + /// + /// On Unix systems this is a user defined signal. + /// By default, the process is terminated by this signal. + pub fn user_defined1() -> Self { + Self(libc::SIGUSR1) + } + + /// Represents the SIGUSR2 signal. + /// + /// On Unix systems this is a user defined signal. + /// By default, the process is terminated by this signal. + pub fn user_defined2() -> Self { + Self(libc::SIGUSR2) + } + + /// Represents the SIGWINCH signal. + /// + /// On Unix systems this signal is sent when the terminal window is resized. + /// By default, this signal is ignored. + pub fn window_change() -> Self { + Self(libc::SIGWINCH) + } +} + +pub(crate) struct SignalInfo { + event_info: EventInfo, + init: Once, + initialized: AtomicBool, +} + +impl Default for SignalInfo { + fn default() -> SignalInfo { + SignalInfo { + event_info: Default::default(), + init: Once::new(), + initialized: AtomicBool::new(false), + } + } +} + +/// Our global signal handler for all signals registered by this module. +/// +/// The purpose of this signal handler is to primarily: +/// +/// 1. Flag that our specific signal was received (e.g. store an atomic flag) +/// 2. Wake up the driver by writing a byte to a pipe +/// +/// Those two operations should both be async-signal safe. +fn action(globals: Pin<&'static Globals>, signal: libc::c_int) { + globals.record_event(signal as EventId); + + // Send a wakeup, ignore any errors (anything reasonably possible is + // full pipe and then it will wake up anyway). + let mut sender = &globals.sender; + drop(sender.write(&[1])); +} + +/// Enables this module to receive signal notifications for the `signal` +/// provided. +/// +/// This will register the signal handler if it hasn't already been registered, +/// returning any error along the way if that fails. +fn signal_enable(signal: SignalKind, handle: &Handle) -> io::Result<()> { + let signal = signal.0; + if signal < 0 || signal_hook_registry::FORBIDDEN.contains(&signal) { + return Err(Error::new( + ErrorKind::Other, + format!("Refusing to register signal {}", signal), + )); + } + + // Check that we have a signal driver running + handle.check_inner()?; + + let globals = globals(); + let siginfo = match globals.storage().get(signal as EventId) { + Some(slot) => slot, + None => return Err(io::Error::new(io::ErrorKind::Other, "signal too large")), + }; + let mut registered = Ok(()); + siginfo.init.call_once(|| { + registered = unsafe { + signal_hook_registry::register(signal, move || action(globals, signal)).map(|_| ()) + }; + if registered.is_ok() { + siginfo.initialized.store(true, Ordering::Relaxed); + } + }); + registered?; + // If the call_once failed, it won't be retried on the next attempt to register the signal. In + // such case it is not run, registered is still `Ok(())`, initialized is still `false`. + if siginfo.initialized.load(Ordering::Relaxed) { + Ok(()) + } else { + Err(Error::new( + ErrorKind::Other, + "Failed to register signal handler", + )) + } +} + +/// A stream of events for receiving a particular type of OS signal. +/// +/// In general signal handling on Unix is a pretty tricky topic, and this +/// structure is no exception! There are some important limitations to keep in +/// mind when using `Signal` streams: +/// +/// * Signals handling in Unix already necessitates coalescing signals +/// together sometimes. This `Signal` stream is also no exception here in +/// that it will also coalesce signals. That is, even if the signal handler +/// for this process runs multiple times, the `Signal` stream may only return +/// one signal notification. Specifically, before `poll` is called, all +/// signal notifications are coalesced into one item returned from `poll`. +/// Once `poll` has been called, however, a further signal is guaranteed to +/// be yielded as an item. +/// +/// Put another way, any element pulled off the returned stream corresponds to +/// *at least one* signal, but possibly more. +/// +/// * Signal handling in general is relatively inefficient. Although some +/// improvements are possible in this crate, it's recommended to not plan on +/// having millions of signal channels open. +/// +/// If you've got any questions about this feel free to open an issue on the +/// repo! New approaches to alleviate some of these limitations are always +/// appreciated! +/// +/// # Caveats +/// +/// The first time that a `Signal` instance is registered for a particular +/// signal kind, an OS signal-handler is installed which replaces the default +/// platform behavior when that signal is received, **for the duration of the +/// entire process**. +/// +/// For example, Unix systems will terminate a process by default when it +/// receives SIGINT. But, when a `Signal` instance is created to listen for +/// this signal, the next SIGINT that arrives will be translated to a stream +/// event, and the process will continue to execute. **Even if this `Signal` +/// instance is dropped, subsequent SIGINT deliveries will end up captured by +/// Tokio, and the default platform behavior will NOT be reset**. +/// +/// Thus, applications should take care to ensure the expected signal behavior +/// occurs as expected after listening for specific signals. +/// +/// # Examples +/// +/// Wait for SIGHUP +/// +/// ```rust,no_run +/// use tokio::signal::unix::{signal, SignalKind}; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box<dyn std::error::Error>> { +/// // An infinite stream of hangup signals. +/// let mut stream = signal(SignalKind::hangup())?; +/// +/// // Print whenever a HUP signal is received +/// loop { +/// stream.recv().await; +/// println!("got signal HUP"); +/// } +/// } +/// ``` +#[must_use = "streams do nothing unless polled"] +#[derive(Debug)] +pub struct Signal { + inner: RxFuture, +} + +/// Creates a new stream which will receive notifications when the current +/// process receives the specified signal `kind`. +/// +/// This function will create a new stream which binds to the default reactor. +/// The `Signal` stream is an infinite stream which will receive +/// notifications whenever a signal is received. More documentation can be +/// found on `Signal` itself, but to reiterate: +/// +/// * Signals may be coalesced beyond what the kernel already does. +/// * Once a signal handler is registered with the process the underlying +/// libc signal handler is never unregistered. +/// +/// A `Signal` stream can be created for a particular signal number +/// multiple times. When a signal is received then all the associated +/// channels will receive the signal notification. +/// +/// # Errors +/// +/// * If the lower-level C functions fail for some reason. +/// * If the previous initialization of this specific signal failed. +/// * If the signal is one of +/// [`signal_hook::FORBIDDEN`](fn@signal_hook_registry::register#panics) +pub fn signal(kind: SignalKind) -> io::Result<Signal> { + let rx = signal_with_handle(kind, &Handle::current())?; + + Ok(Signal { + inner: RxFuture::new(rx), + }) +} + +pub(crate) fn signal_with_handle( + kind: SignalKind, + handle: &Handle, +) -> io::Result<watch::Receiver<()>> { + // Turn the signal delivery on once we are ready for it + signal_enable(kind, handle)?; + + Ok(globals().register_listener(kind.0 as EventId)) +} + +impl Signal { + /// Receives the next signal notification event. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// Wait for SIGHUP + /// + /// ```rust,no_run + /// use tokio::signal::unix::{signal, SignalKind}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// // An infinite stream of hangup signals. + /// let mut stream = signal(SignalKind::hangup())?; + /// + /// // Print whenever a HUP signal is received + /// loop { + /// stream.recv().await; + /// println!("got signal HUP"); + /// } + /// } + /// ``` + pub async fn recv(&mut self) -> Option<()> { + self.inner.recv().await + } + + /// Polls to receive the next signal notification event, outside of an + /// `async` context. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no signals are available but the channel is not + /// closed. + /// * `Poll::Ready(Some(()))` if a signal is available. + /// * `Poll::Ready(None)` if the channel has been closed and all signals + /// sent before it was closed have been received. + /// + /// # Examples + /// + /// Polling from a manually implemented future + /// + /// ```rust,no_run + /// use std::pin::Pin; + /// use std::future::Future; + /// use std::task::{Context, Poll}; + /// use tokio::signal::unix::Signal; + /// + /// struct MyFuture { + /// signal: Signal, + /// } + /// + /// impl Future for MyFuture { + /// type Output = Option<()>; + /// + /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + /// println!("polling MyFuture"); + /// self.signal.poll_recv(cx) + /// } + /// } + /// ``` + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + self.inner.poll_recv(cx) + } +} + +// Work around for abstracting streams internally +pub(crate) trait InternalStream { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>; +} + +impl InternalStream for Signal { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + self.poll_recv(cx) + } +} + +pub(crate) fn ctrl_c() -> io::Result<Signal> { + signal(SignalKind::interrupt()) +} + +#[cfg(all(test, not(loom)))] +mod tests { + use super::*; + + #[test] + fn signal_enable_error_on_invalid_input() { + signal_enable(SignalKind::from_raw(-1), &Handle::default()).unwrap_err(); + } + + #[test] + fn signal_enable_error_on_forbidden_input() { + signal_enable( + SignalKind::from_raw(signal_hook_registry::FORBIDDEN[0]), + &Handle::default(), + ) + .unwrap_err(); + } +} diff --git a/third_party/rust/tokio/src/signal/unix/driver.rs b/third_party/rust/tokio/src/signal/unix/driver.rs new file mode 100644 index 0000000000..5fe7c354c5 --- /dev/null +++ b/third_party/rust/tokio/src/signal/unix/driver.rs @@ -0,0 +1,207 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +//! Signal driver + +use crate::io::driver::{Driver as IoDriver, Interest}; +use crate::io::PollEvented; +use crate::park::Park; +use crate::signal::registry::globals; + +use mio::net::UnixStream; +use std::io::{self, Read}; +use std::ptr; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; +use std::time::Duration; + +/// Responsible for registering wakeups when an OS signal is received, and +/// subsequently dispatching notifications to any signal listeners as appropriate. +/// +/// Note: this driver relies on having an enabled IO driver in order to listen to +/// pipe write wakeups. +#[derive(Debug)] +pub(crate) struct Driver { + /// Thread parker. The `Driver` park implementation delegates to this. + park: IoDriver, + + /// A pipe for receiving wake events from the signal handler + receiver: PollEvented<UnixStream>, + + /// Shared state + inner: Arc<Inner>, +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct Handle { + inner: Weak<Inner>, +} + +#[derive(Debug)] +pub(super) struct Inner(()); + +// ===== impl Driver ===== + +impl Driver { + /// Creates a new signal `Driver` instance that delegates wakeups to `park`. + pub(crate) fn new(park: IoDriver) -> io::Result<Self> { + use std::mem::ManuallyDrop; + use std::os::unix::io::{AsRawFd, FromRawFd}; + + // NB: We give each driver a "fresh" receiver file descriptor to avoid + // the issues described in alexcrichton/tokio-process#42. + // + // In the past we would reuse the actual receiver file descriptor and + // swallow any errors around double registration of the same descriptor. + // I'm not sure if the second (failed) registration simply doesn't end + // up receiving wake up notifications, or there could be some race + // condition when consuming readiness events, but having distinct + // descriptors for distinct PollEvented instances appears to mitigate + // this. + // + // Unfortunately we cannot just use a single global PollEvented instance + // either, since we can't compare Handles or assume they will always + // point to the exact same reactor. + // + // Mio 0.7 removed `try_clone()` as an API due to unexpected behavior + // with registering dups with the same reactor. In this case, duping is + // safe as each dup is registered with separate reactors **and** we + // only expect at least one dup to receive the notification. + + // Manually drop as we don't actually own this instance of UnixStream. + let receiver_fd = globals().receiver.as_raw_fd(); + + // safety: there is nothing unsafe about this, but the `from_raw_fd` fn is marked as unsafe. + let original = + ManuallyDrop::new(unsafe { std::os::unix::net::UnixStream::from_raw_fd(receiver_fd) }); + let receiver = UnixStream::from_std(original.try_clone()?); + let receiver = PollEvented::new_with_interest_and_handle( + receiver, + Interest::READABLE | Interest::WRITABLE, + park.handle(), + )?; + + Ok(Self { + park, + receiver, + inner: Arc::new(Inner(())), + }) + } + + /// Returns a handle to this event loop which can be sent across threads + /// and can be used as a proxy to the event loop itself. + pub(crate) fn handle(&self) -> Handle { + Handle { + inner: Arc::downgrade(&self.inner), + } + } + + fn process(&self) { + // Check if the pipe is ready to read and therefore has "woken" us up + // + // To do so, we will `poll_read_ready` with a noop waker, since we don't + // need to actually be notified when read ready... + let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let ev = match self.receiver.registration().poll_read_ready(&mut cx) { + Poll::Ready(Ok(ev)) => ev, + Poll::Ready(Err(e)) => panic!("reactor gone: {}", e), + Poll::Pending => return, // No wake has arrived, bail + }; + + // Drain the pipe completely so we can receive a new readiness event + // if another signal has come in. + let mut buf = [0; 128]; + loop { + match (&*self.receiver).read(&mut buf) { + Ok(0) => panic!("EOF on self-pipe"), + Ok(_) => continue, // Keep reading + Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => panic!("Bad read on self-pipe: {}", e), + } + } + + self.receiver.registration().clear_readiness(ev); + + // Broadcast any signals which were received + globals().broadcast(); + } +} + +const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + +unsafe fn noop_clone(_data: *const ()) -> RawWaker { + RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE) +} + +unsafe fn noop(_data: *const ()) {} + +// ===== impl Park for Driver ===== + +impl Park for Driver { + type Unpark = <IoDriver as Park>::Unpark; + type Error = io::Error; + + fn unpark(&self) -> Self::Unpark { + self.park.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.park.park()?; + self.process(); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.park.park_timeout(duration)?; + self.process(); + Ok(()) + } + + fn shutdown(&mut self) { + self.park.shutdown() + } +} + +// ===== impl Handle ===== + +impl Handle { + pub(super) fn check_inner(&self) -> io::Result<()> { + if self.inner.strong_count() > 0 { + Ok(()) + } else { + Err(io::Error::new(io::ErrorKind::Other, "signal driver gone")) + } + } +} + +cfg_rt! { + impl Handle { + /// Returns a handle to the current driver + /// + /// # Panics + /// + /// This function panics if there is no current signal driver set. + pub(super) fn current() -> Self { + crate::runtime::context::signal_handle().expect( + "there is no signal driver running, must be called from the context of Tokio runtime", + ) + } + } +} + +cfg_not_rt! { + impl Handle { + /// Returns a handle to the current driver + /// + /// # Panics + /// + /// This function panics if there is no current signal driver set. + pub(super) fn current() -> Self { + panic!( + "there is no signal driver running, must be called from the context of Tokio runtime or with\ + `rt` enabled.", + ) + } + } +} diff --git a/third_party/rust/tokio/src/signal/windows.rs b/third_party/rust/tokio/src/signal/windows.rs new file mode 100644 index 0000000000..11ec6cb08c --- /dev/null +++ b/third_party/rust/tokio/src/signal/windows.rs @@ -0,0 +1,223 @@ +//! Windows-specific types for signal handling. +//! +//! This module is only defined on Windows and allows receiving "ctrl-c" +//! and "ctrl-break" notifications. These events are listened for via the +//! `SetConsoleCtrlHandler` function which receives events of the type +//! `CTRL_C_EVENT` and `CTRL_BREAK_EVENT`. + +#![cfg(any(windows, docsrs))] +#![cfg_attr(docsrs, doc(cfg(all(windows, feature = "signal"))))] + +use crate::signal::RxFuture; +use std::io; +use std::task::{Context, Poll}; + +#[cfg(not(docsrs))] +#[path = "windows/sys.rs"] +mod imp; +#[cfg(not(docsrs))] +pub(crate) use self::imp::{OsExtraData, OsStorage}; + +#[cfg(docsrs)] +#[path = "windows/stub.rs"] +mod imp; + +/// Creates a new stream which receives "ctrl-c" notifications sent to the +/// process. +/// +/// # Examples +/// +/// ```rust,no_run +/// use tokio::signal::windows::ctrl_c; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box<dyn std::error::Error>> { +/// // An infinite stream of CTRL-C events. +/// let mut stream = ctrl_c()?; +/// +/// // Print whenever a CTRL-C event is received. +/// for countdown in (0..3).rev() { +/// stream.recv().await; +/// println!("got CTRL-C. {} more to exit", countdown); +/// } +/// +/// Ok(()) +/// } +/// ``` +pub fn ctrl_c() -> io::Result<CtrlC> { + Ok(CtrlC { + inner: self::imp::ctrl_c()?, + }) +} + +/// Represents a stream which receives "ctrl-c" notifications sent to the process +/// via `SetConsoleCtrlHandler`. +/// +/// A notification to this process notifies *all* streams listening for +/// this event. Moreover, the notifications **are coalesced** if they aren't processed +/// quickly enough. This means that if two notifications are received back-to-back, +/// then the stream may only receive one item about the two notifications. +#[must_use = "streams do nothing unless polled"] +#[derive(Debug)] +pub struct CtrlC { + inner: RxFuture, +} + +impl CtrlC { + /// Receives the next signal notification event. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// ```rust,no_run + /// use tokio::signal::windows::ctrl_c; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// // An infinite stream of CTRL-C events. + /// let mut stream = ctrl_c()?; + /// + /// // Print whenever a CTRL-C event is received. + /// for countdown in (0..3).rev() { + /// stream.recv().await; + /// println!("got CTRL-C. {} more to exit", countdown); + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn recv(&mut self) -> Option<()> { + self.inner.recv().await + } + + /// Polls to receive the next signal notification event, outside of an + /// `async` context. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// Polling from a manually implemented future + /// + /// ```rust,no_run + /// use std::pin::Pin; + /// use std::future::Future; + /// use std::task::{Context, Poll}; + /// use tokio::signal::windows::CtrlC; + /// + /// struct MyFuture { + /// ctrl_c: CtrlC, + /// } + /// + /// impl Future for MyFuture { + /// type Output = Option<()>; + /// + /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + /// println!("polling MyFuture"); + /// self.ctrl_c.poll_recv(cx) + /// } + /// } + /// ``` + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + self.inner.poll_recv(cx) + } +} + +/// Represents a stream which receives "ctrl-break" notifications sent to the process +/// via `SetConsoleCtrlHandler`. +/// +/// A notification to this process notifies *all* streams listening for +/// this event. Moreover, the notifications **are coalesced** if they aren't processed +/// quickly enough. This means that if two notifications are received back-to-back, +/// then the stream may only receive one item about the two notifications. +#[must_use = "streams do nothing unless polled"] +#[derive(Debug)] +pub struct CtrlBreak { + inner: RxFuture, +} + +impl CtrlBreak { + /// Receives the next signal notification event. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// ```rust,no_run + /// use tokio::signal::windows::ctrl_break; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// // An infinite stream of CTRL-BREAK events. + /// let mut stream = ctrl_break()?; + /// + /// // Print whenever a CTRL-BREAK event is received. + /// loop { + /// stream.recv().await; + /// println!("got signal CTRL-BREAK"); + /// } + /// } + /// ``` + pub async fn recv(&mut self) -> Option<()> { + self.inner.recv().await + } + + /// Polls to receive the next signal notification event, outside of an + /// `async` context. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// Polling from a manually implemented future + /// + /// ```rust,no_run + /// use std::pin::Pin; + /// use std::future::Future; + /// use std::task::{Context, Poll}; + /// use tokio::signal::windows::CtrlBreak; + /// + /// struct MyFuture { + /// ctrl_break: CtrlBreak, + /// } + /// + /// impl Future for MyFuture { + /// type Output = Option<()>; + /// + /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + /// println!("polling MyFuture"); + /// self.ctrl_break.poll_recv(cx) + /// } + /// } + /// ``` + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + self.inner.poll_recv(cx) + } +} + +/// Creates a new stream which receives "ctrl-break" notifications sent to the +/// process. +/// +/// # Examples +/// +/// ```rust,no_run +/// use tokio::signal::windows::ctrl_break; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box<dyn std::error::Error>> { +/// // An infinite stream of CTRL-BREAK events. +/// let mut stream = ctrl_break()?; +/// +/// // Print whenever a CTRL-BREAK event is received. +/// loop { +/// stream.recv().await; +/// println!("got signal CTRL-BREAK"); +/// } +/// } +/// ``` +pub fn ctrl_break() -> io::Result<CtrlBreak> { + Ok(CtrlBreak { + inner: self::imp::ctrl_break()?, + }) +} diff --git a/third_party/rust/tokio/src/signal/windows/stub.rs b/third_party/rust/tokio/src/signal/windows/stub.rs new file mode 100644 index 0000000000..88630543da --- /dev/null +++ b/third_party/rust/tokio/src/signal/windows/stub.rs @@ -0,0 +1,13 @@ +//! Stub implementations for the platform API so that rustdoc can build linkable +//! documentation on non-windows platforms. + +use crate::signal::RxFuture; +use std::io; + +pub(super) fn ctrl_c() -> io::Result<RxFuture> { + panic!() +} + +pub(super) fn ctrl_break() -> io::Result<RxFuture> { + panic!() +} diff --git a/third_party/rust/tokio/src/signal/windows/sys.rs b/third_party/rust/tokio/src/signal/windows/sys.rs new file mode 100644 index 0000000000..8d29c357b6 --- /dev/null +++ b/third_party/rust/tokio/src/signal/windows/sys.rs @@ -0,0 +1,153 @@ +use std::convert::TryFrom; +use std::io; +use std::sync::Once; + +use crate::signal::registry::{globals, EventId, EventInfo, Init, Storage}; +use crate::signal::RxFuture; + +use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE}; +use winapi::um::consoleapi::SetConsoleCtrlHandler; +use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT}; + +pub(super) fn ctrl_c() -> io::Result<RxFuture> { + new(CTRL_C_EVENT) +} + +pub(super) fn ctrl_break() -> io::Result<RxFuture> { + new(CTRL_BREAK_EVENT) +} + +fn new(signum: DWORD) -> io::Result<RxFuture> { + global_init()?; + let rx = globals().register_listener(signum as EventId); + Ok(RxFuture::new(rx)) +} + +#[derive(Debug)] +pub(crate) struct OsStorage { + ctrl_c: EventInfo, + ctrl_break: EventInfo, +} + +impl Init for OsStorage { + fn init() -> Self { + Self { + ctrl_c: EventInfo::default(), + ctrl_break: EventInfo::default(), + } + } +} + +impl Storage for OsStorage { + fn event_info(&self, id: EventId) -> Option<&EventInfo> { + match DWORD::try_from(id) { + Ok(CTRL_C_EVENT) => Some(&self.ctrl_c), + Ok(CTRL_BREAK_EVENT) => Some(&self.ctrl_break), + _ => None, + } + } + + fn for_each<'a, F>(&'a self, mut f: F) + where + F: FnMut(&'a EventInfo), + { + f(&self.ctrl_c); + f(&self.ctrl_break); + } +} + +#[derive(Debug)] +pub(crate) struct OsExtraData {} + +impl Init for OsExtraData { + fn init() -> Self { + Self {} + } +} + +fn global_init() -> io::Result<()> { + static INIT: Once = Once::new(); + + let mut init = None; + + INIT.call_once(|| unsafe { + let rc = SetConsoleCtrlHandler(Some(handler), TRUE); + let ret = if rc == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + }; + + init = Some(ret); + }); + + init.unwrap_or_else(|| Ok(())) +} + +unsafe extern "system" fn handler(ty: DWORD) -> BOOL { + let globals = globals(); + globals.record_event(ty as EventId); + + // According to https://docs.microsoft.com/en-us/windows/console/handlerroutine + // the handler routine is always invoked in a new thread, thus we don't + // have the same restrictions as in Unix signal handlers, meaning we can + // go ahead and perform the broadcast here. + if globals.broadcast() { + TRUE + } else { + // No one is listening for this notification any more + // let the OS fire the next (possibly the default) handler. + FALSE + } +} + +#[cfg(all(test, not(loom)))] +mod tests { + use super::*; + use crate::runtime::Runtime; + + use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; + + #[test] + fn ctrl_c() { + let rt = rt(); + let _enter = rt.enter(); + + let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); + + assert_pending!(ctrl_c.poll()); + + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_C_EVENT); + } + + assert_ready_ok!(ctrl_c.poll()); + } + + #[test] + fn ctrl_break() { + let rt = rt(); + + rt.block_on(async { + let mut ctrl_break = assert_ok!(crate::signal::windows::ctrl_break()); + + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_BREAK_EVENT); + } + + ctrl_break.recv().await.unwrap(); + }); + } + + fn rt() -> Runtime { + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + } +} diff --git a/third_party/rust/tokio/src/sync/barrier.rs b/third_party/rust/tokio/src/sync/barrier.rs new file mode 100644 index 0000000000..dfc76a40eb --- /dev/null +++ b/third_party/rust/tokio/src/sync/barrier.rs @@ -0,0 +1,206 @@ +use crate::loom::sync::Mutex; +use crate::sync::watch; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; + +/// A barrier enables multiple tasks to synchronize the beginning of some computation. +/// +/// ``` +/// # #[tokio::main] +/// # async fn main() { +/// use tokio::sync::Barrier; +/// use std::sync::Arc; +/// +/// let mut handles = Vec::with_capacity(10); +/// let barrier = Arc::new(Barrier::new(10)); +/// for _ in 0..10 { +/// let c = barrier.clone(); +/// // The same messages will be printed together. +/// // You will NOT see any interleaving. +/// handles.push(tokio::spawn(async move { +/// println!("before wait"); +/// let wait_result = c.wait().await; +/// println!("after wait"); +/// wait_result +/// })); +/// } +/// +/// // Will not resolve until all "after wait" messages have been printed +/// let mut num_leaders = 0; +/// for handle in handles { +/// let wait_result = handle.await.unwrap(); +/// if wait_result.is_leader() { +/// num_leaders += 1; +/// } +/// } +/// +/// // Exactly one barrier will resolve as the "leader" +/// assert_eq!(num_leaders, 1); +/// # } +/// ``` +#[derive(Debug)] +pub struct Barrier { + state: Mutex<BarrierState>, + wait: watch::Receiver<usize>, + n: usize, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +#[derive(Debug)] +struct BarrierState { + waker: watch::Sender<usize>, + arrived: usize, + generation: usize, +} + +impl Barrier { + /// Creates a new barrier that can block a given number of tasks. + /// + /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all + /// tasks at once when the `n`th task calls `wait`. + #[track_caller] + pub fn new(mut n: usize) -> Barrier { + let (waker, wait) = crate::sync::watch::channel(0); + + if n == 0 { + // if n is 0, it's not clear what behavior the user wants. + // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every + // .wait() immediately unblocks, so we adopt that here as well. + n = 1; + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Barrier", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + size = n, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + arrived = 0, + ) + }); + resource_span + }; + + Barrier { + state: Mutex::new(BarrierState { + waker, + arrived: 0, + generation: 1, + }), + n, + wait, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span, + } + } + + /// Does not resolve until all tasks have rendezvoused here. + /// + /// Barriers are re-usable after all tasks have rendezvoused once, and can + /// be used continuously. + /// + /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from + /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks + /// will receive a result that will return `false` from `is_leader`. + pub async fn wait(&self) -> BarrierWaitResult { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace::async_op( + || self.wait_internal(), + self.resource_span.clone(), + "Barrier::wait", + "poll", + false, + ) + .await; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return self.wait_internal().await; + } + async fn wait_internal(&self) -> BarrierWaitResult { + // NOTE: we are taking a _synchronous_ lock here. + // It is okay to do so because the critical section is fast and never yields, so it cannot + // deadlock even if another future is concurrently holding the lock. + // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than + // the asynchronous counter-parts, so we should use them where possible [citation needed]. + // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across + // a yield point, and thus marks the returned future as !Send. + let generation = { + let mut state = self.state.lock(); + let generation = state.generation; + state.arrived += 1; + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::state_update", + arrived = 1, + arrived.op = "add", + ); + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::async_op::state_update", + arrived = true, + ); + if state.arrived == self.n { + #[cfg(all(tokio_unstable, feature = "tracing"))] + tracing::trace!( + target: "runtime::resource::async_op::state_update", + is_leader = true, + ); + // we are the leader for this generation + // wake everyone, increment the generation, and return + state + .waker + .send(state.generation) + .expect("there is at least one receiver"); + state.arrived = 0; + state.generation += 1; + return BarrierWaitResult(true); + } + + generation + }; + + // we're going to have to wait for the last of the generation to arrive + let mut wait = self.wait.clone(); + + loop { + let _ = wait.changed().await; + + // note that the first time through the loop, this _will_ yield a generation + // immediately, since we cloned a receiver that has never seen any values. + if *wait.borrow() >= generation { + break; + } + } + + BarrierWaitResult(false) + } +} + +/// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. +#[derive(Debug, Clone)] +pub struct BarrierWaitResult(bool); + +impl BarrierWaitResult { + /// Returns `true` if this task from wait is the "leader task". + /// + /// Only one task will have `true` returned from their result, all other tasks will have + /// `false` returned. + pub fn is_leader(&self) -> bool { + self.0 + } +} diff --git a/third_party/rust/tokio/src/sync/batch_semaphore.rs b/third_party/rust/tokio/src/sync/batch_semaphore.rs new file mode 100644 index 0000000000..4f5effff31 --- /dev/null +++ b/third_party/rust/tokio/src/sync/batch_semaphore.rs @@ -0,0 +1,727 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] +//! # Implementation Details. +//! +//! The semaphore is implemented using an intrusive linked list of waiters. An +//! atomic counter tracks the number of available permits. If the semaphore does +//! not contain the required number of permits, the task attempting to acquire +//! permits places its waker at the end of a queue. When new permits are made +//! available (such as by releasing an initial acquisition), they are assigned +//! to the task at the front of the queue, waking that task if its requested +//! number of permits is met. +//! +//! Because waiters are enqueued at the back of the linked list and dequeued +//! from the front, the semaphore is fair. Tasks trying to acquire large numbers +//! of permits at a time will always be woken eventually, even if many other +//! tasks are acquiring smaller numbers of permits. This means that in a +//! use-case like tokio's read-write lock, writers will not be starved by +//! readers. +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Mutex, MutexGuard}; +use crate::util::linked_list::{self, LinkedList}; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use crate::util::WakeList; + +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::*; +use std::task::Poll::*; +use std::task::{Context, Poll, Waker}; +use std::{cmp, fmt}; + +/// An asynchronous counting semaphore which permits waiting on multiple permits at once. +pub(crate) struct Semaphore { + waiters: Mutex<Waitlist>, + /// The current number of available permits in the semaphore. + permits: AtomicUsize, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +struct Waitlist { + queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, + closed: bool, +} + +/// Error returned from the [`Semaphore::try_acquire`] function. +/// +/// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire +#[derive(Debug, PartialEq)] +pub enum TryAcquireError { + /// The semaphore has been [closed] and cannot issue new permits. + /// + /// [closed]: crate::sync::Semaphore::close + Closed, + + /// The semaphore has no available permits. + NoPermits, +} +/// Error returned from the [`Semaphore::acquire`] function. +/// +/// An `acquire` operation can only fail if the semaphore has been +/// [closed]. +/// +/// [closed]: crate::sync::Semaphore::close +/// [`Semaphore::acquire`]: crate::sync::Semaphore::acquire +#[derive(Debug)] +pub struct AcquireError(()); + +pub(crate) struct Acquire<'a> { + node: Waiter, + semaphore: &'a Semaphore, + num_permits: u32, + queued: bool, +} + +/// An entry in the wait queue. +struct Waiter { + /// The current state of the waiter. + /// + /// This is either the number of remaining permits required by + /// the waiter, or a flag indicating that the waiter is not yet queued. + state: AtomicUsize, + + /// The waker to notify the task awaiting permits. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + waker: UnsafeCell<Option<Waker>>, + + /// Intrusive linked-list pointers. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + /// + /// TODO: Ideally, we would be able to use loom to enforce that + /// this isn't accessed concurrently. However, it is difficult to + /// use a `UnsafeCell` here, since the `Link` trait requires _returning_ + /// references to `Pointers`, and `UnsafeCell` requires that checked access + /// take place inside a closure. We should consider changing `Pointers` to + /// use `UnsafeCell` internally. + pointers: linked_list::Pointers<Waiter>, + + #[cfg(all(tokio_unstable, feature = "tracing"))] + ctx: trace::AsyncOpTracingCtx, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +impl Semaphore { + /// The maximum number of permits which a semaphore can hold. + /// + /// Note that this reserves three bits of flags in the permit counter, but + /// we only actually use one of them. However, the previous semaphore + /// implementation used three bits, so we will continue to reserve them to + /// avoid a breaking change if additional flags need to be added in the + /// future. + pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; + const CLOSED: usize = 1; + // The least-significant bit in the number of permits is reserved to use + // as a flag indicating that the semaphore has been closed. Consequently + // PERMIT_SHIFT is used to leave that bit for that purpose. + const PERMIT_SHIFT: usize = 1; + + /// Creates a new semaphore with the initial number of permits + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. + pub(crate) fn new(permits: usize) -> Self { + assert!( + permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Semaphore", + kind = "Sync", + is_internal = true + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = permits, + permits.op = "override", + ) + }); + resource_span + }; + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new semaphore with the initial number of permits. + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. + /// + /// If the specified number of permits exceeds the maximum permit amount + /// Then the value will get clamped to the maximum number of permits. + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + pub(crate) const fn const_new(mut permits: usize) -> Self { + // NOTE: assertions and by extension panics are still being worked on: https://github.com/rust-lang/rust/issues/74925 + // currently we just clamp the permit count when it exceeds the max + permits &= Self::MAX_PERMITS; + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::const_new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Returns the current number of available permits. + pub(crate) fn available_permits(&self) -> usize { + self.permits.load(Acquire) >> Self::PERMIT_SHIFT + } + + /// Adds `added` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. + pub(crate) fn release(&self, added: usize) { + if added == 0 { + return; + } + + // Assign permits to the wait queue + self.add_permits_locked(added, self.waiters.lock()); + } + + /// Closes the semaphore. This prevents the semaphore from issuing new + /// permits and notifies all pending waiters. + pub(crate) fn close(&self) { + let mut waiters = self.waiters.lock(); + // If the semaphore's permits counter has enough permits for an + // unqueued waiter to acquire all the permits it needs immediately, + // it won't touch the wait list. Therefore, we have to set a bit on + // the permit counter as well. However, we must do this while + // holding the lock --- otherwise, if we set the bit and then wait + // to acquire the lock we'll enter an inconsistent state where the + // permit counter is closed, but the wait list is not. + self.permits.fetch_or(Self::CLOSED, Release); + waiters.closed = true; + while let Some(mut waiter) = waiters.queue.pop_back() { + let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + if let Some(waker) = waker { + waker.wake(); + } + } + } + + /// Returns true if the semaphore is closed. + pub(crate) fn is_closed(&self) -> bool { + self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED + } + + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { + assert!( + num_permits as usize <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; + let mut curr = self.permits.load(Acquire); + loop { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { + return Err(TryAcquireError::Closed); + } + + // Are there enough permits remaining? + if curr < num_permits { + return Err(TryAcquireError::NoPermits); + } + + let next = curr - num_permits; + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + // TODO: Instrument once issue has been solved} + return Ok(()); + } + Err(actual) => curr = actual, + } + } + } + + pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> { + Acquire::new(self, num_permits) + } + + /// Release `rem` permits to the semaphore's wait list, starting from the + /// end of the queue. + /// + /// If `rem` exceeds the number of permits needed by the wait list, the + /// remainder are assigned back to the semaphore. + fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { + let mut wakers = WakeList::new(); + let mut lock = Some(waiters); + let mut is_empty = false; + while rem > 0 { + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); + 'inner: while wakers.can_push() { + // Was the waiter assigned enough permits to wake it? + match waiters.queue.last() { + Some(waiter) => { + if !waiter.assign_permits(&mut rem) { + break 'inner; + } + } + None => { + is_empty = true; + // If we assigned permits to all the waiters in the queue, and there are + // still permits left over, assign them back to the semaphore. + break 'inner; + } + }; + let mut waiter = waiters.queue.pop_back().unwrap(); + if let Some(waker) = + unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) } + { + wakers.push(waker); + } + } + + if rem > 0 && is_empty { + let permits = rem; + assert!( + permits <= Self::MAX_PERMITS, + "cannot add more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); + let prev = prev >> Self::PERMIT_SHIFT; + assert!( + prev + permits <= Self::MAX_PERMITS, + "number of added permits ({}) would overflow MAX_PERMITS ({})", + rem, + Self::MAX_PERMITS + ); + + // add remaining permits back + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = rem, + permits.op = "add", + ) + }); + + rem = 0; + } + + drop(waiters); // release the lock + + wakers.wake_all(); + } + + assert_eq!(rem, 0); + } + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + num_permits: u32, + node: Pin<&mut Waiter>, + queued: bool, + ) -> Poll<Result<(), AcquireError>> { + let mut acquired = 0; + + let needed = if queued { + node.state.load(Acquire) << Self::PERMIT_SHIFT + } else { + (num_permits as usize) << Self::PERMIT_SHIFT + }; + + let mut lock = None; + // First, try to take the requested number of permits from the + // semaphore. + let mut curr = self.permits.load(Acquire); + let mut waiters = loop { + // Has the semaphore closed? + if curr & Self::CLOSED > 0 { + return Ready(Err(AcquireError::closed())); + } + + let mut remaining = 0; + let total = curr + .checked_add(acquired) + .expect("number of permits must not overflow"); + let (next, acq) = if total >= needed { + let next = curr - (needed - acquired); + (next, needed >> Self::PERMIT_SHIFT) + } else { + remaining = (needed - acquired) - curr; + (0, curr >> Self::PERMIT_SHIFT) + }; + + if remaining > 0 && lock.is_none() { + // No permits were immediately available, so this permit will + // (probably) need to wait. We'll need to acquire a lock on the + // wait queue before continuing. We need to do this _before_ the + // CAS that sets the new value of the semaphore's `permits` + // counter. Otherwise, if we subtract the permits and then + // acquire the lock, we might miss additional permits being + // added while waiting for the lock. + lock = Some(self.waiters.lock()); + } + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + acquired += acq; + if remaining == 0 { + if !queued { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = acquired, + permits.op = "sub", + ); + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = acquired, + permits.op = "add", + ) + }); + + return Ready(Ok(())); + } else if lock.is_none() { + break self.waiters.lock(); + } + } + break lock.expect("lock must be acquired before waiting"); + } + Err(actual) => curr = actual, + } + }; + + if waiters.closed { + return Ready(Err(AcquireError::closed())); + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + permits = acquired, + permits.op = "sub", + ) + }); + + if node.assign_permits(&mut acquired) { + self.add_permits_locked(acquired, waiters); + return Ready(Ok(())); + } + + assert_eq!(acquired, 0); + + // Otherwise, register the waker & enqueue the node. + node.waker.with_mut(|waker| { + // Safety: the wait list is locked, so we may modify the waker. + let waker = unsafe { &mut *waker }; + // Do we need to register the new waker? + if waker + .as_ref() + .map(|waker| !waker.will_wake(cx.waker())) + .unwrap_or(true) + { + *waker = Some(cx.waker().clone()); + } + }); + + // If the waiter is not already in the wait queue, enqueue it. + if !queued { + let node = unsafe { + let node = Pin::into_inner_unchecked(node) as *mut _; + NonNull::new_unchecked(node) + }; + + waiters.queue.push_front(node); + } + + Pending + } +} + +impl fmt::Debug for Semaphore { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Semaphore") + .field("permits", &self.available_permits()) + .finish() + } +} + +impl Waiter { + fn new( + num_permits: u32, + #[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx, + ) -> Self { + Waiter { + waker: UnsafeCell::new(None), + state: AtomicUsize::new(num_permits as usize), + pointers: linked_list::Pointers::new(), + #[cfg(all(tokio_unstable, feature = "tracing"))] + ctx, + _p: PhantomPinned, + } + } + + /// Assign permits to the waiter. + /// + /// Returns `true` if the waiter should be removed from the queue + fn assign_permits(&self, n: &mut usize) -> bool { + let mut curr = self.state.load(Acquire); + loop { + let assign = cmp::min(curr, *n); + let next = curr - assign; + match self.state.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + *n -= assign; + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.ctx.async_op_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = assign, + permits.op = "add", + ); + }); + return next == 0; + } + Err(actual) => curr = actual, + } + } + } +} + +impl Future for Acquire<'_> { + type Output = Result<(), AcquireError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _resource_span = self.node.ctx.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _async_op_span = self.node.ctx.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered(); + + let (node, semaphore, needed, queued) = self.project(); + + // First, ensure the current task has enough budget to proceed. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let coop = ready!(trace_poll_op!( + "poll_acquire", + crate::coop::poll_proceed(cx), + )); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let coop = ready!(crate::coop::poll_proceed(cx)); + + let result = match semaphore.poll_acquire(cx, needed, node, *queued) { + Pending => { + *queued = true; + Pending + } + Ready(r) => { + coop.made_progress(); + r?; + *queued = false; + Ready(Ok(())) + } + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace_poll_op!("poll_acquire", result); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + return result; + } +} + +impl<'a> Acquire<'a> { + fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self { + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return Self { + node: Waiter::new(num_permits), + semaphore, + num_permits, + queued: false, + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return semaphore.resource_span.in_scope(|| { + let async_op_span = + tracing::trace_span!("runtime.resource.async_op", source = "Acquire::new"); + let async_op_poll_span = async_op_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_requested = num_permits, + permits.op = "override", + ); + + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = 0 as usize, + permits.op = "override", + ); + + tracing::trace_span!("runtime.resource.async_op.poll") + }); + + let ctx = trace::AsyncOpTracingCtx { + async_op_span, + async_op_poll_span, + resource_span: semaphore.resource_span.clone(), + }; + + Self { + node: Waiter::new(num_permits, ctx), + semaphore, + num_permits, + queued: false, + } + }); + } + + fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) { + fn is_unpin<T: Unpin>() {} + unsafe { + // Safety: all fields other than `node` are `Unpin` + + is_unpin::<&Semaphore>(); + is_unpin::<&mut bool>(); + is_unpin::<u32>(); + + let this = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut this.node), + this.semaphore, + this.num_permits, + &mut this.queued, + ) + } + } +} + +impl Drop for Acquire<'_> { + fn drop(&mut self) { + // If the future is completed, there is no node in the wait list, so we + // can skip acquiring the lock. + if !self.queued { + return; + } + + // This is where we ensure safety. The future is being dropped, + // which means we must ensure that the waiter entry is no longer stored + // in the linked list. + let mut waiters = self.semaphore.waiters.lock(); + + // remove the entry from the list + let node = NonNull::from(&mut self.node); + // Safety: we have locked the wait list. + unsafe { waiters.queue.remove(node) }; + + let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire); + if acquired_permits > 0 { + self.semaphore.add_permits_locked(acquired_permits, waiters); + } + } +} + +// Safety: the `Acquire` future is not `Sync` automatically because it contains +// a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the +// `UnsafeCell` is only accessed when the future is borrowed mutably (either in +// `poll` or in `drop`). Therefore, it is safe (although not particularly +// _useful_) for the future to be borrowed immutably across threads. +unsafe impl Sync for Acquire<'_> {} + +// ===== impl AcquireError ==== + +impl AcquireError { + fn closed() -> AcquireError { + AcquireError(()) + } +} + +impl fmt::Display for AcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "semaphore closed") + } +} + +impl std::error::Error for AcquireError {} + +// ===== impl TryAcquireError ===== + +impl TryAcquireError { + /// Returns `true` if the error was caused by a closed semaphore. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_closed(&self) -> bool { + matches!(self, TryAcquireError::Closed) + } + + /// Returns `true` if the error was caused by calling `try_acquire` on a + /// semaphore with no available permits. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_no_permits(&self) -> bool { + matches!(self, TryAcquireError::NoPermits) + } +} + +impl fmt::Display for TryAcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryAcquireError::Closed => write!(fmt, "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "no permits available"), + } + } +} + +impl std::error::Error for TryAcquireError {} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + // XXX: ideally, we would be able to use `Pin` here, to enforce the + // invariant that list entries may not move while in the list. However, we + // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle` + // type would require `Semaphore` to be generic over a lifetime. We can't + // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether + // or not they dereference to an `!Unpin` target. + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} diff --git a/third_party/rust/tokio/src/sync/broadcast.rs b/third_party/rust/tokio/src/sync/broadcast.rs new file mode 100644 index 0000000000..0d9cd3bc17 --- /dev/null +++ b/third_party/rust/tokio/src/sync/broadcast.rs @@ -0,0 +1,1078 @@ +//! A multi-producer, multi-consumer broadcast queue. Each sent value is seen by +//! all consumers. +//! +//! A [`Sender`] is used to broadcast values to **all** connected [`Receiver`] +//! values. [`Sender`] handles are clone-able, allowing concurrent send and +//! receive actions. [`Sender`] and [`Receiver`] are both `Send` and `Sync` as +//! long as `T` is also `Send` or `Sync` respectively. +//! +//! When a value is sent, **all** [`Receiver`] handles are notified and will +//! receive the value. The value is stored once inside the channel and cloned on +//! demand for each receiver. Once all receivers have received a clone of the +//! value, the value is released from the channel. +//! +//! A channel is created by calling [`channel`], specifying the maximum number +//! of messages the channel can retain at any given time. +//! +//! New [`Receiver`] handles are created by calling [`Sender::subscribe`]. The +//! returned [`Receiver`] will receive values sent **after** the call to +//! `subscribe`. +//! +//! ## Lagging +//! +//! As sent messages must be retained until **all** [`Receiver`] handles receive +//! a clone, broadcast channels are susceptible to the "slow receiver" problem. +//! In this case, all but one receiver are able to receive values at the rate +//! they are sent. Because one receiver is stalled, the channel starts to fill +//! up. +//! +//! This broadcast channel implementation handles this case by setting a hard +//! upper bound on the number of values the channel may retain at any given +//! time. This upper bound is passed to the [`channel`] function as an argument. +//! +//! If a value is sent when the channel is at capacity, the oldest value +//! currently held by the channel is released. This frees up space for the new +//! value. Any receiver that has not yet seen the released value will return +//! [`RecvError::Lagged`] the next time [`recv`] is called. +//! +//! Once [`RecvError::Lagged`] is returned, the lagging receiver's position is +//! updated to the oldest value contained by the channel. The next call to +//! [`recv`] will return this value. +//! +//! This behavior enables a receiver to detect when it has lagged so far behind +//! that data has been dropped. The caller may decide how to respond to this: +//! either by aborting its task or by tolerating lost messages and resuming +//! consumption of the channel. +//! +//! ## Closing +//! +//! When **all** [`Sender`] handles have been dropped, no new values may be +//! sent. At this point, the channel is "closed". Once a receiver has received +//! all values retained by the channel, the next call to [`recv`] will return +//! with [`RecvError::Closed`]. +//! +//! [`Sender`]: crate::sync::broadcast::Sender +//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe +//! [`Receiver`]: crate::sync::broadcast::Receiver +//! [`channel`]: crate::sync::broadcast::channel +//! [`RecvError::Lagged`]: crate::sync::broadcast::error::RecvError::Lagged +//! [`RecvError::Closed`]: crate::sync::broadcast::error::RecvError::Closed +//! [`recv`]: crate::sync::broadcast::Receiver::recv +//! +//! # Examples +//! +//! Basic usage +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx1) = broadcast::channel(16); +//! let mut rx2 = tx.subscribe(); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx1.recv().await.unwrap(), 10); +//! assert_eq!(rx1.recv().await.unwrap(), 20); +//! }); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx2.recv().await.unwrap(), 10); +//! assert_eq!(rx2.recv().await.unwrap(), 20); +//! }); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! } +//! ``` +//! +//! Handling lag +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx) = broadcast::channel(2); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! tx.send(30).unwrap(); +//! +//! // The receiver lagged behind +//! assert!(rx.recv().await.is_err()); +//! +//! // At this point, we can abort or continue with lost messages +//! +//! assert_eq!(20, rx.recv().await.unwrap()); +//! assert_eq!(30, rx.recv().await.unwrap()); +//! } +//! ``` + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, LinkedList}; + +use std::fmt; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::SeqCst; +use std::task::{Context, Poll, Waker}; +use std::usize; + +/// Sending-half of the [`broadcast`] channel. +/// +/// May be used from many threads. Messages can be sent with +/// [`send`][Sender::send]. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +/// +/// [`broadcast`]: crate::sync::broadcast +pub struct Sender<T> { + shared: Arc<Shared<T>>, +} + +/// Receiving-half of the [`broadcast`] channel. +/// +/// Must not be used concurrently. Messages may be retrieved using +/// [`recv`][Receiver::recv]. +/// +/// To turn this receiver into a `Stream`, you can use the [`BroadcastStream`] +/// wrapper. +/// +/// [`BroadcastStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.BroadcastStream.html +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +/// +/// [`broadcast`]: crate::sync::broadcast +pub struct Receiver<T> { + /// State shared with all receivers and senders. + shared: Arc<Shared<T>>, + + /// Next position to read from + next: u64, +} + +pub mod error { + //! Broadcast error types + + use std::fmt; + + /// Error returned by from the [`send`] function on a [`Sender`]. + /// + /// A **send** operation can only fail if there are no active receivers, + /// implying that the message could never be received. The error contains the + /// message being sent as a payload so it can be recovered. + /// + /// [`send`]: crate::sync::broadcast::Sender::send + /// [`Sender`]: crate::sync::broadcast::Sender + #[derive(Debug)] + pub struct SendError<T>(pub T); + + impl<T> fmt::Display for SendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} + + /// An error returned from the [`recv`] function on a [`Receiver`]. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum RecvError { + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvError::Closed => write!(f, "channel closed"), + RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for RecvError {} + + /// An error returned from the [`try_recv`] function on a [`Receiver`]. + /// + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum TryRecvError { + /// The channel is currently empty. There are still active + /// [`Sender`] handles, so data may yet become available. + /// + /// [`Sender`]: crate::sync::broadcast::Sender + Empty, + + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind and has been forcibly disconnected. + /// Attempting to receive again will return the oldest message still + /// retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(f, "channel empty"), + TryRecvError::Closed => write!(f, "channel closed"), + TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for TryRecvError {} +} + +use self::error::*; + +/// Data shared between senders and receivers. +struct Shared<T> { + /// slots in the channel. + buffer: Box<[RwLock<Slot<T>>]>, + + /// Mask a position -> index. + mask: usize, + + /// Tail of the queue. Includes the rx wait list. + tail: Mutex<Tail>, + + /// Number of outstanding Sender handles. + num_tx: AtomicUsize, +} + +/// Next position to write a value. +struct Tail { + /// Next position to write to. + pos: u64, + + /// Number of active receivers. + rx_cnt: usize, + + /// True if the channel is closed. + closed: bool, + + /// Receivers waiting for a value. + waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, +} + +/// Slot in the buffer. +struct Slot<T> { + /// Remaining number of receivers that are expected to see this value. + /// + /// When this goes to zero, the value is released. + /// + /// An atomic is used as it is mutated concurrently with the slot read lock + /// acquired. + rem: AtomicUsize, + + /// Uniquely identifies the `send` stored in the slot. + pos: u64, + + /// True signals the channel is closed. + closed: bool, + + /// The value being broadcast. + /// + /// The value is set by `send` when the write lock is held. When a reader + /// drops, `rem` is decremented. When it hits zero, the value is dropped. + val: UnsafeCell<Option<T>>, +} + +/// An entry in the wait queue. +struct Waiter { + /// True if queued. + queued: bool, + + /// Task waiting on the broadcast channel. + waker: Option<Waker>, + + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers<Waiter>, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +struct RecvGuard<'a, T> { + slot: RwLockReadGuard<'a, Slot<T>>, +} + +/// Receive a value future. +struct Recv<'a, T> { + /// Receiver being waited on. + receiver: &'a mut Receiver<T>, + + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, +} + +unsafe impl<'a, T: Send> Send for Recv<'a, T> {} +unsafe impl<'a, T: Send> Sync for Recv<'a, T> {} + +/// Max number of receivers. Reserve space to lock. +const MAX_RECEIVERS: usize = usize::MAX >> 2; + +/// Create a bounded, multi-producer, multi-consumer channel where each sent +/// value is broadcasted to all active receivers. +/// +/// All data sent on [`Sender`] will become available on every active +/// [`Receiver`] in the same order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple +/// points in the process or it can be used concurrently from an `Arc`. New +/// `Receiver` handles are created by calling [`Sender::subscribe`]. +/// +/// If all [`Receiver`] handles are dropped, the `send` method will return a +/// [`SendError`]. Similarly, if all [`Sender`] handles are dropped, the [`recv`] +/// method will return a [`RecvError`]. +/// +/// [`Sender`]: crate::sync::broadcast::Sender +/// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe +/// [`Receiver`]: crate::sync::broadcast::Receiver +/// [`recv`]: crate::sync::broadcast::Receiver::recv +/// [`SendError`]: crate::sync::broadcast::error::SendError +/// [`RecvError`]: crate::sync::broadcast::error::RecvError +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::broadcast; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx1) = broadcast::channel(16); +/// let mut rx2 = tx.subscribe(); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx1.recv().await.unwrap(), 10); +/// assert_eq!(rx1.recv().await.unwrap(), 20); +/// }); +/// +/// tokio::spawn(async move { +/// assert_eq!(rx2.recv().await.unwrap(), 10); +/// assert_eq!(rx2.recv().await.unwrap(), 20); +/// }); +/// +/// tx.send(10).unwrap(); +/// tx.send(20).unwrap(); +/// } +/// ``` +pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { + assert!(capacity > 0, "capacity is empty"); + assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); + + // Round to a power of two + capacity = capacity.next_power_of_two(); + + let mut buffer = Vec::with_capacity(capacity); + + for i in 0..capacity { + buffer.push(RwLock::new(Slot { + rem: AtomicUsize::new(0), + pos: (i as u64).wrapping_sub(capacity as u64), + closed: false, + val: UnsafeCell::new(None), + })); + } + + let shared = Arc::new(Shared { + buffer: buffer.into_boxed_slice(), + mask: capacity - 1, + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: 1, + closed: false, + waiters: LinkedList::new(), + }), + num_tx: AtomicUsize::new(1), + }); + + let rx = Receiver { + shared: shared.clone(), + next: 0, + }; + + let tx = Sender { shared }; + + (tx, rx) +} + +unsafe impl<T: Send> Send for Sender<T> {} +unsafe impl<T: Send> Sync for Sender<T> {} + +unsafe impl<T: Send> Send for Receiver<T> {} +unsafe impl<T: Send> Sync for Receiver<T> {} + +impl<T> Sender<T> { + /// Attempts to send a value to all active [`Receiver`] handles, returning + /// it back if it could not be sent. + /// + /// A successful send occurs when there is at least one active [`Receiver`] + /// handle. An unsuccessful send would be one where all associated + /// [`Receiver`] handles have already been dropped. + /// + /// # Return + /// + /// On success, the number of subscribed [`Receiver`] handles is returned. + /// This does not mean that this number of receivers will see the message as + /// a receiver may drop before receiving the message. + /// + /// # Note + /// + /// A return value of `Ok` **does not** mean that the sent value will be + /// observed by all or any of the active [`Receiver`] handles. [`Receiver`] + /// handles may be dropped before receiving the sent message. + /// + /// A return value of `Err` **does not** mean that future calls to `send` + /// will fail. New [`Receiver`] handles may be created by calling + /// [`subscribe`]. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// }); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// assert_eq!(rx2.recv().await.unwrap(), 20); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// } + /// ``` + pub fn send(&self, value: T) -> Result<usize, SendError<T>> { + self.send2(Some(value)) + .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap())) + } + + /// Creates a new [`Receiver`] handle that will receive values sent **after** + /// this call to `subscribe`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = broadcast::channel(16); + /// + /// // Will not be seen + /// tx.send(10).unwrap(); + /// + /// let mut rx = tx.subscribe(); + /// + /// tx.send(20).unwrap(); + /// + /// let value = rx.recv().await.unwrap(); + /// assert_eq!(20, value); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + new_receiver(shared) + } + + /// Returns the number of active receivers + /// + /// An active receiver is a [`Receiver`] handle returned from [`channel`] or + /// [`subscribe`]. These are the handles that will receive values sent on + /// this [`Sender`]. + /// + /// # Note + /// + /// It is not guaranteed that a sent message will reach this number of + /// receivers. Active receivers may never call [`recv`] again before + /// dropping. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`Sender`]: crate::sync::broadcast::Sender + /// [`subscribe`]: crate::sync::broadcast::Sender::subscribe + /// [`channel`]: crate::sync::broadcast::channel + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx1) = broadcast::channel(16); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = tx.subscribe(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// + /// tx.send(10).unwrap(); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + let tail = self.shared.tail.lock(); + tail.rx_cnt + } + + fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> { + let mut tail = self.shared.tail.lock(); + + if tail.rx_cnt == 0 { + return Err(SendError(value)); + } + + // Position to write into + let pos = tail.pos; + let rem = tail.rx_cnt; + let idx = (pos & self.shared.mask as u64) as usize; + + // Update the tail position + tail.pos = tail.pos.wrapping_add(1); + + // Get the slot + let mut slot = self.shared.buffer[idx].write().unwrap(); + + // Track the position + slot.pos = pos; + + // Set remaining receivers + slot.rem.with_mut(|v| *v = rem); + + // Set the closed bit if the value is `None`; otherwise write the value + if value.is_none() { + tail.closed = true; + slot.closed = true; + } else { + slot.val.with_mut(|ptr| unsafe { *ptr = value }); + } + + // Release the slot lock before notifying the receivers. + drop(slot); + + tail.notify_rx(); + + // Release the mutex. This must happen after the slot lock is released, + // otherwise the writer lock bit could be cleared while another thread + // is in the critical section. + drop(tail); + + Ok(rem) + } +} + +fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { + let mut tail = shared.tail.lock(); + + if tail.rx_cnt == MAX_RECEIVERS { + panic!("max receivers"); + } + + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + + let next = tail.pos; + + drop(tail); + + Receiver { shared, next } +} + +impl Tail { + fn notify_rx(&mut self) { + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.queued); + waiter.queued = false; + + let waker = waiter.waker.take().unwrap(); + waker.wake(); + } + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Sender<T> { + let shared = self.shared.clone(); + shared.num_tx.fetch_add(1, SeqCst); + + Sender { shared } + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) { + let _ = self.send2(None); + } + } +} + +impl<T> Receiver<T> { + /// Locks the next value if there is one. + fn recv_ref( + &mut self, + waiter: Option<(&UnsafeCell<Waiter>, &Waker)>, + ) -> Result<RecvGuard<'_, T>, TryRecvError> { + let idx = (self.next & self.shared.mask as u64) as usize; + + // The slot holding the next value to read + let mut slot = self.shared.buffer[idx].read().unwrap(); + + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + // The receiver has read all current values in the channel and there + // is no waiter to register + if waiter.is_none() && next_pos == self.next { + return Err(TryRecvError::Empty); + } + + // Release the `slot` lock before attempting to acquire the `tail` + // lock. This is required because `send2` acquires the tail lock + // first followed by the slot lock. Acquiring the locks in reverse + // order here would result in a potential deadlock: `recv_ref` + // acquires the `slot` lock and attempts to acquire the `tail` lock + // while `send2` acquired the `tail` lock and attempts to acquire + // the slot lock. + drop(slot); + + let mut tail = self.shared.tail.lock(); + + // Acquire slot lock again + slot = self.shared.buffer[idx].read().unwrap(); + + // Make sure the position did not change. This could happen in the + // unlikely event that the buffer is wrapped between dropping the + // read lock and acquiring the tail lock. + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + if next_pos == self.next { + // Store the waker + if let Some((waiter, waker)) = waiter { + // Safety: called while locked. + unsafe { + // Only queue if not already queued + waiter.with_mut(|ptr| { + // If there is no waker **or** if the currently + // stored waker references a **different** task, + // track the tasks' waker to be notified on + // receipt of a new value. + match (*ptr).waker { + Some(ref w) if w.will_wake(waker) => {} + _ => { + (*ptr).waker = Some(waker.clone()); + } + } + + if !(*ptr).queued { + (*ptr).queued = true; + tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); + } + }); + } + } + + return Err(TryRecvError::Empty); + } + + // At this point, the receiver has lagged behind the sender by + // more than the channel capacity. The receiver will attempt to + // catch up by skipping dropped messages and setting the + // internal cursor to the **oldest** message stored by the + // channel. + // + // However, finding the oldest position is a bit more + // complicated than `tail-position - buffer-size`. When + // the channel is closed, the tail position is incremented to + // signal a new `None` message, but `None` is not stored in the + // channel itself (see issue #2425 for why). + // + // To account for this, if the channel is closed, the tail + // position is decremented by `buffer-size + 1`. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + + let missed = next.wrapping_sub(self.next); + + drop(tail); + + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); + + return Ok(RecvGuard { slot }); + } + + self.next = next; + + return Err(TryRecvError::Lagged(missed)); + } + } + + self.next = self.next.wrapping_add(1); + + if slot.closed { + return Err(TryRecvError::Closed); + } + + Ok(RecvGuard { slot }) + } +} + +impl<T: Clone> Receiver<T> { + /// Receives the next value for this receiver. + /// + /// Each [`Receiver`] handle will receive a clone of all values sent + /// **after** it has subscribed. + /// + /// `Err(RecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(RecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`recv`] will return this value + /// **unless** it has been since overwritten. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// }); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// assert_eq!(rx2.recv().await.unwrap(), 20); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// } + /// ``` + /// + /// Handling lag + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(2); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// tx.send(30).unwrap(); + /// + /// // The receiver lagged behind + /// assert!(rx.recv().await.is_err()); + /// + /// // At this point, we can abort or continue with lost messages + /// + /// assert_eq!(20, rx.recv().await.unwrap()); + /// assert_eq!(30, rx.recv().await.unwrap()); + /// } + /// ``` + pub async fn recv(&mut self) -> Result<T, RecvError> { + let fut = Recv::new(self); + fut.await + } + + /// Attempts to return a pending value on this receiver without awaiting. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// await on a receiver. + /// + /// Compared with [`recv`], this function has three failure cases instead of two + /// (one for closed, one for an empty buffer, one for a lagging receiver). + /// + /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`try_recv`] will return this value + /// **unless** it has been since overwritten. If there are no values to + /// receive, `Err(TryRecvError::Empty)` is returned. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// assert!(rx.try_recv().is_err()); + /// + /// tx.send(10).unwrap(); + /// + /// let value = rx.try_recv().unwrap(); + /// assert_eq!(10, value); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let guard = self.recv_ref(None)?; + guard.clone_value().ok_or(TryRecvError::Closed) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + let mut tail = self.shared.tail.lock(); + + tail.rx_cnt -= 1; + let until = tail.pos; + + drop(tail); + + while self.next < until { + match self.recv_ref(None) { + Ok(_) => {} + // The channel is closed + Err(TryRecvError::Closed) => break, + // Ignore lagging, we will catch up + Err(TryRecvError::Lagged(..)) => {} + // Can't be empty + Err(TryRecvError::Empty) => panic!("unexpected empty broadcast channel"), + } + } + } +} + +impl<'a, T> Recv<'a, T> { + fn new(receiver: &'a mut Receiver<T>) -> Recv<'a, T> { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + } + } + + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&mut Receiver<T>, &UnsafeCell<Waiter>) { + unsafe { + // Safety: Receiver is Unpin + is_unpin::<&mut Receiver<T>>(); + + let me = self.get_unchecked_mut(); + (me.receiver, &me.waiter) + } + } +} + +impl<'a, T> Future for Recv<'a, T> +where + T: Clone, +{ + type Output = Result<T, RecvError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + }; + + Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) + } +} + +impl<'a, T> Drop for Recv<'a, T> { + fn drop(&mut self) { + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.shared.tail.lock(); + + // safety: tail lock is held + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Sender") + } +} + +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "broadcast::Receiver") + } +} + +impl<'a, T> RecvGuard<'a, T> { + fn clone_value(&self) -> Option<T> + where + T: Clone, + { + self.slot.val.with(|ptr| unsafe { (*ptr).clone() }) + } +} + +impl<'a, T> Drop for RecvGuard<'a, T> { + fn drop(&mut self) { + // Decrement the remaining counter + if 1 == self.slot.rem.fetch_sub(1, SeqCst) { + // Safety: Last receiver, drop the value + self.slot.val.with_mut(|ptr| unsafe { *ptr = None }); + } + } +} + +fn is_unpin<T: Unpin>() {} diff --git a/third_party/rust/tokio/src/sync/mod.rs b/third_party/rust/tokio/src/sync/mod.rs new file mode 100644 index 0000000000..457e6ab294 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mod.rs @@ -0,0 +1,499 @@ +#![cfg_attr(loom, allow(dead_code, unreachable_pub, unused_imports))] + +//! Synchronization primitives for use in asynchronous contexts. +//! +//! Tokio programs tend to be organized as a set of [tasks] where each task +//! operates independently and may be executed on separate physical threads. The +//! synchronization primitives provided in this module permit these independent +//! tasks to communicate together. +//! +//! [tasks]: crate::task +//! +//! # Message passing +//! +//! The most common form of synchronization in a Tokio program is message +//! passing. Two tasks operate independently and send messages to each other to +//! synchronize. Doing so has the advantage of avoiding shared state. +//! +//! Message passing is implemented using channels. A channel supports sending a +//! message from one producer task to one or more consumer tasks. There are a +//! few flavors of channels provided by Tokio. Each channel flavor supports +//! different message passing patterns. When a channel supports multiple +//! producers, many separate tasks may **send** messages. When a channel +//! supports multiple consumers, many different separate tasks may **receive** +//! messages. +//! +//! Tokio provides many different channel flavors as different message passing +//! patterns are best handled with different implementations. +//! +//! ## `oneshot` channel +//! +//! The [`oneshot` channel][oneshot] supports sending a **single** value from a +//! single producer to a single consumer. This channel is usually used to send +//! the result of a computation to a waiter. +//! +//! **Example:** using a [`oneshot` channel][oneshot] to receive the result of a +//! computation. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! async fn some_computation() -> String { +//! "represents the result of the computation".to_string() +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel(); +//! +//! tokio::spawn(async move { +//! let res = some_computation().await; +//! tx.send(res).unwrap(); +//! }); +//! +//! // Do other work while the computation is happening in the background +//! +//! // Wait for the computation result +//! let res = rx.await.unwrap(); +//! } +//! ``` +//! +//! Note, if the task produces a computation result as its final +//! action before terminating, the [`JoinHandle`] can be used to +//! receive that value instead of allocating resources for the +//! `oneshot` channel. Awaiting on [`JoinHandle`] returns `Result`. If +//! the task panics, the `Joinhandle` yields `Err` with the panic +//! cause. +//! +//! **Example:** +//! +//! ``` +//! async fn some_computation() -> String { +//! "the result of the computation".to_string() +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let join_handle = tokio::spawn(async move { +//! some_computation().await +//! }); +//! +//! // Do other work while the computation is happening in the background +//! +//! // Wait for the computation result +//! let res = join_handle.await.unwrap(); +//! } +//! ``` +//! +//! [oneshot]: oneshot +//! [`JoinHandle`]: crate::task::JoinHandle +//! +//! ## `mpsc` channel +//! +//! The [`mpsc` channel][mpsc] supports sending **many** values from **many** +//! producers to a single consumer. This channel is often used to send work to a +//! task or to receive the result of many computations. +//! +//! **Example:** using an mpsc to incrementally stream the results of a series +//! of computations. +//! +//! ``` +//! use tokio::sync::mpsc; +//! +//! async fn some_computation(input: u32) -> String { +//! format!("the result of computation {}", input) +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx) = mpsc::channel(100); +//! +//! tokio::spawn(async move { +//! for i in 0..10 { +//! let res = some_computation(i).await; +//! tx.send(res).await.unwrap(); +//! } +//! }); +//! +//! while let Some(res) = rx.recv().await { +//! println!("got = {}", res); +//! } +//! } +//! ``` +//! +//! The argument to `mpsc::channel` is the channel capacity. This is the maximum +//! number of values that can be stored in the channel pending receipt at any +//! given time. Properly setting this value is key in implementing robust +//! programs as the channel capacity plays a critical part in handling back +//! pressure. +//! +//! A common concurrency pattern for resource management is to spawn a task +//! dedicated to managing that resource and using message passing between other +//! tasks to interact with the resource. The resource may be anything that may +//! not be concurrently used. Some examples include a socket and program state. +//! For example, if multiple tasks need to send data over a single socket, spawn +//! a task to manage the socket and use a channel to synchronize. +//! +//! **Example:** sending data from many tasks over a single socket using message +//! passing. +//! +//! ```no_run +//! use tokio::io::{self, AsyncWriteExt}; +//! use tokio::net::TcpStream; +//! use tokio::sync::mpsc; +//! +//! #[tokio::main] +//! async fn main() -> io::Result<()> { +//! let mut socket = TcpStream::connect("www.example.com:1234").await?; +//! let (tx, mut rx) = mpsc::channel(100); +//! +//! for _ in 0..10 { +//! // Each task needs its own `tx` handle. This is done by cloning the +//! // original handle. +//! let tx = tx.clone(); +//! +//! tokio::spawn(async move { +//! tx.send(&b"data to write"[..]).await.unwrap(); +//! }); +//! } +//! +//! // The `rx` half of the channel returns `None` once **all** `tx` clones +//! // drop. To ensure `None` is returned, drop the handle owned by the +//! // current task. If this `tx` handle is not dropped, there will always +//! // be a single outstanding `tx` handle. +//! drop(tx); +//! +//! while let Some(res) = rx.recv().await { +//! socket.write_all(res).await?; +//! } +//! +//! Ok(()) +//! } +//! ``` +//! +//! The [`mpsc`][mpsc] and [`oneshot`][oneshot] channels can be combined to +//! provide a request / response type synchronization pattern with a shared +//! resource. A task is spawned to synchronize a resource and waits on commands +//! received on a [`mpsc`][mpsc] channel. Each command includes a +//! [`oneshot`][oneshot] `Sender` on which the result of the command is sent. +//! +//! **Example:** use a task to synchronize a `u64` counter. Each task sends an +//! "fetch and increment" command. The counter value **before** the increment is +//! sent over the provided `oneshot` channel. +//! +//! ``` +//! use tokio::sync::{oneshot, mpsc}; +//! use Command::Increment; +//! +//! enum Command { +//! Increment, +//! // Other commands can be added here +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (cmd_tx, mut cmd_rx) = mpsc::channel::<(Command, oneshot::Sender<u64>)>(100); +//! +//! // Spawn a task to manage the counter +//! tokio::spawn(async move { +//! let mut counter: u64 = 0; +//! +//! while let Some((cmd, response)) = cmd_rx.recv().await { +//! match cmd { +//! Increment => { +//! let prev = counter; +//! counter += 1; +//! response.send(prev).unwrap(); +//! } +//! } +//! } +//! }); +//! +//! let mut join_handles = vec![]; +//! +//! // Spawn tasks that will send the increment command. +//! for _ in 0..10 { +//! let cmd_tx = cmd_tx.clone(); +//! +//! join_handles.push(tokio::spawn(async move { +//! let (resp_tx, resp_rx) = oneshot::channel(); +//! +//! cmd_tx.send((Increment, resp_tx)).await.ok().unwrap(); +//! let res = resp_rx.await.unwrap(); +//! +//! println!("previous value = {}", res); +//! })); +//! } +//! +//! // Wait for all tasks to complete +//! for join_handle in join_handles.drain(..) { +//! join_handle.await.unwrap(); +//! } +//! } +//! ``` +//! +//! [mpsc]: mpsc +//! +//! ## `broadcast` channel +//! +//! The [`broadcast` channel] supports sending **many** values from +//! **many** producers to **many** consumers. Each consumer will receive +//! **each** value. This channel can be used to implement "fan out" style +//! patterns common with pub / sub or "chat" systems. +//! +//! This channel tends to be used less often than `oneshot` and `mpsc` but still +//! has its use cases. +//! +//! Basic usage +//! +//! ``` +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, mut rx1) = broadcast::channel(16); +//! let mut rx2 = tx.subscribe(); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx1.recv().await.unwrap(), 10); +//! assert_eq!(rx1.recv().await.unwrap(), 20); +//! }); +//! +//! tokio::spawn(async move { +//! assert_eq!(rx2.recv().await.unwrap(), 10); +//! assert_eq!(rx2.recv().await.unwrap(), 20); +//! }); +//! +//! tx.send(10).unwrap(); +//! tx.send(20).unwrap(); +//! } +//! ``` +//! +//! [`broadcast` channel]: crate::sync::broadcast +//! +//! ## `watch` channel +//! +//! The [`watch` channel] supports sending **many** values from a **single** +//! producer to **many** consumers. However, only the **most recent** value is +//! stored in the channel. Consumers are notified when a new value is sent, but +//! there is no guarantee that consumers will see **all** values. +//! +//! The [`watch` channel] is similar to a [`broadcast` channel] with capacity 1. +//! +//! Use cases for the [`watch` channel] include broadcasting configuration +//! changes or signalling program state changes, such as transitioning to +//! shutdown. +//! +//! **Example:** use a [`watch` channel] to notify tasks of configuration +//! changes. In this example, a configuration file is checked periodically. When +//! the file changes, the configuration changes are signalled to consumers. +//! +//! ``` +//! use tokio::sync::watch; +//! use tokio::time::{self, Duration, Instant}; +//! +//! use std::io; +//! +//! #[derive(Debug, Clone, Eq, PartialEq)] +//! struct Config { +//! timeout: Duration, +//! } +//! +//! impl Config { +//! async fn load_from_file() -> io::Result<Config> { +//! // file loading and deserialization logic here +//! # Ok(Config { timeout: Duration::from_secs(1) }) +//! } +//! } +//! +//! async fn my_async_operation() { +//! // Do something here +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! // Load initial configuration value +//! let mut config = Config::load_from_file().await.unwrap(); +//! +//! // Create the watch channel, initialized with the loaded configuration +//! let (tx, rx) = watch::channel(config.clone()); +//! +//! // Spawn a task to monitor the file. +//! tokio::spawn(async move { +//! loop { +//! // Wait 10 seconds between checks +//! time::sleep(Duration::from_secs(10)).await; +//! +//! // Load the configuration file +//! let new_config = Config::load_from_file().await.unwrap(); +//! +//! // If the configuration changed, send the new config value +//! // on the watch channel. +//! if new_config != config { +//! tx.send(new_config.clone()).unwrap(); +//! config = new_config; +//! } +//! } +//! }); +//! +//! let mut handles = vec![]; +//! +//! // Spawn tasks that runs the async operation for at most `timeout`. If +//! // the timeout elapses, restart the operation. +//! // +//! // The task simultaneously watches the `Config` for changes. When the +//! // timeout duration changes, the timeout is updated without restarting +//! // the in-flight operation. +//! for _ in 0..5 { +//! // Clone a config watch handle for use in this task +//! let mut rx = rx.clone(); +//! +//! let handle = tokio::spawn(async move { +//! // Start the initial operation and pin the future to the stack. +//! // Pinning to the stack is required to resume the operation +//! // across multiple calls to `select!` +//! let op = my_async_operation(); +//! tokio::pin!(op); +//! +//! // Get the initial config value +//! let mut conf = rx.borrow().clone(); +//! +//! let mut op_start = Instant::now(); +//! let sleep = time::sleep_until(op_start + conf.timeout); +//! tokio::pin!(sleep); +//! +//! loop { +//! tokio::select! { +//! _ = &mut sleep => { +//! // The operation elapsed. Restart it +//! op.set(my_async_operation()); +//! +//! // Track the new start time +//! op_start = Instant::now(); +//! +//! // Restart the timeout +//! sleep.set(time::sleep_until(op_start + conf.timeout)); +//! } +//! _ = rx.changed() => { +//! conf = rx.borrow().clone(); +//! +//! // The configuration has been updated. Update the +//! // `sleep` using the new `timeout` value. +//! sleep.as_mut().reset(op_start + conf.timeout); +//! } +//! _ = &mut op => { +//! // The operation completed! +//! return +//! } +//! } +//! } +//! }); +//! +//! handles.push(handle); +//! } +//! +//! for handle in handles.drain(..) { +//! handle.await.unwrap(); +//! } +//! } +//! ``` +//! +//! [`watch` channel]: mod@crate::sync::watch +//! [`broadcast` channel]: mod@crate::sync::broadcast +//! +//! # State synchronization +//! +//! The remaining synchronization primitives focus on synchronizing state. +//! These are asynchronous equivalents to versions provided by `std`. They +//! operate in a similar way as their `std` counterparts but will wait +//! asynchronously instead of blocking the thread. +//! +//! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to +//! reach a point in the program, before continuing execution all together. +//! +//! * [`Mutex`](Mutex) Mutual Exclusion mechanism, which ensures that at most +//! one thread at a time is able to access some data. +//! +//! * [`Notify`](Notify) Basic task notification. `Notify` supports notifying a +//! receiving task without sending data. In this case, the task wakes up and +//! resumes processing. +//! +//! * [`RwLock`](RwLock) Provides a mutual exclusion mechanism which allows +//! multiple readers at the same time, while allowing only one writer at a +//! time. In some cases, this can be more efficient than a mutex. +//! +//! * [`Semaphore`](Semaphore) Limits the amount of concurrency. A semaphore +//! holds a number of permits, which tasks may request in order to enter a +//! critical section. Semaphores are useful for implementing limiting or +//! bounding of any kind. + +cfg_sync! { + /// Named future types. + pub mod futures { + pub use super::notify::Notified; + } + + mod barrier; + pub use barrier::{Barrier, BarrierWaitResult}; + + pub mod broadcast; + + pub mod mpsc; + + mod mutex; + pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard, MappedMutexGuard}; + + pub(crate) mod notify; + pub use notify::Notify; + + pub mod oneshot; + + pub(crate) mod batch_semaphore; + pub use batch_semaphore::{AcquireError, TryAcquireError}; + + mod semaphore; + pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; + + mod rwlock; + pub use rwlock::RwLock; + pub use rwlock::owned_read_guard::OwnedRwLockReadGuard; + pub use rwlock::owned_write_guard::OwnedRwLockWriteGuard; + pub use rwlock::owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; + pub use rwlock::read_guard::RwLockReadGuard; + pub use rwlock::write_guard::RwLockWriteGuard; + pub use rwlock::write_guard_mapped::RwLockMappedWriteGuard; + + mod task; + pub(crate) use task::AtomicWaker; + + mod once_cell; + pub use self::once_cell::{OnceCell, SetError}; + + pub mod watch; +} + +cfg_not_sync! { + cfg_fs! { + pub(crate) mod batch_semaphore; + mod mutex; + pub(crate) use mutex::Mutex; + } + + #[cfg(any(feature = "rt", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod notify; + + #[cfg(any(feature = "rt", all(windows, feature = "process")))] + pub(crate) mod oneshot; + + cfg_atomic_waker_impl! { + mod task; + pub(crate) use task::AtomicWaker; + } + + #[cfg(any(feature = "signal", all(unix, feature = "process")))] + pub(crate) mod watch; +} + +/// Unit tests +#[cfg(test)] +mod tests; diff --git a/third_party/rust/tokio/src/sync/mpsc/block.rs b/third_party/rust/tokio/src/sync/mpsc/block.rs new file mode 100644 index 0000000000..58f4a9f6cc --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/block.rs @@ -0,0 +1,385 @@ +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; + +use std::mem::MaybeUninit; +use std::ops; +use std::ptr::{self, NonNull}; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Release}; + +/// A block in a linked list. +/// +/// Each block in the list can hold up to `BLOCK_CAP` messages. +pub(crate) struct Block<T> { + /// The start index of this block. + /// + /// Slots in this block have indices in `start_index .. start_index + BLOCK_CAP`. + start_index: usize, + + /// The next block in the linked list. + next: AtomicPtr<Block<T>>, + + /// Bitfield tracking slots that are ready to have their values consumed. + ready_slots: AtomicUsize, + + /// The observed `tail_position` value *after* the block has been passed by + /// `block_tail`. + observed_tail_position: UnsafeCell<usize>, + + /// Array containing values pushed into the block. Values are stored in a + /// continuous array in order to improve cache line behavior when reading. + /// The values must be manually dropped. + values: Values<T>, +} + +pub(crate) enum Read<T> { + Value(T), + Closed, +} + +struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]); + +use super::BLOCK_CAP; + +/// Masks an index to get the block identifier. +const BLOCK_MASK: usize = !(BLOCK_CAP - 1); + +/// Masks an index to get the value offset in a block. +const SLOT_MASK: usize = BLOCK_CAP - 1; + +/// Flag tracking that a block has gone through the sender's release routine. +/// +/// When this is set, the receiver may consider freeing the block. +const RELEASED: usize = 1 << BLOCK_CAP; + +/// Flag tracking all senders dropped. +/// +/// When this flag is set, the send half of the channel has closed. +const TX_CLOSED: usize = RELEASED << 1; + +/// Mask covering all bits used to track slot readiness. +const READY_MASK: usize = RELEASED - 1; + +/// Returns the index of the first slot in the block referenced by `slot_index`. +#[inline(always)] +pub(crate) fn start_index(slot_index: usize) -> usize { + BLOCK_MASK & slot_index +} + +/// Returns the offset into the block referenced by `slot_index`. +#[inline(always)] +pub(crate) fn offset(slot_index: usize) -> usize { + SLOT_MASK & slot_index +} + +impl<T> Block<T> { + pub(crate) fn new(start_index: usize) -> Block<T> { + Block { + // The absolute index in the channel of the first slot in the block. + start_index, + + // Pointer to the next block in the linked list. + next: AtomicPtr::new(ptr::null_mut()), + + ready_slots: AtomicUsize::new(0), + + observed_tail_position: UnsafeCell::new(0), + + // Value storage + values: unsafe { Values::uninitialized() }, + } + } + + /// Returns `true` if the block matches the given index. + pub(crate) fn is_at_index(&self, index: usize) -> bool { + debug_assert!(offset(index) == 0); + self.start_index == index + } + + /// Returns the number of blocks between `self` and the block at the + /// specified index. + /// + /// `start_index` must represent a block *after* `self`. + pub(crate) fn distance(&self, other_index: usize) -> usize { + debug_assert!(offset(other_index) == 0); + other_index.wrapping_sub(self.start_index) / BLOCK_CAP + } + + /// Reads the value at the given offset. + /// + /// Returns `None` if the slot is empty. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * No concurrent access to the slot. + pub(crate) unsafe fn read(&self, slot_index: usize) -> Option<Read<T>> { + let offset = offset(slot_index); + + let ready_bits = self.ready_slots.load(Acquire); + + if !is_ready(ready_bits, offset) { + if is_tx_closed(ready_bits) { + return Some(Read::Closed); + } + + return None; + } + + // Get the value + let value = self.values[offset].with(|ptr| ptr::read(ptr)); + + Some(Read::Value(value.assume_init())) + } + + /// Writes a value to the block at the given offset. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * The slot is empty. + /// * No concurrent access to the slot. + pub(crate) unsafe fn write(&self, slot_index: usize, value: T) { + // Get the offset into the block + let slot_offset = offset(slot_index); + + self.values[slot_offset].with_mut(|ptr| { + ptr::write(ptr, MaybeUninit::new(value)); + }); + + // Release the value. After this point, the slot ref may no longer + // be used. It is possible for the receiver to free the memory at + // any point. + self.set_ready(slot_offset); + } + + /// Signal to the receiver that the sender half of the list is closed. + pub(crate) unsafe fn tx_close(&self) { + self.ready_slots.fetch_or(TX_CLOSED, Release); + } + + /// Resets the block to a blank state. This enables reusing blocks in the + /// channel. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * All slots are empty. + /// * The caller holds a unique pointer to the block. + pub(crate) unsafe fn reclaim(&mut self) { + self.start_index = 0; + self.next = AtomicPtr::new(ptr::null_mut()); + self.ready_slots = AtomicUsize::new(0); + } + + /// Releases the block to the rx half for freeing. + /// + /// This function is called by the tx half once it can be guaranteed that no + /// more senders will attempt to access the block. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * The block will no longer be accessed by any sender. + pub(crate) unsafe fn tx_release(&self, tail_position: usize) { + // Track the observed tail_position. Any sender targeting a greater + // tail_position is guaranteed to not access this block. + self.observed_tail_position + .with_mut(|ptr| *ptr = tail_position); + + // Set the released bit, signalling to the receiver that it is safe to + // free the block's memory as soon as all slots **prior** to + // `observed_tail_position` have been filled. + self.ready_slots.fetch_or(RELEASED, Release); + } + + /// Mark a slot as ready + fn set_ready(&self, slot: usize) { + let mask = 1 << slot; + self.ready_slots.fetch_or(mask, Release); + } + + /// Returns `true` when all slots have their `ready` bits set. + /// + /// This indicates that the block is in its final state and will no longer + /// be mutated. + /// + /// # Implementation + /// + /// The implementation walks each slot checking the `ready` flag. It might + /// be that it would make more sense to coalesce ready flags as bits in a + /// single atomic cell. However, this could have negative impact on cache + /// behavior as there would be many more mutations to a single slot. + pub(crate) fn is_final(&self) -> bool { + self.ready_slots.load(Acquire) & READY_MASK == READY_MASK + } + + /// Returns the `observed_tail_position` value, if set + pub(crate) fn observed_tail_position(&self) -> Option<usize> { + if 0 == RELEASED & self.ready_slots.load(Acquire) { + None + } else { + Some(self.observed_tail_position.with(|ptr| unsafe { *ptr })) + } + } + + /// Loads the next block + pub(crate) fn load_next(&self, ordering: Ordering) -> Option<NonNull<Block<T>>> { + let ret = NonNull::new(self.next.load(ordering)); + + debug_assert!(unsafe { + ret.map(|block| block.as_ref().start_index == self.start_index.wrapping_add(BLOCK_CAP)) + .unwrap_or(true) + }); + + ret + } + + /// Pushes `block` as the next block in the link. + /// + /// Returns Ok if successful, otherwise, a pointer to the next block in + /// the list is returned. + /// + /// This requires that the next pointer is null. + /// + /// # Ordering + /// + /// This performs a compare-and-swap on `next` using AcqRel ordering. + /// + /// # Safety + /// + /// To maintain safety, the caller must ensure: + /// + /// * `block` is not freed until it has been removed from the list. + pub(crate) unsafe fn try_push( + &self, + block: &mut NonNull<Block<T>>, + success: Ordering, + failure: Ordering, + ) -> Result<(), NonNull<Block<T>>> { + block.as_mut().start_index = self.start_index.wrapping_add(BLOCK_CAP); + + let next_ptr = self + .next + .compare_exchange(ptr::null_mut(), block.as_ptr(), success, failure) + .unwrap_or_else(|x| x); + + match NonNull::new(next_ptr) { + Some(next_ptr) => Err(next_ptr), + None => Ok(()), + } + } + + /// Grows the `Block` linked list by allocating and appending a new block. + /// + /// The next block in the linked list is returned. This may or may not be + /// the one allocated by the function call. + /// + /// # Implementation + /// + /// It is assumed that `self.next` is null. A new block is allocated with + /// `start_index` set to be the next block. A compare-and-swap is performed + /// with AcqRel memory ordering. If the compare-and-swap is successful, the + /// newly allocated block is released to other threads walking the block + /// linked list. If the compare-and-swap fails, the current thread acquires + /// the next block in the linked list, allowing the current thread to access + /// the slots. + pub(crate) fn grow(&self) -> NonNull<Block<T>> { + // Create the new block. It is assumed that the block will become the + // next one after `&self`. If this turns out to not be the case, + // `start_index` is updated accordingly. + let new_block = Box::new(Block::new(self.start_index + BLOCK_CAP)); + + let mut new_block = unsafe { NonNull::new_unchecked(Box::into_raw(new_block)) }; + + // Attempt to store the block. The first compare-and-swap attempt is + // "unrolled" due to minor differences in logic + // + // `AcqRel` is used as the ordering **only** when attempting the + // compare-and-swap on self.next. + // + // If the compare-and-swap fails, then the actual value of the cell is + // returned from this function and accessed by the caller. Given this, + // the memory must be acquired. + // + // `Release` ensures that the newly allocated block is available to + // other threads acquiring the next pointer. + let next = NonNull::new( + self.next + .compare_exchange(ptr::null_mut(), new_block.as_ptr(), AcqRel, Acquire) + .unwrap_or_else(|x| x), + ); + + let next = match next { + Some(next) => next, + None => { + // The compare-and-swap succeeded and the newly allocated block + // is successfully pushed. + return new_block; + } + }; + + // There already is a next block in the linked list. The newly allocated + // block could be dropped and the discovered next block returned; + // however, that would be wasteful. Instead, the linked list is walked + // by repeatedly attempting to compare-and-swap the pointer into the + // `next` register until the compare-and-swap succeed. + // + // Care is taken to update new_block's start_index field as appropriate. + + let mut curr = next; + + // TODO: Should this iteration be capped? + loop { + let actual = unsafe { curr.as_ref().try_push(&mut new_block, AcqRel, Acquire) }; + + curr = match actual { + Ok(_) => { + return next; + } + Err(curr) => curr, + }; + + crate::loom::thread::yield_now(); + } + } +} + +/// Returns `true` if the specified slot has a value ready to be consumed. +fn is_ready(bits: usize, slot: usize) -> bool { + let mask = 1 << slot; + mask == mask & bits +} + +/// Returns `true` if the closed flag has been set. +fn is_tx_closed(bits: usize) -> bool { + TX_CLOSED == bits & TX_CLOSED +} + +impl<T> Values<T> { + unsafe fn uninitialized() -> Values<T> { + let mut vals = MaybeUninit::uninit(); + + // When fuzzing, `UnsafeCell` needs to be initialized. + if_loom! { + let p = vals.as_mut_ptr() as *mut UnsafeCell<MaybeUninit<T>>; + for i in 0..BLOCK_CAP { + p.add(i) + .write(UnsafeCell::new(MaybeUninit::uninit())); + } + } + + Values(vals.assume_init()) + } +} + +impl<T> ops::Index<usize> for Values<T> { + type Output = UnsafeCell<MaybeUninit<T>>; + + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/bounded.rs b/third_party/rust/tokio/src/sync/mpsc/bounded.rs new file mode 100644 index 0000000000..ddded8ebb3 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/bounded.rs @@ -0,0 +1,1197 @@ +use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; +use crate::sync::mpsc::chan; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; + +cfg_time! { + use crate::sync::mpsc::error::SendTimeoutError; + use crate::time::Duration; +} + +use std::fmt; +use std::task::{Context, Poll}; + +/// Sends values to the associated `Receiver`. +/// +/// Instances are created by the [`channel`](channel) function. +/// +/// To convert the `Sender` into a `Sink` or use it in a poll function, you can +/// use the [`PollSender`] utility. +/// +/// [`PollSender`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSender.html +pub struct Sender<T> { + chan: chan::Tx<T, Semaphore>, +} + +/// Permits to send one value into the channel. +/// +/// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`] +/// and are used to guarantee channel capacity before generating a message to send. +/// +/// [`Sender::reserve()`]: Sender::reserve +/// [`Sender::try_reserve()`]: Sender::try_reserve +pub struct Permit<'a, T> { + chan: &'a chan::Tx<T, Semaphore>, +} + +/// Owned permit to send one value into the channel. +/// +/// This is identical to the [`Permit`] type, except that it moves the sender +/// rather than borrowing it. +/// +/// `OwnedPermit` values are returned by [`Sender::reserve_owned()`] and +/// [`Sender::try_reserve_owned()`] and are used to guarantee channel capacity +/// before generating a message to send. +/// +/// [`Permit`]: Permit +/// [`Sender::reserve_owned()`]: Sender::reserve_owned +/// [`Sender::try_reserve_owned()`]: Sender::try_reserve_owned +pub struct OwnedPermit<T> { + chan: Option<chan::Tx<T, Semaphore>>, +} + +/// Receives values from the associated `Sender`. +/// +/// Instances are created by the [`channel`](channel) function. +/// +/// This receiver can be turned into a `Stream` using [`ReceiverStream`]. +/// +/// [`ReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReceiverStream.html +pub struct Receiver<T> { + /// The channel receiver. + chan: chan::Rx<T, Semaphore>, +} + +/// Creates a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to send new messages will wait until a message is +/// received from the channel. The provided buffer capacity must be at least 1. +/// +/// All data sent on `Sender` will become available on `Receiver` in the same +/// order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple code +/// locations. Only one `Receiver` is supported. +/// +/// If the `Receiver` is disconnected while trying to `send`, the `send` method +/// will return a `SendError`. Similarly, if `Sender` is disconnected while +/// trying to `recv`, the `recv` method will return `None`. +/// +/// # Panics +/// +/// Panics if the buffer capacity is 0. +/// +/// # Examples +/// +/// ```rust +/// use tokio::sync::mpsc; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx) = mpsc::channel(100); +/// +/// tokio::spawn(async move { +/// for i in 0..10 { +/// if let Err(_) = tx.send(i).await { +/// println!("receiver dropped"); +/// return; +/// } +/// } +/// }); +/// +/// while let Some(i) = rx.recv().await { +/// println!("got = {}", i); +/// } +/// } +/// ``` +pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) { + assert!(buffer > 0, "mpsc bounded channel requires buffer > 0"); + let semaphore = (semaphore::Semaphore::new(buffer), buffer); + let (tx, rx) = chan::channel(semaphore); + + let tx = Sender::new(tx); + let rx = Receiver::new(rx); + + (tx, rx) +} + +/// Channel semaphore is a tuple of the semaphore implementation and a `usize` +/// representing the channel bound. +type Semaphore = (semaphore::Semaphore, usize); + +impl<T> Receiver<T> { + pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> Receiver<T> { + Receiver { chan } + } + + /// Receives the next value for this receiver. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. Note that if [`close`] is called, but there are + /// still outstanding [`Permits`] from before it was closed, the channel is + /// not considered closed by `recv` until the permits are released. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`close`]: Self::close + /// [`Permits`]: struct@crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tokio::spawn(async move { + /// tx.send("hello").await.unwrap(); + /// }); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(None, rx.recv().await); + /// } + /// ``` + /// + /// Values are buffered: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// tx.send("world").await.unwrap(); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(Some("world"), rx.recv().await); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + use crate::future::poll_fn; + poll_fn(|cx| self.chan.recv(cx)).await + } + + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").await.unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will block until a message is sent or + /// the channel is closed. + /// + /// This method is intended for use cases where you are sending from + /// asynchronous code to synchronous code, and will work even if the sender + /// is not using [`blocking_send`] to send the message. + /// + /// Note that if [`close`] is called, but there are still outstanding + /// [`Permits`] from before it was closed, the channel is not considered + /// closed by `blocking_recv` until the permits are released. + /// + /// [`close`]: Self::close + /// [`Permits`]: struct@crate::sync::mpsc::Permit + /// [`blocking_send`]: fn@crate::sync::mpsc::Sender::blocking_send + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(10); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// Runtime::new() + /// .unwrap() + /// .block_on(async move { + /// let _ = tx.send(10).await; + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) + } + + /// Closes the receiving half of a channel without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// To guarantee that no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. If there are + /// outstanding [`Permit`] or [`OwnedPermit`] values, the `recv` method will + /// not return `None` until those are released. + /// + /// [`Permit`]: Permit + /// [`OwnedPermit`]: OwnedPermit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(20); + /// + /// tokio::spawn(async move { + /// let mut i = 0; + /// while let Ok(permit) = tx.reserve().await { + /// permit.send(i); + /// i += 1; + /// } + /// }); + /// + /// rx.close(); + /// + /// while let Some(msg) = rx.recv().await { + /// println!("got {}", msg); + /// } + /// + /// // Channel closed and no messages are lost. + /// } + /// ``` + pub fn close(&mut self) { + self.chan.close(); + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } +} + +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + +impl<T> Unpin for Receiver<T> {} + +impl<T> Sender<T> { + pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> Sender<T> { + Sender { chan } + } + + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. Note that a return + /// value of `Err` means that the data will never be received, but a return + /// value of `Ok` does not mean that the data will be received. It is + /// possible for the corresponding receiver to hang up immediately after + /// this function returns `Ok`. + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + /// + /// # Cancel safety + /// + /// If `send` is used as the event in a [`tokio::select!`](crate::select) + /// statement and some other branch completes first, then it is guaranteed + /// that the message was not sent. + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `send` makes you lose your place in the queue. + /// + /// # Examples + /// + /// In the following example, each call to `send` will block until the + /// previously sent value was received. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// tokio::spawn(async move { + /// for i in 0..10 { + /// if let Err(_) = tx.send(i).await { + /// println!("receiver dropped"); + /// return; + /// } + /// } + /// }); + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// } + /// } + /// ``` + pub async fn send(&self, value: T) -> Result<(), SendError<T>> { + match self.reserve().await { + Ok(permit) => { + permit.send(value); + Ok(()) + } + Err(_) => Err(SendError(value)), + } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::channel::<()>(1); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + /// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await + } + + /// Attempts to immediately send a message on this `Sender` + /// + /// This method differs from [`send`] by returning immediately if the channel's + /// buffer is full or no receiver is waiting to acquire some data. Compared + /// with [`send`], this function has two failure cases instead of one (one for + /// disconnection, one for a full buffer). + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`channel`], then an + /// error is returned. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`send`]: Sender::send + /// [`channel`]: channel + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create a channel with buffer size 1 + /// let (tx1, mut rx) = mpsc::channel(1); + /// let tx2 = tx1.clone(); + /// + /// tokio::spawn(async move { + /// tx1.send(1).await.unwrap(); + /// tx1.send(2).await.unwrap(); + /// // task waits until the receiver receives a value. + /// }); + /// + /// tokio::spawn(async move { + /// // This will return an error and send + /// // no message if the buffer is full + /// let _ = tx2.try_send(3); + /// }); + /// + /// let mut msg; + /// msg = rx.recv().await.unwrap(); + /// println!("message {} received", msg); + /// + /// msg = rx.recv().await.unwrap(); + /// println!("message {} received", msg); + /// + /// // Third message may have never been sent + /// match rx.recv().await { + /// Some(msg) => println!("message {} received", msg), + /// None => println!("the third message was never sent"), + /// } + /// } + /// ``` + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), + } + + // Send the message + self.chan.send(message); + Ok(()) + } + + /// Sends a value, waiting until there is capacity, but only for a limited time. + /// + /// Shares the same success and error conditions as [`send`], adding one more + /// condition for an unsuccessful send, which is when the provided timeout has + /// elapsed, and there is no capacity available. + /// + /// [`send`]: Sender::send + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] having been dropped, + /// the function returns an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + /// + /// # Panics + /// + /// This function panics if it is called outside the context of a Tokio + /// runtime [with time enabled](crate::runtime::Builder::enable_time). + /// + /// # Examples + /// + /// In the following example, each call to `send_timeout` will block until the + /// previously sent value was received, unless the timeout has elapsed. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// use tokio::time::{sleep, Duration}; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// tokio::spawn(async move { + /// for i in 0..10 { + /// if let Err(e) = tx.send_timeout(i, Duration::from_millis(100)).await { + /// println!("send error: #{:?}", e); + /// return; + /// } + /// } + /// }); + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// sleep(Duration::from_millis(200)).await; + /// } + /// } + /// ``` + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + pub async fn send_timeout( + &self, + value: T, + timeout: Duration, + ) -> Result<(), SendTimeoutError<T>> { + let permit = match crate::time::timeout(timeout, self.reserve()).await { + Err(_) => { + return Err(SendTimeoutError::Timeout(value)); + } + Ok(Err(_)) => { + return Err(SendTimeoutError::Closed(value)); + } + Ok(Ok(permit)) => permit, + }; + + permit.send(value); + Ok(()) + } + + /// Blocking send to call outside of asynchronous contexts. + /// + /// This method is intended for use cases where you are sending from + /// synchronous code to asynchronous code, and will work even if the + /// receiver is not using [`blocking_recv`] to receive the message. + /// + /// [`blocking_recv`]: fn@crate::sync::mpsc::Receiver::blocking_recv + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(1); + /// + /// let sync_code = thread::spawn(move || { + /// tx.blocking_send(10).unwrap(); + /// }); + /// + /// Runtime::new().unwrap().block_on(async move { + /// assert_eq!(Some(10), rx.recv().await); + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> { + crate::future::block_on(self.send(value)) + } + + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. + /// + /// [`Receiver`]: crate::sync::mpsc::Receiver + /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(42); + /// assert!(!tx.is_closed()); + /// + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Waits for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. A [`Permit`] is returned to track + /// the reserved capacity. The [`send`] function on [`Permit`] consumes the + /// reserved capacity. + /// + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve` makes you lose your place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> { + self.reserve_inner().await?; + Ok(Permit { chan: &self.chan }) + } + + /// Waits for channel capacity, moving the `Sender` and returning an owned + /// permit. Once capacity to send one message is available, it is reserved + /// for the caller. + /// + /// This moves the sender _by value_, and returns an owned permit that can + /// be used to send a message into the channel. Unlike [`Sender::reserve`], + /// this method may be used in cases where the permit must be valid for the + /// `'static` lifetime. `Sender`s may be cloned cheaply (`Sender::clone` is + /// essentially a reference count increment, comparable to [`Arc::clone`]), + /// so when multiple [`OwnedPermit`]s are needed or the `Sender` cannot be + /// moved, it can be cloned prior to calling `reserve_owned`. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. An [`OwnedPermit`] is returned to + /// track the reserved capacity. The [`send`] function on [`OwnedPermit`] + /// consumes the reserved capacity. + /// + /// Dropping the [`OwnedPermit`] without sending a message releases the + /// capacity back to the channel. + /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve_owned` makes you lose your place in the queue. + /// + /// # Examples + /// Sending a message using an [`OwnedPermit`]: + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity, moving the sender. + /// let permit = tx.reserve_owned().await.unwrap(); + /// + /// // Send a message, consuming the permit and returning + /// // the moved sender. + /// let tx = permit.send(123); + /// + /// // The value sent on the permit is received. + /// assert_eq!(rx.recv().await.unwrap(), 123); + /// + /// // The sender can now be used again. + /// tx.send(456).await.unwrap(); + /// } + /// ``` + /// + /// When multiple [`OwnedPermit`]s are needed, or the sender cannot be moved + /// by value, it can be inexpensively cloned before calling `reserve_owned`: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Clone the sender and reserve capacity. + /// let permit = tx.clone().reserve_owned().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds. + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + /// + /// [`Sender::reserve`]: Sender::reserve + /// [`OwnedPermit`]: OwnedPermit + /// [`send`]: OwnedPermit::send + /// [`Arc::clone`]: std::sync::Arc::clone + pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> { + self.reserve_inner().await?; + Ok(OwnedPermit { + chan: Some(self.chan), + }) + } + + async fn reserve_inner(&self) -> Result<(), SendError<()>> { + match self.chan.semaphore().0.acquire(1).await { + Ok(_) => Ok(()), + Err(_) => Err(SendError(())), + } + } + + /// Tries to acquire a slot in the channel without waiting for the slot to become + /// available. + /// + /// If the channel is full this function will return [`TrySendError`], otherwise + /// if there is a slot available it will return a [`Permit`] that will then allow you + /// to [`send`] on the channel with a guaranteed slot. This function is similar to + /// [`reserve`] except it does not await for the slot to become available. + /// + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// [`reserve`]: Sender::reserve + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.try_reserve().unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Trying to reserve an additional slot on the `tx` will + /// // fail because there is no capacity. + /// assert!(tx.try_reserve().is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// } + /// ``` + pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())), + } + + Ok(Permit { chan: &self.chan }) + } + + /// Tries to acquire a slot in the channel without waiting for the slot to become + /// available, returning an owned permit. + /// + /// This moves the sender _by value_, and returns an owned permit that can + /// be used to send a message into the channel. Unlike [`Sender::try_reserve`], + /// this method may be used in cases where the permit must be valid for the + /// `'static` lifetime. `Sender`s may be cloned cheaply (`Sender::clone` is + /// essentially a reference count increment, comparable to [`Arc::clone`]), + /// so when multiple [`OwnedPermit`]s are needed or the `Sender` cannot be + /// moved, it can be cloned prior to calling `try_reserve_owned`. + /// + /// If the channel is full this function will return a [`TrySendError`]. + /// Since the sender is taken by value, the `TrySendError` returned in this + /// case contains the sender, so that it may be used again. Otherwise, if + /// there is a slot available, this method will return an [`OwnedPermit`] + /// that can then be used to [`send`] on the channel with a guaranteed slot. + /// This function is similar to [`reserve_owned`] except it does not await + /// for the slot to become available. + /// + /// Dropping the [`OwnedPermit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`OwnedPermit`]: OwnedPermit + /// [`send`]: OwnedPermit::send + /// [`reserve_owned`]: Sender::reserve_owned + /// [`Arc::clone`]: std::sync::Arc::clone + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.clone().try_reserve_owned().unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Trying to reserve an additional slot on the `tx` will + /// // fail because there is no capacity. + /// assert!(tx.try_reserve().is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// } + /// ``` + pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(self)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(self)), + } + + Ok(OwnedPermit { + chan: Some(self.chan), + }) + } + + /// Returns `true` if senders belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(1); + /// let tx2 = tx.clone(); + /// assert!(tx.same_channel(&tx2)); + /// + /// let (tx3, rx3) = tokio::sync::mpsc::channel::<()>(1); + /// assert!(!tx3.same_channel(&tx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + self.chan.same_channel(&other.chan) + } + + /// Returns the current capacity of the channel. + /// + /// The capacity goes down when sending a value by calling [`send`] or by reserving capacity + /// with [`reserve`]. The capacity goes up when values are received by the [`Receiver`]. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel::<()>(5); + /// + /// assert_eq!(tx.capacity(), 5); + /// + /// // Making a reservation drops the capacity by one. + /// let permit = tx.reserve().await.unwrap(); + /// assert_eq!(tx.capacity(), 4); + /// + /// // Sending and receiving a value increases the capacity by one. + /// permit.send(()); + /// rx.recv().await.unwrap(); + /// assert_eq!(tx.capacity(), 5); + /// } + /// ``` + /// + /// [`send`]: Sender::send + /// [`reserve`]: Sender::reserve + pub fn capacity(&self) -> usize { + self.chan.semaphore().0.available_permits() + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl<T> Permit<'_, T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Send a message on the permit + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + pub fn send(self, value: T) { + use std::mem; + + self.chan.send(value); + + // Avoid the drop logic + mem::forget(self); + } +} + +impl<T> Drop for Permit<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + let semaphore = self.chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + // If this is the last sender for this channel, wake the receiver so + // that it can be notified that the channel is closed. + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); + } + } +} + +impl<T> fmt::Debug for Permit<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl<T> OwnedPermit<T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// Unlike [`Permit::send`], this method returns the [`Sender`] from which + /// the `OwnedPermit` was reserved. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve_owned().await.unwrap(); + /// + /// // Send a message on the permit, returning the sender. + /// let tx = permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// // We may now reuse `tx` to send another message. + /// tx.send(789).await.unwrap(); + /// } + /// ``` + pub fn send(mut self, value: T) -> Sender<T> { + let chan = self.chan.take().unwrap_or_else(|| { + unreachable!("OwnedPermit channel is only taken when the permit is moved") + }); + chan.send(value); + + Sender { chan } + } + + /// Releases the reserved capacity *without* sending a message, returning the + /// [`Sender`]. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = mpsc::channel(1); + /// + /// // Clone the sender and reserve capacity + /// let permit = tx.clone().reserve_owned().await.unwrap(); + /// + /// // Trying to send on the original `tx` will fail, since the `permit` + /// // has reserved all the available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Release the permit without sending a message, returning the clone + /// // of the sender. + /// let tx2 = permit.release(); + /// + /// // We may now reuse `tx` to send another message. + /// tx.send(789).await.unwrap(); + /// # drop(rx); drop(tx2); + /// } + /// ``` + /// + /// [`Sender`]: Sender + pub fn release(mut self) -> Sender<T> { + use chan::Semaphore; + + let chan = self.chan.take().unwrap_or_else(|| { + unreachable!("OwnedPermit channel is only taken when the permit is moved") + }); + + // Add the permit back to the semaphore + chan.semaphore().add_permit(); + Sender { chan } + } +} + +impl<T> Drop for OwnedPermit<T> { + fn drop(&mut self) { + use chan::Semaphore; + + // Are we still holding onto the sender? + if let Some(chan) = self.chan.take() { + let semaphore = chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + // If this `OwnedPermit` is holding the last sender for this + // channel, wake the receiver so that it can be notified that the + // channel is closed. + if semaphore.is_closed() && semaphore.is_idle() { + chan.wake_rx(); + } + } + + // Otherwise, do nothing. + } +} + +impl<T> fmt::Debug for OwnedPermit<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OwnedPermit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/chan.rs b/third_party/rust/tokio/src/sync/mpsc/chan.rs new file mode 100644 index 0000000000..c3007de89c --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/chan.rs @@ -0,0 +1,405 @@ +use crate::loom::cell::UnsafeCell; +use crate::loom::future::AtomicWaker; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +use crate::park::thread::CachedParkThread; +use crate::park::Park; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::list; +use crate::sync::notify::Notify; + +use std::fmt; +use std::process; +use std::sync::atomic::Ordering::{AcqRel, Relaxed}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +/// Channel sender. +pub(crate) struct Tx<T, S> { + inner: Arc<Chan<T, S>>, +} + +impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Tx").field("inner", &self.inner).finish() + } +} + +/// Channel receiver. +pub(crate) struct Rx<T, S: Semaphore> { + inner: Arc<Chan<T, S>>, +} + +impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Rx").field("inner", &self.inner).finish() + } +} + +pub(crate) trait Semaphore { + fn is_idle(&self) -> bool; + + fn add_permit(&self); + + fn close(&self); + + fn is_closed(&self) -> bool; +} + +struct Chan<T, S> { + /// Notifies all tasks listening for the receiver being dropped. + notify_rx_closed: Notify, + + /// Handle to the push half of the lock-free list. + tx: list::Tx<T>, + + /// Coordinates access to channel's capacity. + semaphore: S, + + /// Receiver waker. Notified when a value is pushed into the channel. + rx_waker: AtomicWaker, + + /// Tracks the number of outstanding sender handles. + /// + /// When this drops to zero, the send half of the channel is closed. + tx_count: AtomicUsize, + + /// Only accessed by `Rx` handle. + rx_fields: UnsafeCell<RxFields<T>>, +} + +impl<T, S> fmt::Debug for Chan<T, S> +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Chan") + .field("tx", &self.tx) + .field("semaphore", &self.semaphore) + .field("rx_waker", &self.rx_waker) + .field("tx_count", &self.tx_count) + .field("rx_fields", &"...") + .finish() + } +} + +/// Fields only accessed by `Rx` handle. +struct RxFields<T> { + /// Channel receiver. This field is only accessed by the `Receiver` type. + list: list::Rx<T>, + + /// `true` if `Rx::close` is called. + rx_closed: bool, +} + +impl<T> fmt::Debug for RxFields<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("RxFields") + .field("list", &self.list) + .field("rx_closed", &self.rx_closed) + .finish() + } +} + +unsafe impl<T: Send, S: Send> Send for Chan<T, S> {} +unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {} + +pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) { + let (tx, rx) = list::channel(); + + let chan = Arc::new(Chan { + notify_rx_closed: Notify::new(), + tx, + semaphore, + rx_waker: AtomicWaker::new(), + tx_count: AtomicUsize::new(1), + rx_fields: UnsafeCell::new(RxFields { + list: rx, + rx_closed: false, + }), + }); + + (Tx::new(chan.clone()), Rx::new(chan)) +} + +// ===== impl Tx ===== + +impl<T, S> Tx<T, S> { + fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> { + Tx { inner: chan } + } + + pub(super) fn semaphore(&self) -> &S { + &self.inner.semaphore + } + + /// Send a message and notify the receiver. + pub(crate) fn send(&self, value: T) { + self.inner.send(value); + } + + /// Wake the receive half + pub(crate) fn wake_rx(&self) { + self.inner.rx_waker.wake(); + } + + /// Returns `true` if senders belong to the same channel. + pub(crate) fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<T, S: Semaphore> Tx<T, S> { + pub(crate) fn is_closed(&self) -> bool { + self.inner.semaphore.is_closed() + } + + pub(crate) async fn closed(&self) { + // In order to avoid a race condition, we first request a notification, + // **then** check whether the semaphore is closed. If the semaphore is + // closed the notification request is dropped. + let notified = self.inner.notify_rx_closed.notified(); + + if self.inner.semaphore.is_closed() { + return; + } + notified.await; + } +} + +impl<T, S> Clone for Tx<T, S> { + fn clone(&self) -> Tx<T, S> { + // Using a Relaxed ordering here is sufficient as the caller holds a + // strong ref to `self`, preventing a concurrent decrement to zero. + self.inner.tx_count.fetch_add(1, Relaxed); + + Tx { + inner: self.inner.clone(), + } + } +} + +impl<T, S> Drop for Tx<T, S> { + fn drop(&mut self) { + if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { + return; + } + + // Close the list, which sends a `Close` message + self.inner.tx.close(); + + // Notify the receiver + self.wake_rx(); + } +} + +// ===== impl Rx ===== + +impl<T, S: Semaphore> Rx<T, S> { + fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> { + Rx { inner: chan } + } + + pub(crate) fn close(&mut self) { + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + if rx_fields.rx_closed { + return; + } + + rx_fields.rx_closed = true; + }); + + self.inner.semaphore.close(); + self.inner.notify_rx_closed.notify_waiters(); + } + + /// Receive the next value + pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + use super::block::Read::*; + + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.pop(&self.inner.tx) { + Some(Value(value)) => { + self.inner.semaphore.add_permit(); + coop.made_progress(); + return Ready(Some(value)); + } + Some(Closed) => { + // TODO: This check may not be required as it most + // likely can only return `true` at this point. A + // channel is closed when all tx handles are + // dropped. Dropping a tx handle releases memory, + // which ensures that if dropping the tx handle is + // visible, then all messages sent are also visible. + assert!(self.inner.semaphore.is_idle()); + coop.made_progress(); + return Ready(None); + } + None => {} // fall through + } + }; + } + + try_recv!(); + + self.inner.rx_waker.register_by_ref(cx.waker()); + + // It is possible that a value was pushed between attempting to read + // and registering the task, so we have to check the channel a + // second time here. + try_recv!(); + + if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + coop.made_progress(); + Ready(None) + } else { + Pending + } + }) + } + + /// Try to receive the next value. + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { + use super::list::TryPopResult; + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.try_pop(&self.inner.tx) { + TryPopResult::Ok(value) => { + self.inner.semaphore.add_permit(); + return Ok(value); + } + TryPopResult::Closed => return Err(TryRecvError::Disconnected), + TryPopResult::Empty => return Err(TryRecvError::Empty), + TryPopResult::Busy => {} // fall through + } + }; + } + + try_recv!(); + + // If a previous `poll_recv` call has set a waker, we wake it here. + // This allows us to put our own CachedParkThread waker in the + // AtomicWaker slot instead. + // + // This is not a spurious wakeup to `poll_recv` since we just got a + // Busy from `try_pop`, which only happens if there are messages in + // the queue. + self.inner.rx_waker.wake(); + + // Park the thread until the problematic send has completed. + let mut park = CachedParkThread::new(); + let waker = park.unpark().into_waker(); + loop { + self.inner.rx_waker.register_by_ref(&waker); + // It is possible that the problematic send has now completed, + // so we have to check for messages again. + try_recv!(); + park.park().expect("park failed"); + } + }) + } +} + +impl<T, S: Semaphore> Drop for Rx<T, S> { + fn drop(&mut self) { + use super::block::Read::Value; + + self.close(); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) { + self.inner.semaphore.add_permit(); + } + }) + } +} + +// ===== impl Chan ===== + +impl<T, S> Chan<T, S> { + fn send(&self, value: T) { + // Push the value + self.tx.push(value); + + // Notify the rx task + self.rx_waker.wake(); + } +} + +impl<T, S> Drop for Chan<T, S> { + fn drop(&mut self) { + use super::block::Read::Value; + + // Safety: the only owner of the rx fields is Chan, and eing + // inside its own Drop means we're the last ones to touch it. + self.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {} + unsafe { rx_fields.list.free_blocks() }; + }); + } +} + +// ===== impl Semaphore for (::Semaphore, capacity) ===== + +impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { + fn add_permit(&self) { + self.0.release(1) + } + + fn is_idle(&self) -> bool { + self.0.available_permits() == self.1 + } + + fn close(&self) { + self.0.close(); + } + + fn is_closed(&self) -> bool { + self.0.is_closed() + } +} + +// ===== impl Semaphore for AtomicUsize ===== + +use std::sync::atomic::Ordering::{Acquire, Release}; +use std::usize; + +impl Semaphore for AtomicUsize { + fn add_permit(&self) { + let prev = self.fetch_sub(2, Release); + + if prev >> 1 == 0 { + // Something went wrong + process::abort(); + } + } + + fn is_idle(&self) -> bool { + self.load(Acquire) >> 1 == 0 + } + + fn close(&self) { + self.fetch_or(1, Release); + } + + fn is_closed(&self) -> bool { + self.load(Acquire) & 1 == 1 + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/error.rs b/third_party/rust/tokio/src/sync/mpsc/error.rs new file mode 100644 index 0000000000..3fe6bac5e1 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/error.rs @@ -0,0 +1,125 @@ +//! Channel error types. + +use std::error::Error; +use std::fmt; + +/// Error returned by the `Sender`. +#[derive(Debug)] +pub struct SendError<T>(pub T); + +impl<T> fmt::Display for SendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl<T: fmt::Debug> std::error::Error for SendError<T> {} + +// ===== TrySendError ===== + +/// This enumeration is the list of the possible error outcomes for the +/// [try_send](super::Sender::try_send) method. +#[derive(Debug, Eq, PartialEq)] +pub enum TrySendError<T> { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), +} + +impl<T: fmt::Debug> Error for TrySendError<T> {} + +impl<T> fmt::Display for TrySendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TrySendError::Full(..) => "no available capacity", + TrySendError::Closed(..) => "channel closed", + } + ) + } +} + +impl<T> From<SendError<T>> for TrySendError<T> { + fn from(src: SendError<T>) -> TrySendError<T> { + TrySendError::Closed(src.0) + } +} + +// ===== TryRecvError ===== + +/// Error returned by `try_recv`. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// This **channel** is currently empty, but the **Sender**(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The **channel**'s sending half has become disconnected, and there will + /// never be any more data received on it. + Disconnected, +} + +impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), + TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), + } + } +} + +impl Error for TryRecvError {} + +// ===== RecvError ===== + +/// Error returned by `Receiver`. +#[derive(Debug)] +#[doc(hidden)] +#[deprecated(note = "This type is unused because recv returns an Option.")] +pub struct RecvError(()); + +#[allow(deprecated)] +impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +#[allow(deprecated)] +impl Error for RecvError {} + +cfg_time! { + // ===== SendTimeoutError ===== + + #[derive(Debug, Eq, PartialEq)] + /// Error returned by [`Sender::send_timeout`](super::Sender::send_timeout)]. + pub enum SendTimeoutError<T> { + /// The data could not be sent on the channel because the channel is + /// full, and the timeout to send has elapsed. + Timeout(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), + } + + impl<T: fmt::Debug> Error for SendTimeoutError<T> {} + + impl<T> fmt::Display for SendTimeoutError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + SendTimeoutError::Timeout(..) => "timed out waiting on send operation", + SendTimeoutError::Closed(..) => "channel closed", + } + ) + } + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/list.rs b/third_party/rust/tokio/src/sync/mpsc/list.rs new file mode 100644 index 0000000000..e4eeb45411 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/list.rs @@ -0,0 +1,371 @@ +//! A concurrent, lock-free, FIFO list. + +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; +use crate::sync::mpsc::block::{self, Block}; + +use std::fmt; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; + +/// List queue transmit handle. +pub(crate) struct Tx<T> { + /// Tail in the `Block` mpmc list. + block_tail: AtomicPtr<Block<T>>, + + /// Position to push the next message. This references a block and offset + /// into the block. + tail_position: AtomicUsize, +} + +/// List queue receive handle +pub(crate) struct Rx<T> { + /// Pointer to the block being processed. + head: NonNull<Block<T>>, + + /// Next slot index to process. + index: usize, + + /// Pointer to the next block pending release. + free_head: NonNull<Block<T>>, +} + +/// Return value of `Rx::try_pop`. +pub(crate) enum TryPopResult<T> { + /// Successfully popped a value. + Ok(T), + /// The channel is empty. + Empty, + /// The channel is empty and closed. + Closed, + /// The channel is not empty, but the first value is being written. + Busy, +} + +pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) { + // Create the initial block shared between the tx and rx halves. + let initial_block = Box::new(Block::new(0)); + let initial_block_ptr = Box::into_raw(initial_block); + + let tx = Tx { + block_tail: AtomicPtr::new(initial_block_ptr), + tail_position: AtomicUsize::new(0), + }; + + let head = NonNull::new(initial_block_ptr).unwrap(); + + let rx = Rx { + head, + index: 0, + free_head: head, + }; + + (tx, rx) +} + +impl<T> Tx<T> { + /// Pushes a value into the list. + pub(crate) fn push(&self, value: T) { + // First, claim a slot for the value. `Acquire` is used here to + // synchronize with the `fetch_add` in `reclaim_blocks`. + let slot_index = self.tail_position.fetch_add(1, Acquire); + + // Load the current block and write the value + let block = self.find_block(slot_index); + + unsafe { + // Write the value to the block + block.as_ref().write(slot_index, value); + } + } + + /// Closes the send half of the list. + /// + /// Similar process as pushing a value, but instead of writing the value & + /// setting the ready flag, the TX_CLOSED flag is set on the block. + pub(crate) fn close(&self) { + // First, claim a slot for the value. This is the last slot that will be + // claimed. + let slot_index = self.tail_position.fetch_add(1, Acquire); + + let block = self.find_block(slot_index); + + unsafe { block.as_ref().tx_close() } + } + + fn find_block(&self, slot_index: usize) -> NonNull<Block<T>> { + // The start index of the block that contains `index`. + let start_index = block::start_index(slot_index); + + // The index offset into the block + let offset = block::offset(slot_index); + + // Load the current head of the block + let mut block_ptr = self.block_tail.load(Acquire); + + let block = unsafe { &*block_ptr }; + + // Calculate the distance between the tail ptr and the target block + let distance = block.distance(start_index); + + // Decide if this call to `find_block` should attempt to update the + // `block_tail` pointer. + // + // Updating `block_tail` is not always performed in order to reduce + // contention. + // + // When set, as the routine walks the linked list, it attempts to update + // `block_tail`. If the update cannot be performed, `try_updating_tail` + // is unset. + let mut try_updating_tail = distance > offset; + + // Walk the linked list of blocks until the block with `start_index` is + // found. + loop { + let block = unsafe { &(*block_ptr) }; + + if block.is_at_index(start_index) { + return unsafe { NonNull::new_unchecked(block_ptr) }; + } + + let next_block = block + .load_next(Acquire) + // There is no allocated next block, grow the linked list. + .unwrap_or_else(|| block.grow()); + + // If the block is **not** final, then the tail pointer cannot be + // advanced any more. + try_updating_tail &= block.is_final(); + + if try_updating_tail { + // Advancing `block_tail` must happen when walking the linked + // list. `block_tail` may not advance passed any blocks that are + // not "final". At the point a block is finalized, it is unknown + // if there are any prior blocks that are unfinalized, which + // makes it impossible to advance `block_tail`. + // + // While walking the linked list, `block_tail` can be advanced + // as long as finalized blocks are traversed. + // + // Release ordering is used to ensure that any subsequent reads + // are able to see the memory pointed to by `block_tail`. + // + // Acquire is not needed as any "actual" value is not accessed. + // At this point, the linked list is walked to acquire blocks. + if self + .block_tail + .compare_exchange(block_ptr, next_block.as_ptr(), Release, Relaxed) + .is_ok() + { + // Synchronize with any senders + let tail_position = self.tail_position.fetch_add(0, Release); + + unsafe { + block.tx_release(tail_position); + } + } else { + // A concurrent sender is also working on advancing + // `block_tail` and this thread is falling behind. + // + // Stop trying to advance the tail pointer + try_updating_tail = false; + } + } + + block_ptr = next_block.as_ptr(); + + thread::yield_now(); + } + } + + pub(crate) unsafe fn reclaim_block(&self, mut block: NonNull<Block<T>>) { + // The block has been removed from the linked list and ownership + // is reclaimed. + // + // Before dropping the block, see if it can be reused by + // inserting it back at the end of the linked list. + // + // First, reset the data + block.as_mut().reclaim(); + + let mut reused = false; + + // Attempt to insert the block at the end + // + // Walk at most three times + // + let curr_ptr = self.block_tail.load(Acquire); + + // The pointer can never be null + debug_assert!(!curr_ptr.is_null()); + + let mut curr = NonNull::new_unchecked(curr_ptr); + + // TODO: Unify this logic with Block::grow + for _ in 0..3 { + match curr.as_ref().try_push(&mut block, AcqRel, Acquire) { + Ok(_) => { + reused = true; + break; + } + Err(next) => { + curr = next; + } + } + } + + if !reused { + let _ = Box::from_raw(block.as_ptr()); + } + } +} + +impl<T> fmt::Debug for Tx<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Tx") + .field("block_tail", &self.block_tail.load(Relaxed)) + .field("tail_position", &self.tail_position.load(Relaxed)) + .finish() + } +} + +impl<T> Rx<T> { + /// Pops the next value off the queue. + pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> { + // Advance `head`, if needed + if !self.try_advancing_head() { + return None; + } + + self.reclaim_blocks(tx); + + unsafe { + let block = self.head.as_ref(); + + let ret = block.read(self.index); + + if let Some(block::Read::Value(..)) = ret { + self.index = self.index.wrapping_add(1); + } + + ret + } + } + + /// Pops the next value off the queue, detecting whether the block + /// is busy or empty on failure. + /// + /// This function exists because `Rx::pop` can return `None` even if the + /// channel's queue contains a message that has been completely written. + /// This can happen if the fully delivered message is behind another message + /// that is in the middle of being written to the block, since the channel + /// can't return the messages out of order. + pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> { + let tail_position = tx.tail_position.load(Acquire); + let result = self.pop(tx); + + match result { + Some(block::Read::Value(t)) => TryPopResult::Ok(t), + Some(block::Read::Closed) => TryPopResult::Closed, + None if tail_position == self.index => TryPopResult::Empty, + None => TryPopResult::Busy, + } + } + + /// Tries advancing the block pointer to the block referenced by `self.index`. + /// + /// Returns `true` if successful, `false` if there is no next block to load. + fn try_advancing_head(&mut self) -> bool { + let block_index = block::start_index(self.index); + + loop { + let next_block = { + let block = unsafe { self.head.as_ref() }; + + if block.is_at_index(block_index) { + return true; + } + + block.load_next(Acquire) + }; + + let next_block = match next_block { + Some(next_block) => next_block, + None => { + return false; + } + }; + + self.head = next_block; + + thread::yield_now(); + } + } + + fn reclaim_blocks(&mut self, tx: &Tx<T>) { + while self.free_head != self.head { + unsafe { + // Get a handle to the block that will be freed and update + // `free_head` to point to the next block. + let block = self.free_head; + + let observed_tail_position = block.as_ref().observed_tail_position(); + + let required_index = match observed_tail_position { + Some(i) => i, + None => return, + }; + + if required_index > self.index { + return; + } + + // We may read the next pointer with `Relaxed` ordering as it is + // guaranteed that the `reclaim_blocks` routine trails the `recv` + // routine. Any memory accessed by `reclaim_blocks` has already + // been acquired by `recv`. + let next_block = block.as_ref().load_next(Relaxed); + + // Update the free list head + self.free_head = next_block.unwrap(); + + // Push the emptied block onto the back of the queue, making it + // available to senders. + tx.reclaim_block(block); + } + + thread::yield_now(); + } + } + + /// Effectively `Drop` all the blocks. Should only be called once, when + /// the list is dropping. + pub(super) unsafe fn free_blocks(&mut self) { + debug_assert_ne!(self.free_head, NonNull::dangling()); + + let mut cur = Some(self.free_head); + + #[cfg(debug_assertions)] + { + // to trigger the debug assert above so as to catch that we + // don't call `free_blocks` more than once. + self.free_head = NonNull::dangling(); + self.head = NonNull::dangling(); + } + + while let Some(block) = cur { + cur = block.as_ref().load_next(Relaxed); + drop(Box::from_raw(block.as_ptr())); + } + } +} + +impl<T> fmt::Debug for Rx<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Rx") + .field("head", &self.head) + .field("index", &self.index) + .field("free_head", &self.free_head) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/mpsc/mod.rs b/third_party/rust/tokio/src/sync/mpsc/mod.rs new file mode 100644 index 0000000000..b1513a9da5 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/mod.rs @@ -0,0 +1,115 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A multi-producer, single-consumer queue for sending values between +//! asynchronous tasks. +//! +//! This module provides two variants of the channel: bounded and unbounded. The +//! bounded variant has a limit on the number of messages that the channel can +//! store, and if this limit is reached, trying to send another message will +//! wait until a message is received from the channel. An unbounded channel has +//! an infinite capacity, so the `send` method will always complete immediately. +//! This makes the [`UnboundedSender`] usable from both synchronous and +//! asynchronous code. +//! +//! Similar to the `mpsc` channels provided by `std`, the channel constructor +//! functions provide separate send and receive handles, [`Sender`] and +//! [`Receiver`] for the bounded channel, [`UnboundedSender`] and +//! [`UnboundedReceiver`] for the unbounded channel. If there is no message to read, +//! the current task will be notified when a new value is sent. [`Sender`] and +//! [`UnboundedSender`] allow sending values into the channel. If the bounded +//! channel is at capacity, the send is rejected and the task will be notified +//! when additional capacity is available. In other words, the channel provides +//! backpressure. +//! +//! +//! # Disconnection +//! +//! When all [`Sender`] handles have been dropped, it is no longer +//! possible to send values into the channel. This is considered the termination +//! event of the stream. As such, `Receiver::poll` returns `Ok(Ready(None))`. +//! +//! If the [`Receiver`] handle is dropped, then messages can no longer +//! be read out of the channel. In this case, all further attempts to send will +//! result in an error. +//! +//! # Clean Shutdown +//! +//! When the [`Receiver`] is dropped, it is possible for unprocessed messages to +//! remain in the channel. Instead, it is usually desirable to perform a "clean" +//! shutdown. To do this, the receiver first calls `close`, which will prevent +//! any further messages to be sent into the channel. Then, the receiver +//! consumes the channel to completion, at which point the receiver can be +//! dropped. +//! +//! # Communicating between sync and async code +//! +//! When you want to communicate between synchronous and asynchronous code, there +//! are two situations to consider: +//! +//! **Bounded channel**: If you need a bounded channel, you should use a bounded +//! Tokio `mpsc` channel for both directions of communication. Instead of calling +//! the async [`send`][bounded-send] or [`recv`][bounded-recv] methods, in +//! synchronous code you will need to use the [`blocking_send`][blocking-send] or +//! [`blocking_recv`][blocking-recv] methods. +//! +//! **Unbounded channel**: You should use the kind of channel that matches where +//! the receiver is. So for sending a message _from async to sync_, you should +//! use [the standard library unbounded channel][std-unbounded] or +//! [crossbeam][crossbeam-unbounded]. Similarly, for sending a message _from sync +//! to async_, you should use an unbounded Tokio `mpsc` channel. +//! +//! Please be aware that the above remarks were written with the `mpsc` channel +//! in mind, but they can also be generalized to other kinds of channels. In +//! general, any channel method that isn't marked async can be called anywhere, +//! including outside of the runtime. For example, sending a message on a +//! oneshot channel from outside the runtime is perfectly fine. +//! +//! # Multiple runtimes +//! +//! The mpsc channel does not care about which runtime you use it in, and can be +//! used to send messages from one runtime to another. It can also be used in +//! non-Tokio runtimes. +//! +//! There is one exception to the above: the [`send_timeout`] must be used from +//! within a Tokio runtime, however it is still not tied to one specific Tokio +//! runtime, and the sender may be moved from one Tokio runtime to another. +//! +//! [`Sender`]: crate::sync::mpsc::Sender +//! [`Receiver`]: crate::sync::mpsc::Receiver +//! [bounded-send]: crate::sync::mpsc::Sender::send() +//! [bounded-recv]: crate::sync::mpsc::Receiver::recv() +//! [blocking-send]: crate::sync::mpsc::Sender::blocking_send() +//! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv() +//! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender +//! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver +//! [`Handle::block_on`]: crate::runtime::Handle::block_on() +//! [std-unbounded]: std::sync::mpsc::channel +//! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html +//! [`send_timeout`]: crate::sync::mpsc::Sender::send_timeout + +pub(super) mod block; + +mod bounded; +pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender}; + +mod chan; + +pub(super) mod list; + +mod unbounded; +pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender}; + +pub mod error; + +/// The number of values a block can contain. +/// +/// This value must be a power of 2. It also must be smaller than the number of +/// bits in `usize`. +#[cfg(all(target_pointer_width = "64", not(loom)))] +const BLOCK_CAP: usize = 32; + +#[cfg(all(not(target_pointer_width = "64"), not(loom)))] +const BLOCK_CAP: usize = 16; + +#[cfg(loom)] +const BLOCK_CAP: usize = 2; diff --git a/third_party/rust/tokio/src/sync/mpsc/unbounded.rs b/third_party/rust/tokio/src/sync/mpsc/unbounded.rs new file mode 100644 index 0000000000..b133f9f35e --- /dev/null +++ b/third_party/rust/tokio/src/sync/mpsc/unbounded.rs @@ -0,0 +1,373 @@ +use crate::loom::sync::atomic::AtomicUsize; +use crate::sync::mpsc::chan; +use crate::sync::mpsc::error::{SendError, TryRecvError}; + +use std::fmt; +use std::task::{Context, Poll}; + +/// Send values to the associated `UnboundedReceiver`. +/// +/// Instances are created by the +/// [`unbounded_channel`](unbounded_channel) function. +pub struct UnboundedSender<T> { + chan: chan::Tx<T, Semaphore>, +} + +impl<T> Clone for UnboundedSender<T> { + fn clone(&self) -> Self { + UnboundedSender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for UnboundedSender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedSender") + .field("chan", &self.chan) + .finish() + } +} + +/// Receive values from the associated `UnboundedSender`. +/// +/// Instances are created by the +/// [`unbounded_channel`](unbounded_channel) function. +/// +/// This receiver can be turned into a `Stream` using [`UnboundedReceiverStream`]. +/// +/// [`UnboundedReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.UnboundedReceiverStream.html +pub struct UnboundedReceiver<T> { + /// The channel receiver + chan: chan::Rx<T, Semaphore>, +} + +impl<T> fmt::Debug for UnboundedReceiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("UnboundedReceiver") + .field("chan", &self.chan) + .finish() + } +} + +/// Creates an unbounded mpsc channel for communicating between asynchronous +/// tasks without backpressure. +/// +/// A `send` on this channel will always succeed as long as the receive half has +/// not been closed. If the receiver falls behind, messages will be arbitrarily +/// buffered. +/// +/// **Note** that the amount of available system memory is an implicit bound to +/// the channel. Using an `unbounded` channel has the ability of causing the +/// process to run out of memory. In this case, the process will be aborted. +pub fn unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) { + let (tx, rx) = chan::channel(AtomicUsize::new(0)); + + let tx = UnboundedSender::new(tx); + let rx = UnboundedReceiver::new(rx); + + (tx, rx) +} + +/// No capacity +type Semaphore = AtomicUsize; + +impl<T> UnboundedReceiver<T> { + pub(crate) fn new(chan: chan::Rx<T, Semaphore>) -> UnboundedReceiver<T> { + UnboundedReceiver { chan } + } + + /// Receives the next value for this receiver. + /// + /// `None` is returned when all `Sender` halves have dropped, indicating + /// that no further values can be sent on the channel. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tokio::spawn(async move { + /// tx.send("hello").unwrap(); + /// }); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(None, rx.recv().await); + /// } + /// ``` + /// + /// Values are buffered: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// tx.send("world").unwrap(); + /// + /// assert_eq!(Some("hello"), rx.recv().await); + /// assert_eq!(Some("world"), rx.recv().await); + /// } + /// ``` + pub async fn recv(&mut self) -> Option<T> { + use crate::future::poll_fn; + + poll_fn(|cx| self.poll_recv(cx)).await + } + + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel::<u8>(); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) + } + + /// Closes the receiving half of a channel, without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + pub fn close(&mut self) { + self.chan.close(); + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed, or if a spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } +} + +impl<T> UnboundedSender<T> { + pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> UnboundedSender<T> { + UnboundedSender { chan } + } + + /// Attempts to send a message on this `UnboundedSender` without blocking. + /// + /// This method is not marked async because sending a message to an unbounded channel + /// never requires any form of waiting. Because of this, the `send` method can be + /// used in both synchronous and asynchronous code without problems. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`UnboundedReceiver`] having been dropped, this + /// function returns an error. The error includes the value passed to `send`. + /// + /// [`close`]: UnboundedReceiver::close + /// [`UnboundedReceiver`]: UnboundedReceiver + pub fn send(&self, message: T) -> Result<(), SendError<T>> { + if !self.inc_num_messages() { + return Err(SendError(message)); + } + + self.chan.send(message); + Ok(()) + } + + fn inc_num_messages(&self) -> bool { + use std::process; + use std::sync::atomic::Ordering::{AcqRel, Acquire}; + + let mut curr = self.chan.semaphore().load(Acquire); + + loop { + if curr & 1 == 1 { + return false; + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self + .chan + .semaphore() + .compare_exchange(curr, curr + 2, AcqRel, Acquire) + { + Ok(_) => return true, + Err(actual) => { + curr = actual; + } + } + } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::unbounded_channel::<()>(); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + //// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await + } + + /// Checks if the channel has been closed. This happens when the + /// [`UnboundedReceiver`] is dropped, or when the + /// [`UnboundedReceiver::close`] method is called. + /// + /// [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver + /// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// assert!(!tx.is_closed()); + /// + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Returns `true` if senders belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// let tx2 = tx.clone(); + /// assert!(tx.same_channel(&tx2)); + /// + /// let (tx3, rx3) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// assert!(!tx3.same_channel(&tx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + self.chan.same_channel(&other.chan) + } +} diff --git a/third_party/rust/tokio/src/sync/mutex.rs b/third_party/rust/tokio/src/sync/mutex.rs new file mode 100644 index 0000000000..b8d5ba74e7 --- /dev/null +++ b/third_party/rust/tokio/src/sync/mutex.rs @@ -0,0 +1,967 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + +use crate::sync::batch_semaphore as semaphore; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; + +use std::cell::UnsafeCell; +use std::error::Error; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::{fmt, marker, mem}; + +/// An asynchronous `Mutex`-like type. +/// +/// This type acts similarly to [`std::sync::Mutex`], with two major +/// differences: [`lock`] is an async method so does not block, and the lock +/// guard is designed to be held across `.await` points. +/// +/// # Which kind of mutex should you use? +/// +/// Contrary to popular belief, it is ok and often preferred to use the ordinary +/// [`Mutex`][std] from the standard library in asynchronous code. +/// +/// The feature that the async mutex offers over the blocking mutex is the +/// ability to keep it locked across an `.await` point. This makes the async +/// mutex more expensive than the blocking mutex, so the blocking mutex should +/// be preferred in the cases where it can be used. The primary use case for the +/// async mutex is to provide shared mutable access to IO resources such as a +/// database connection. If the value behind the mutex is just data, it's +/// usually appropriate to use a blocking mutex such as the one in the standard +/// library or [`parking_lot`]. +/// +/// Note that, although the compiler will not prevent the std `Mutex` from holding +/// its guard across `.await` points in situations where the task is not movable +/// between threads, this virtually never leads to correct concurrent code in +/// practice as it can easily lead to deadlocks. +/// +/// A common pattern is to wrap the `Arc<Mutex<...>>` in a struct that provides +/// non-async methods for performing operations on the data within, and only +/// lock the mutex inside these methods. The [mini-redis] example provides an +/// illustration of this pattern. +/// +/// Additionally, when you _do_ want shared access to an IO resource, it is +/// often better to spawn a task to manage the IO resource, and to use message +/// passing to communicate with that task. +/// +/// [std]: std::sync::Mutex +/// [`parking_lot`]: https://docs.rs/parking_lot +/// [mini-redis]: https://github.com/tokio-rs/mini-redis/blob/master/src/db.rs +/// +/// # Examples: +/// +/// ```rust,no_run +/// use tokio::sync::Mutex; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let data1 = Arc::new(Mutex::new(0)); +/// let data2 = Arc::clone(&data1); +/// +/// tokio::spawn(async move { +/// let mut lock = data2.lock().await; +/// *lock += 1; +/// }); +/// +/// let mut lock = data1.lock().await; +/// *lock += 1; +/// } +/// ``` +/// +/// +/// ```rust,no_run +/// use tokio::sync::Mutex; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let count = Arc::new(Mutex::new(0)); +/// +/// for i in 0..5 { +/// let my_count = Arc::clone(&count); +/// tokio::spawn(async move { +/// for j in 0..10 { +/// let mut lock = my_count.lock().await; +/// *lock += 1; +/// println!("{} {} {}", i, j, lock); +/// } +/// }); +/// } +/// +/// loop { +/// if *count.lock().await >= 50 { +/// break; +/// } +/// } +/// println!("Count hit 50."); +/// } +/// ``` +/// There are a few things of note here to pay attention to in this example. +/// 1. The mutex is wrapped in an [`Arc`] to allow it to be shared across +/// threads. +/// 2. Each spawned task obtains a lock and releases it on every iteration. +/// 3. Mutation of the data protected by the Mutex is done by de-referencing +/// the obtained lock as seen on lines 12 and 19. +/// +/// Tokio's Mutex works in a simple FIFO (first in, first out) style where all +/// calls to [`lock`] complete in the order they were performed. In that way the +/// Mutex is "fair" and predictable in how it distributes the locks to inner +/// data. Locks are released and reacquired after every iteration, so basically, +/// each thread goes to the back of the line after it increments the value once. +/// Note that there's some unpredictability to the timing between when the +/// threads are started, but once they are going they alternate predictably. +/// Finally, since there is only a single valid lock at any given time, there is +/// no possibility of a race condition when mutating the inner value. +/// +/// Note that in contrast to [`std::sync::Mutex`], this implementation does not +/// poison the mutex when a thread holding the [`MutexGuard`] panics. In such a +/// case, the mutex will be unlocked. If the panic is caught, this might leave +/// the data protected by the mutex in an inconsistent state. +/// +/// [`Mutex`]: struct@Mutex +/// [`MutexGuard`]: struct@MutexGuard +/// [`Arc`]: struct@std::sync::Arc +/// [`std::sync::Mutex`]: struct@std::sync::Mutex +/// [`Send`]: trait@std::marker::Send +/// [`lock`]: method@Mutex::lock +pub struct Mutex<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + s: semaphore::Semaphore, + c: UnsafeCell<T>, +} + +/// A handle to a held `Mutex`. The guard can be held across any `.await` point +/// as it is [`Send`]. +/// +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally borrows the `Mutex`, so the mutex will not be +/// dropped while a guard exists. +/// +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +pub struct MutexGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: &'a Mutex<T>, +} + +/// An owned handle to a held `Mutex`. +/// +/// This guard is only available from a `Mutex` that is wrapped in an [`Arc`]. It +/// is identical to `MutexGuard`, except that rather than borrowing the `Mutex`, +/// it clones the `Arc`, incrementing the reference count. This means that +/// unlike `MutexGuard`, it will have the `'static` lifetime. +/// +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally keeps a reference-counted pointer to the original +/// `Mutex`, so even if the lock goes away, the guard remains valid. +/// +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +/// +/// [`Arc`]: std::sync::Arc +pub struct OwnedMutexGuard<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + lock: Arc<Mutex<T>>, +} + +/// A handle to a held `Mutex` that has had a function applied to it via [`MutexGuard::map`]. +/// +/// This can be used to hold a subfield of the protected data. +/// +/// [`MutexGuard::map`]: method@MutexGuard::map +#[must_use = "if unused the Mutex will immediately unlock"] +pub struct MappedMutexGuard<'a, T: ?Sized> { + s: &'a semaphore::Semaphore, + data: *mut T, + // Needed to tell the borrow checker that we are holding a `&mut T` + marker: marker::PhantomData<&'a mut T>, +} + +// As long as T: Send, it's fine to send and share Mutex<T> between threads. +// If T was not Send, sending and sharing a Mutex<T> would be bad, since you can +// access T through Mutex<T>. +unsafe impl<T> Send for Mutex<T> where T: ?Sized + Send {} +unsafe impl<T> Sync for Mutex<T> where T: ?Sized + Send {} +unsafe impl<T> Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Sync for OwnedMutexGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<'a, T> Sync for MappedMutexGuard<'a, T> where T: ?Sized + Sync + 'a {} +unsafe impl<'a, T> Send for MappedMutexGuard<'a, T> where T: ?Sized + Send + 'a {} + +/// Error returned from the [`Mutex::try_lock`], [`RwLock::try_read`] and +/// [`RwLock::try_write`] functions. +/// +/// `Mutex::try_lock` operation will only fail if the mutex is already locked. +/// +/// `RwLock::try_read` operation will only fail if the lock is currently held +/// by an exclusive writer. +/// +/// `RwLock::try_write` operation will if lock is held by any reader or by an +/// exclusive writer. +/// +/// [`Mutex::try_lock`]: Mutex::try_lock +/// [`RwLock::try_read`]: fn@super::RwLock::try_read +/// [`RwLock::try_write`]: fn@super::RwLock::try_write +#[derive(Debug)] +pub struct TryLockError(pub(super) ()); + +impl fmt::Display for TryLockError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "operation would block") + } +} + +impl Error for TryLockError {} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send<T: Send>() {} + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + fn check_send_sync<T: Send + Sync>() {} + fn check_static<T: 'static>() {} + fn check_static_val<T: 'static>(_t: T) {} + + check_send::<MutexGuard<'_, u32>>(); + check_send::<OwnedMutexGuard<u32>>(); + check_unpin::<Mutex<u32>>(); + check_send_sync::<Mutex<u32>>(); + check_static::<OwnedMutexGuard<u32>>(); + + let mutex = Mutex::new(1); + check_send_sync_val(mutex.lock()); + let arc_mutex = Arc::new(Mutex::new(1)); + check_send_sync_val(arc_mutex.clone().lock_owned()); + check_static_val(arc_mutex.lock_owned()); +} + +impl<T: ?Sized> Mutex<T> { + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// let lock = Mutex::new(5); + /// ``` + #[track_caller] + pub fn new(t: T) -> Self + where + T: Sized, + { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Mutex", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + semaphore::Semaphore::new(1) + }); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = semaphore::Semaphore::new(1); + + Self { + c: UnsafeCell::new(t), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// static LOCK: Mutex<i32> = Mutex::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(t: T) -> Self + where + T: Sized, + { + Self { + c: UnsafeCell::new(t), + s: semaphore::Semaphore::const_new(1), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock` makes you lose your place in + /// the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Mutex::new(1); + /// + /// let mut n = mutex.lock().await; + /// *n = 2; + /// } + /// ``` + pub async fn lock(&self) -> MutexGuard<'_, T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.acquire(), + self.resource_span.clone(), + "Mutex::lock", + "poll", + false, + ) + .await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + self.acquire().await; + + MutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + } + + /// Blockingly locks this `Mutex`. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// This method is intended for use cases where you + /// need to use this mutex in asynchronous code as well as in synchronous code. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// let lock = mutex.lock().await; + /// + /// let mutex1 = Arc::clone(&mutex); + /// let blocking_task = tokio::task::spawn_blocking(move || { + /// // This shall block until the `lock` is released. + /// let mut n = mutex1.blocking_lock(); + /// *n = 2; + /// }); + /// + /// assert_eq!(*lock, 1); + /// // Release the lock. + /// drop(lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let n = mutex.try_lock().unwrap(); + /// assert_eq!(*n, 2); + /// } + /// + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_lock(&self) -> MutexGuard<'_, T> { + crate::future::block_on(self.lock()) + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, this returns an + /// [`OwnedMutexGuard`]. + /// + /// This method is identical to [`Mutex::lock`], except that the returned + /// guard references the `Mutex` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `Mutex` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `Mutex` alive by holding an `Arc`. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let mut n = mutex.clone().lock_owned().await; + /// *n = 2; + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + pub async fn lock_owned(self: Arc<Self>) -> OwnedMutexGuard<T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.acquire(), + self.resource_span.clone(), + "Mutex::lock_owned", + "poll", + false, + ) + .await; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + self.acquire().await; + + OwnedMutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + async fn acquire(&self) { + self.s.acquire(1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and + // we own it exclusively, which means that this can never happen. + unreachable!() + }); + } + + /// Attempts to acquire the lock, and returns [`TryLockError`] if the + /// lock is currently held somewhere else. + /// + /// [`TryLockError`]: TryLockError + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Mutex::new(1); + /// + /// let n = mutex.try_lock()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + /// ``` + pub fn try_lock(&self) -> Result<MutexGuard<'_, T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + Ok(MutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }) + } + Err(_) => Err(TryLockError(())), + } + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// fn main() { + /// let mut mutex = Mutex::new(1); + /// + /// let n = mutex.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } + } + + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock + /// is currently held somewhere else. + /// + /// This method is identical to [`Mutex::try_lock`], except that the + /// returned guard references the `Mutex` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `Mutex` must be wrapped in an `Arc` to call + /// this method, and the guard will live for the `'static` lifetime, as it + /// keeps the `Mutex` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// [`Arc`]: std::sync::Arc + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let n = mutex.clone().try_lock_owned()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + pub fn try_lock_owned(self: Arc<Self>) -> Result<OwnedMutexGuard<T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = true, + ); + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + Ok(OwnedMutexGuard { + lock: self, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + Err(_) => Err(TryLockError(())), + } + } + + /// Consumes the mutex, returning the underlying data. + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Mutex::new(1); + /// + /// let n = mutex.into_inner(); + /// assert_eq!(n, 1); + /// } + /// ``` + pub fn into_inner(self) -> T + where + T: Sized, + { + self.c.into_inner() + } +} + +impl<T> From<T> for Mutex<T> { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl<T> Default for Mutex<T> +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} + +impl<T: ?Sized> std::fmt::Debug for Mutex<T> +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("Mutex"); + match self.try_lock() { + Ok(inner) => d.field("data", &&*inner), + Err(_) => d.field("data", &format_args!("<locked>")), + }; + d.finish() + } +} + +// === impl MutexGuard === + +impl<'a, T: ?Sized> MutexGuard<'a, T> { + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MutexGuard::map(...)`. A method + /// would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Mutex::new(Foo(1)); + /// + /// { + /// let mut mapped = MutexGuard::map(foo.lock().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`MutexGuard`]: struct@MutexGuard + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> MappedMutexGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = &this.lock.s; + mem::forget(this); + MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Mutex::new(Foo(1)); + /// + /// { + /// let mut mapped = MutexGuard::try_map(foo.lock().await, |f| Some(&mut f.0)) + /// .expect("should not fail"); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`MutexGuard`]: struct@MutexGuard + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<MappedMutexGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = &this.lock.s; + mem::forget(this); + Ok(MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + }) + } + + /// Returns a reference to the original `Mutex`. + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// async fn unlock_and_relock<'l>(guard: MutexGuard<'l, u32>) -> MutexGuard<'l, u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex = MutexGuard::mutex(&guard); + /// drop(guard); + /// let guard = mutex.lock().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Mutex::new(0u32); + /// # let guard = mutex.lock().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &'a Mutex<T> { + this.lock + } +} + +impl<T: ?Sized> Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); + self.lock.s.release(1); + } +} + +impl<T: ?Sized> Deref for MutexGuard<'_, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl<T: ?Sized> DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +// === impl OwnedMutexGuard === + +impl<T: ?Sized> OwnedMutexGuard<T> { + /// Returns a reference to the original `Arc<Mutex>`. + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// + /// async fn unlock_and_relock(guard: OwnedMutexGuard<u32>) -> OwnedMutexGuard<u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex: Arc<Mutex<u32>> = OwnedMutexGuard::mutex(&guard).clone(); + /// drop(guard); + /// let guard = mutex.lock_owned().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Arc::new(Mutex::new(0u32)); + /// # let guard = mutex.lock_owned().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &Arc<Mutex<T>> { + &this.lock + } +} + +impl<T: ?Sized> Drop for OwnedMutexGuard<T> { + fn drop(&mut self) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + locked = false, + ); + }); + self.lock.s.release(1) + } +} + +impl<T: ?Sized> Deref for OwnedMutexGuard<T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl<T: ?Sized> DerefMut for OwnedMutexGuard<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl<T: ?Sized + fmt::Debug> fmt::Debug for OwnedMutexGuard<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized + fmt::Display> fmt::Display for OwnedMutexGuard<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +// === impl MappedMutexGuard === + +impl<'a, T: ?Sized> MappedMutexGuard<'a, T> { + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MappedMutexGuard::map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> MappedMutexGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + mem::forget(this); + MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MappedMutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<MappedMutexGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + mem::forget(this); + Ok(MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + }) + } +} + +impl<'a, T: ?Sized> Drop for MappedMutexGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1) + } +} + +impl<'a, T: ?Sized> Deref for MappedMutexGuard<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.data } + } +} + +impl<'a, T: ?Sized> DerefMut for MappedMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized + fmt::Debug> fmt::Debug for MappedMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized + fmt::Display> fmt::Display for MappedMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} diff --git a/third_party/rust/tokio/src/sync/notify.rs b/third_party/rust/tokio/src/sync/notify.rs new file mode 100644 index 0000000000..83d0de4fbe --- /dev/null +++ b/third_party/rust/tokio/src/sync/notify.rs @@ -0,0 +1,740 @@ +// Allow `unreachable_pub` warnings when sync is not enabled +// due to the usage of `Notify` within the `rt` feature set. +// When this module is compiled with `sync` enabled we will warn on +// this lint. When `rt` is enabled we use `pub(crate)` which +// triggers this warning but it is safe to ignore in this case. +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::util::linked_list::{self, LinkedList}; +use crate::util::WakeList; + +use std::cell::UnsafeCell; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::SeqCst; +use std::task::{Context, Poll, Waker}; + +type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; + +/// Notifies a single task to wake up. +/// +/// `Notify` provides a basic mechanism to notify a single task of an event. +/// `Notify` itself does not carry any data. Instead, it is to be used to signal +/// another task to perform an operation. +/// +/// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. +/// [`notified().await`] waits for a permit to become available, and [`notify_one()`] +/// sets a permit **if there currently are no available permits**. +/// +/// The synchronization details of `Notify` are similar to +/// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`] +/// value contains a single permit. [`notified().await`] waits for the permit to +/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the +/// permit, waking a pending task if there is one. +/// +/// If `notify_one()` is called **before** `notified().await`, then the next call to +/// `notified().await` will complete immediately, consuming the permit. Any +/// subsequent calls to `notified().await` will wait for a new permit. +/// +/// If `notify_one()` is called **multiple** times before `notified().await`, only a +/// **single** permit is stored. The next call to `notified().await` will +/// complete immediately, but the one after will wait for a new permit. +/// +/// # Examples +/// +/// Basic usage. +/// +/// ``` +/// use tokio::sync::Notify; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() { +/// let notify = Arc::new(Notify::new()); +/// let notify2 = notify.clone(); +/// +/// let handle = tokio::spawn(async move { +/// notify2.notified().await; +/// println!("received notification"); +/// }); +/// +/// println!("sending notification"); +/// notify.notify_one(); +/// +/// // Wait for task to receive notification. +/// handle.await.unwrap(); +/// } +/// ``` +/// +/// Unbound mpsc channel. +/// +/// ``` +/// use tokio::sync::Notify; +/// +/// use std::collections::VecDeque; +/// use std::sync::Mutex; +/// +/// struct Channel<T> { +/// values: Mutex<VecDeque<T>>, +/// notify: Notify, +/// } +/// +/// impl<T> Channel<T> { +/// pub fn send(&self, value: T) { +/// self.values.lock().unwrap() +/// .push_back(value); +/// +/// // Notify the consumer a value is available +/// self.notify.notify_one(); +/// } +/// +/// pub async fn recv(&self) -> T { +/// loop { +/// // Drain values +/// if let Some(value) = self.values.lock().unwrap().pop_front() { +/// return value; +/// } +/// +/// // Wait for values to be available +/// self.notify.notified().await; +/// } +/// } +/// } +/// ``` +/// +/// [park]: std::thread::park +/// [unpark]: std::thread::Thread::unpark +/// [`notified().await`]: Notify::notified() +/// [`notify_one()`]: Notify::notify_one() +/// [`Semaphore`]: crate::sync::Semaphore +#[derive(Debug)] +pub struct Notify { + // This uses 2 bits to store one of `EMPTY`, + // `WAITING` or `NOTIFIED`. The rest of the bits + // are used to store the number of times `notify_waiters` + // was called. + state: AtomicUsize, + waiters: Mutex<WaitList>, +} + +#[derive(Debug, Clone, Copy)] +enum NotificationType { + // Notification triggered by calling `notify_waiters` + AllWaiters, + // Notification triggered by calling `notify_one` + OneWaiter, +} + +#[derive(Debug)] +#[repr(C)] // required by `linked_list::Link` impl +struct Waiter { + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers<Waiter>, + + /// Waiting task's waker. + waker: Option<Waker>, + + /// `true` if the notification has been assigned to this waiter. + notified: Option<NotificationType>, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +/// Future returned from [`Notify::notified()`] +#[derive(Debug)] +pub struct Notified<'a> { + /// The `Notify` being received on. + notify: &'a Notify, + + /// The current state of the receiving process. + state: State, + + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, +} + +unsafe impl<'a> Send for Notified<'a> {} +unsafe impl<'a> Sync for Notified<'a> {} + +#[derive(Debug)] +enum State { + Init(usize), + Waiting, + Done, +} + +const NOTIFY_WAITERS_SHIFT: usize = 2; +const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1; +const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK; + +/// Initial "idle" state. +const EMPTY: usize = 0; + +/// One or more threads are currently waiting to be notified. +const WAITING: usize = 1; + +/// Pending notification. +const NOTIFIED: usize = 2; + +fn set_state(data: usize, state: usize) -> usize { + (data & NOTIFY_WAITERS_CALLS_MASK) | (state & STATE_MASK) +} + +fn get_state(data: usize) -> usize { + data & STATE_MASK +} + +fn get_num_notify_waiters_calls(data: usize) -> usize { + (data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT +} + +fn inc_num_notify_waiters_calls(data: usize) -> usize { + data + (1 << NOTIFY_WAITERS_SHIFT) +} + +fn atomic_inc_num_notify_waiters_calls(data: &AtomicUsize) { + data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst); +} + +impl Notify { + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// let notify = Notify::new(); + /// ``` + pub fn new() -> Notify { + Notify { + state: AtomicUsize::new(0), + waiters: Mutex::new(LinkedList::new()), + } + } + + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// static NOTIFY: Notify = Notify::const_new(); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Notify { + Notify { + state: AtomicUsize::new(0), + waiters: Mutex::const_new(LinkedList::new()), + } + } + + /// Wait for a notification. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn notified(&self); + /// ``` + /// + /// Each `Notify` value holds a single permit. If a permit is available from + /// an earlier call to [`notify_one()`], then `notified().await` will complete + /// immediately, consuming that permit. Otherwise, `notified().await` waits + /// for a permit to be made available by the next call to `notify_one()`. + /// + /// [`notify_one()`]: Notify::notify_one + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute notifications in the order + /// they were requested. Cancelling a call to `notified` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// tokio::spawn(async move { + /// notify2.notified().await; + /// println!("received notification"); + /// }); + /// + /// println!("sending notification"); + /// notify.notify_one(); + /// } + /// ``` + pub fn notified(&self) -> Notified<'_> { + // we load the number of times notify_waiters + // was called and store that in our initial state + let state = self.state.load(SeqCst); + Notified { + notify: self, + state: State::Init(state >> NOTIFY_WAITERS_SHIFT), + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + notified: None, + _p: PhantomPinned, + }), + } + } + + /// Notifies a waiting task. + /// + /// If a task is currently waiting, that task is notified. Otherwise, a + /// permit is stored in this `Notify` value and the **next** call to + /// [`notified().await`] will complete immediately consuming the permit made + /// available by this call to `notify_one()`. + /// + /// At most one permit may be stored by `Notify`. Many sequential calls to + /// `notify_one` will result in a single permit being stored. The next call to + /// `notified().await` will complete immediately, but the one after that + /// will wait. + /// + /// [`notified().await`]: Notify::notified() + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// tokio::spawn(async move { + /// notify2.notified().await; + /// println!("received notification"); + /// }); + /// + /// println!("sending notification"); + /// notify.notify_one(); + /// } + /// ``` + // Alias for old name in 0.x + #[cfg_attr(docsrs, doc(alias = "notify"))] + pub fn notify_one(&self) { + // Load the current state + let mut curr = self.state.load(SeqCst); + + // If the state is `EMPTY`, transition to `NOTIFIED` and return. + while let EMPTY | NOTIFIED = get_state(curr) { + // The compare-exchange from `NOTIFIED` -> `NOTIFIED` is intended. A + // happens-before synchronization must happen between this atomic + // operation and a task calling `notified().await`. + let new = set_state(curr, NOTIFIED); + let res = self.state.compare_exchange(curr, new, SeqCst, SeqCst); + + match res { + // No waiters, no further work to do + Ok(_) => return, + Err(actual) => { + curr = actual; + } + } + } + + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + curr = self.state.load(SeqCst); + + if let Some(waker) = notify_locked(&mut waiters, &self.state, curr) { + drop(waiters); + waker.wake(); + } + } + + /// Notifies all waiting tasks. + /// + /// If a task is currently waiting, that task is notified. Unlike with + /// `notify_one()`, no permit is stored to be used by the next call to + /// `notified().await`. The purpose of this method is to notify all + /// already registered waiters. Registering for notification is done by + /// acquiring an instance of the `Notified` future via calling `notified()`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// let notify2 = notify.clone(); + /// + /// let notified1 = notify.notified(); + /// let notified2 = notify.notified(); + /// + /// let handle = tokio::spawn(async move { + /// println!("sending notifications"); + /// notify2.notify_waiters(); + /// }); + /// + /// notified1.await; + /// notified2.await; + /// println!("received notifications"); + /// } + /// ``` + pub fn notify_waiters(&self) { + let mut wakers = WakeList::new(); + + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + let curr = self.state.load(SeqCst); + + if let EMPTY | NOTIFIED = get_state(curr) { + // There are no waiting tasks. All we need to do is increment the + // number of times this method was called. + atomic_inc_num_notify_waiters_calls(&self.state); + return; + } + + // At this point, it is guaranteed that the state will not + // concurrently change, as holding the lock is required to + // transition **out** of `WAITING`. + 'outer: loop { + while wakers.can_push() { + match waiters.pop_back() { + Some(mut waiter) => { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.notified.is_none()); + + waiter.notified = Some(NotificationType::AllWaiters); + + if let Some(waker) = waiter.waker.take() { + wakers.push(waker); + } + } + None => { + break 'outer; + } + } + } + + drop(waiters); + + wakers.wake_all(); + + // Acquire the lock again. + waiters = self.waiters.lock(); + } + + // All waiters will be notified, the state must be transitioned to + // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be + // held, a `store` is sufficient. + let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY); + self.state.store(new, SeqCst); + + // Release the lock before notifying + drop(waiters); + + wakers.wake_all(); + } +} + +impl Default for Notify { + fn default() -> Notify { + Notify::new() + } +} + +fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option<Waker> { + loop { + match get_state(curr) { + EMPTY | NOTIFIED => { + let res = state.compare_exchange(curr, set_state(curr, NOTIFIED), SeqCst, SeqCst); + + match res { + Ok(_) => return None, + Err(actual) => { + let actual_state = get_state(actual); + assert!(actual_state == EMPTY || actual_state == NOTIFIED); + state.store(set_state(actual, NOTIFIED), SeqCst); + return None; + } + } + } + WAITING => { + // At this point, it is guaranteed that the state will not + // concurrently change as holding the lock is required to + // transition **out** of `WAITING`. + // + // Get a pending waiter + let mut waiter = waiters.pop_back().unwrap(); + + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.notified.is_none()); + + waiter.notified = Some(NotificationType::OneWaiter); + let waker = waiter.waker.take(); + + if waiters.is_empty() { + // As this the **final** waiter in the list, the state + // must be transitioned to `EMPTY`. As transitioning + // **from** `WAITING` requires the lock to be held, a + // `store` is sufficient. + state.store(set_state(curr, EMPTY), SeqCst); + } + + return waker; + } + _ => unreachable!(), + } + } +} + +// ===== impl Notified ===== + +impl Notified<'_> { + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) { + unsafe { + // Safety: both `notify` and `state` are `Unpin`. + + is_unpin::<&Notify>(); + is_unpin::<AtomicUsize>(); + + let me = self.get_unchecked_mut(); + (me.notify, &mut me.state, &me.waiter) + } + } +} + +impl Future for Notified<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + use State::*; + + let (notify, state, waiter) = self.project(); + + loop { + match *state { + Init(initial_notify_waiters_calls) => { + let curr = notify.state.load(SeqCst); + + // Optimistically try acquiring a pending notification + let res = notify.state.compare_exchange( + set_state(curr, NOTIFIED), + set_state(curr, EMPTY), + SeqCst, + SeqCst, + ); + + if res.is_ok() { + // Acquired the notification + *state = Done; + return Poll::Ready(()); + } + + // Clone the waker before locking, a waker clone can be + // triggering arbitrary code. + let waker = cx.waker().clone(); + + // Acquire the lock and attempt to transition to the waiting + // state. + let mut waiters = notify.waiters.lock(); + + // Reload the state with the lock held + let mut curr = notify.state.load(SeqCst); + + // if notify_waiters has been called after the future + // was created, then we are done + if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls { + *state = Done; + return Poll::Ready(()); + } + + // Transition the state to WAITING. + loop { + match get_state(curr) { + EMPTY => { + // Transition to WAITING + let res = notify.state.compare_exchange( + set_state(curr, EMPTY), + set_state(curr, WAITING), + SeqCst, + SeqCst, + ); + + if let Err(actual) = res { + assert_eq!(get_state(actual), NOTIFIED); + curr = actual; + } else { + break; + } + } + WAITING => break, + NOTIFIED => { + // Try consuming the notification + let res = notify.state.compare_exchange( + set_state(curr, NOTIFIED), + set_state(curr, EMPTY), + SeqCst, + SeqCst, + ); + + match res { + Ok(_) => { + // Acquired the notification + *state = Done; + return Poll::Ready(()); + } + Err(actual) => { + assert_eq!(get_state(actual), EMPTY); + curr = actual; + } + } + } + _ => unreachable!(), + } + } + + // Safety: called while locked. + unsafe { + (*waiter.get()).waker = Some(waker); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + + *state = Waiting; + + return Poll::Pending; + } + Waiting => { + // Currently in the "Waiting" state, implying the caller has + // a waiter stored in the waiter list (guarded by + // `notify.waiters`). In order to access the waker fields, + // we must hold the lock. + + let waiters = notify.waiters.lock(); + + // Safety: called while locked + let w = unsafe { &mut *waiter.get() }; + + if w.notified.is_some() { + // Our waker has been notified. Reset the fields and + // remove it from the list. + w.waker = None; + w.notified = None; + + *state = Done; + } else { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + + return Poll::Pending; + } + + // Explicit drop of the lock to indicate the scope that the + // lock is held. Because holding the lock is required to + // ensure safe access to fields not held within the lock, it + // is helpful to visualize the scope of the critical + // section. + drop(waiters); + } + Done => { + return Poll::Ready(()); + } + } + } + } +} + +impl Drop for Notified<'_> { + fn drop(&mut self) { + use State::*; + + // Safety: The type only transitions to a "Waiting" state when pinned. + let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() }; + + // This is where we ensure safety. The `Notified` value is being + // dropped, which means we must ensure that the waiter entry is no + // longer stored in the linked list. + if let Waiting = *state { + let mut waiters = notify.waiters.lock(); + let mut notify_state = notify.state.load(SeqCst); + + // remove the entry from the list (if not already removed) + // + // safety: the waiter is only added to `waiters` by virtue of it + // being the only `LinkedList` available to the type. + unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) }; + + if waiters.is_empty() { + if let WAITING = get_state(notify_state) { + notify_state = set_state(notify_state, EMPTY); + notify.state.store(notify_state, SeqCst); + } + } + + // See if the node was notified but not received. In this case, if + // the notification was triggered via `notify_one`, it must be sent + // to the next waiter. + // + // Safety: with the entry removed from the linked list, there can be + // no concurrent access to the entry + if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } { + if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { + drop(waiters); + waker.wake(); + } + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + target.cast() + } +} + +fn is_unpin<T: Unpin>() {} diff --git a/third_party/rust/tokio/src/sync/once_cell.rs b/third_party/rust/tokio/src/sync/once_cell.rs new file mode 100644 index 0000000000..d31a40e2c8 --- /dev/null +++ b/third_party/rust/tokio/src/sync/once_cell.rs @@ -0,0 +1,457 @@ +use super::{Semaphore, SemaphorePermit, TryAcquireError}; +use crate::loom::cell::UnsafeCell; +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::ops::Drop; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; + +// This file contains an implementation of an OnceCell. The principle +// behind the safety the of the cell is that any thread with an `&OnceCell` may +// access the `value` field according the following rules: +// +// 1. When `value_set` is false, the `value` field may be modified by the +// thread holding the permit on the semaphore. +// 2. When `value_set` is true, the `value` field may be accessed immutably by +// any thread. +// +// It is an invariant that if the semaphore is closed, then `value_set` is true. +// The reverse does not necessarily hold — but if not, the semaphore may not +// have any available permits. +// +// A thread with a `&mut OnceCell` may modify the value in any way it wants as +// long as the invariants are upheld. + +/// A thread-safe cell that can be written to only once. +/// +/// A `OnceCell` is typically used for global variables that need to be +/// initialized once on first use, but need no further changes. The `OnceCell` +/// in Tokio allows the initialization procedure to be asynchronous. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// async fn some_computation() -> u32 { +/// 1 + 1 +/// } +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// #[tokio::main] +/// async fn main() { +/// let result = ONCE.get_or_init(some_computation).await; +/// assert_eq!(*result, 2); +/// } +/// ``` +/// +/// It is often useful to write a wrapper method for accessing the value. +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// async fn get_global_integer() -> &'static u32 { +/// ONCE.get_or_init(|| async { +/// 1 + 1 +/// }).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = get_global_integer().await; +/// assert_eq!(*result, 2); +/// } +/// ``` +pub struct OnceCell<T> { + value_set: AtomicBool, + value: UnsafeCell<MaybeUninit<T>>, + semaphore: Semaphore, +} + +impl<T> Default for OnceCell<T> { + fn default() -> OnceCell<T> { + OnceCell::new() + } +} + +impl<T: fmt::Debug> fmt::Debug for OnceCell<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OnceCell") + .field("value", &self.get()) + .finish() + } +} + +impl<T: Clone> Clone for OnceCell<T> { + fn clone(&self) -> OnceCell<T> { + OnceCell::new_with(self.get().cloned()) + } +} + +impl<T: PartialEq> PartialEq for OnceCell<T> { + fn eq(&self, other: &OnceCell<T>) -> bool { + self.get() == other.get() + } +} + +impl<T: Eq> Eq for OnceCell<T> {} + +impl<T> Drop for OnceCell<T> { + fn drop(&mut self) { + if self.initialized_mut() { + unsafe { + self.value + .with_mut(|ptr| ptr::drop_in_place((&mut *ptr).as_mut_ptr())); + }; + } + } +} + +impl<T> From<T> for OnceCell<T> { + fn from(value: T) -> Self { + let semaphore = Semaphore::new(0); + semaphore.close(); + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(value)), + semaphore, + } + } +} + +impl<T> OnceCell<T> { + /// Creates a new empty `OnceCell` instance. + pub fn new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::new(1), + } + } + + /// Creates a new `OnceCell` that contains the provided value, if any. + /// + /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. + /// + /// [`OnceCell::new`]: crate::sync::OnceCell::new + pub fn new_with(value: Option<T>) -> Self { + if let Some(v) = value { + OnceCell::from(v) + } else { + OnceCell::new() + } + } + + /// Creates a new empty `OnceCell` instance. + /// + /// Equivalent to `OnceCell::new`, except that it can be used in static + /// variables. + /// + /// # Example + /// + /// ``` + /// use tokio::sync::OnceCell; + /// + /// static ONCE: OnceCell<u32> = OnceCell::const_new(); + /// + /// async fn get_global_integer() -> &'static u32 { + /// ONCE.get_or_init(|| async { + /// 1 + 1 + /// }).await + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let result = get_global_integer().await; + /// assert_eq!(*result, 2); + /// } + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::const_new(1), + } + } + + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + pub fn initialized(&self) -> bool { + // Using acquire ordering so any threads that read a true from this + // atomic is able to read the value. + self.value_set.load(Ordering::Acquire) + } + + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + fn initialized_mut(&mut self) -> bool { + *self.value_set.get_mut() + } + + // SAFETY: The OnceCell must not be empty. + unsafe fn get_unchecked(&self) -> &T { + &*self.value.with(|ptr| (*ptr).as_ptr()) + } + + // SAFETY: The OnceCell must not be empty. + unsafe fn get_unchecked_mut(&mut self) -> &mut T { + &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) + } + + fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { + // SAFETY: We are holding the only permit on the semaphore. + unsafe { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + } + + // Using release ordering so any threads that read a true from this + // atomic is able to read the value we just stored. + self.value_set.store(true, Ordering::Release); + self.semaphore.close(); + permit.forget(); + + // SAFETY: We just initialized the cell. + unsafe { self.get_unchecked() } + } + + /// Returns a reference to the value currently stored in the `OnceCell`, or + /// `None` if the `OnceCell` is empty. + pub fn get(&self) -> Option<&T> { + if self.initialized() { + Some(unsafe { self.get_unchecked() }) + } else { + None + } + } + + /// Returns a mutable reference to the value currently stored in the + /// `OnceCell`, or `None` if the `OnceCell` is empty. + /// + /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the + /// value inside the `OnceCell` — the mutable borrow statically guarantees + /// no other references exist. + pub fn get_mut(&mut self) -> Option<&mut T> { + if self.initialized_mut() { + Some(unsafe { self.get_unchecked_mut() }) + } else { + None + } + } + + /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is + /// empty. + /// + /// If the `OnceCell` already has a value, this call will fail with an + /// [`SetError::AlreadyInitializedError`]. + /// + /// If the `OnceCell` is empty, but some other task is currently trying to + /// set the value, this call will fail with [`SetError::InitializingError`]. + /// + /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError + /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError + pub fn set(&self, value: T) -> Result<(), SetError<T>> { + if self.initialized() { + return Err(SetError::AlreadyInitializedError(value)); + } + + // Another task might be initializing the cell, in which case + // `try_acquire` will return an error. If we succeed to acquire the + // permit, then we can set the value. + match self.semaphore.try_acquire() { + Ok(permit) => { + debug_assert!(!self.initialized()); + self.set_value(value, permit); + Ok(()) + } + Err(TryAcquireError::NoPermits) => { + // Some other task is holding the permit. That task is + // currently trying to initialize the value. + Err(SetError::InitializingError(value)) + } + Err(TryAcquireError::Closed) => { + // The semaphore was closed. Some other task has initialized + // the value. + Err(SetError::AlreadyInitializedError(value)) + } + } + } + + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation is cancelled or panics, the initialization + /// attempt is cancelled. If there are other tasks waiting for the value to + /// be initialized, one of them will start another attempt at initializing + /// the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. + pub async fn get_or_init<F, Fut>(&self, f: F) -> &T + where + F: FnOnce() -> Fut, + Fut: Future<Output = T>, + { + if self.initialized() { + // SAFETY: The OnceCell has been fully initialized. + unsafe { self.get_unchecked() } + } else { + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. + match self.semaphore.acquire().await { + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_init` call is aborted and the semaphore permit is + // dropped. + let value = f().await; + + self.set_value(value, permit) + } + Err(_) => { + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { self.get_unchecked() } + } + } + } + } + + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation returns an error, is cancelled or panics, the + /// initialization attempt is cancelled. If there are other tasks waiting + /// for the value to be initialized, one of them will start another attempt + /// at initializing the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. + pub async fn get_or_try_init<E, F, Fut>(&self, f: F) -> Result<&T, E> + where + F: FnOnce() -> Fut, + Fut: Future<Output = Result<T, E>>, + { + if self.initialized() { + // SAFETY: The OnceCell has been fully initialized. + unsafe { Ok(self.get_unchecked()) } + } else { + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. + match self.semaphore.acquire().await { + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_try_init` call is aborted and the semaphore + // permit is dropped. + let value = f().await; + + match value { + Ok(value) => Ok(self.set_value(value, permit)), + Err(e) => Err(e), + } + } + Err(_) => { + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { Ok(self.get_unchecked()) } + } + } + } + } + + /// Takes the value from the cell, destroying the cell in the process. + /// Returns `None` if the cell is empty. + pub fn into_inner(mut self) -> Option<T> { + if self.initialized_mut() { + // Set to uninitialized for the destructor of `OnceCell` to work properly + *self.value_set.get_mut() = false; + Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) + } else { + None + } + } + + /// Takes ownership of the current value, leaving the cell empty. Returns + /// `None` if the cell is empty. + pub fn take(&mut self) -> Option<T> { + std::mem::take(self).into_inner() + } +} + +// Since `get` gives us access to immutable references of the OnceCell, OnceCell +// can only be Sync if T is Sync, otherwise OnceCell would allow sharing +// references of !Sync values across threads. We need T to be Send in order for +// OnceCell to by Sync because we can use `set` on `&OnceCell<T>` to send values +// (of type T) across threads. +unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} + +// Access to OnceCell's value is guarded by the semaphore permit +// and atomic operations on `value_set`, so as long as T itself is Send +// it's safe to send it to another thread +unsafe impl<T: Send> Send for OnceCell<T> {} + +/// Errors that can be returned from [`OnceCell::set`]. +/// +/// [`OnceCell::set`]: crate::sync::OnceCell::set +#[derive(Debug, PartialEq)] +pub enum SetError<T> { + /// The cell was already initialized when [`OnceCell::set`] was called. + /// + /// [`OnceCell::set`]: crate::sync::OnceCell::set + AlreadyInitializedError(T), + + /// The cell is currently being initialized. + InitializingError(T), +} + +impl<T> fmt::Display for SetError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SetError::AlreadyInitializedError(_) => write!(f, "AlreadyInitializedError"), + SetError::InitializingError(_) => write!(f, "InitializingError"), + } + } +} + +impl<T: fmt::Debug> Error for SetError<T> {} + +impl<T> SetError<T> { + /// Whether `SetError` is `SetError::AlreadyInitializedError`. + pub fn is_already_init_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => true, + SetError::InitializingError(_) => false, + } + } + + /// Whether `SetError` is `SetError::InitializingError` + pub fn is_initializing_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => false, + SetError::InitializingError(_) => true, + } + } +} diff --git a/third_party/rust/tokio/src/sync/oneshot.rs b/third_party/rust/tokio/src/sync/oneshot.rs new file mode 100644 index 0000000000..2240074e73 --- /dev/null +++ b/third_party/rust/tokio/src/sync/oneshot.rs @@ -0,0 +1,1366 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A one-shot channel is used for sending a single message between +//! asynchronous tasks. The [`channel`] function is used to create a +//! [`Sender`] and [`Receiver`] handle pair that form the channel. +//! +//! The `Sender` handle is used by the producer to send the value. +//! The `Receiver` handle is used by the consumer to receive the value. +//! +//! Each handle can be used on separate tasks. +//! +//! Since the `send` method is not async, it can be used anywhere. This includes +//! sending between two runtimes, and using it from non-async code. +//! +//! # Examples +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel(); +//! +//! tokio::spawn(async move { +//! if let Err(_) = tx.send(3) { +//! println!("the receiver dropped"); +//! } +//! }); +//! +//! match rx.await { +//! Ok(v) => println!("got = {:?}", v), +//! Err(_) => println!("the sender dropped"), +//! } +//! } +//! ``` +//! +//! If the sender is dropped without sending, the receiver will fail with +//! [`error::RecvError`]: +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! #[tokio::main] +//! async fn main() { +//! let (tx, rx) = oneshot::channel::<u32>(); +//! +//! tokio::spawn(async move { +//! drop(tx); +//! }); +//! +//! match rx.await { +//! Ok(_) => panic!("This doesn't happen"), +//! Err(_) => println!("the sender dropped"), +//! } +//! } +//! ``` +//! +//! To use a oneshot channel in a `tokio::select!` loop, add `&mut` in front of +//! the channel. +//! +//! ``` +//! use tokio::sync::oneshot; +//! use tokio::time::{interval, sleep, Duration}; +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread", start_paused = true)] +//! async fn main() { +//! let (send, mut recv) = oneshot::channel(); +//! let mut interval = interval(Duration::from_millis(100)); +//! +//! # let handle = +//! tokio::spawn(async move { +//! sleep(Duration::from_secs(1)).await; +//! send.send("shut down").unwrap(); +//! }); +//! +//! loop { +//! tokio::select! { +//! _ = interval.tick() => println!("Another 100ms"), +//! msg = &mut recv => { +//! println!("Got message: {}", msg.unwrap()); +//! break; +//! } +//! } +//! } +//! # handle.await.unwrap(); +//! } +//! ``` +//! +//! To use a `Sender` from a destructor, put it in an [`Option`] and call +//! [`Option::take`]. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! struct SendOnDrop { +//! sender: Option<oneshot::Sender<&'static str>>, +//! } +//! impl Drop for SendOnDrop { +//! fn drop(&mut self) { +//! if let Some(sender) = self.sender.take() { +//! // Using `let _ =` to ignore send errors. +//! let _ = sender.send("I got dropped!"); +//! } +//! } +//! } +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread")] +//! async fn main() { +//! let (send, recv) = oneshot::channel(); +//! +//! let send_on_drop = SendOnDrop { sender: Some(send) }; +//! drop(send_on_drop); +//! +//! assert_eq!(recv.await, Ok("I got dropped!")); +//! } +//! ``` + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Arc; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; + +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::sync::atomic::Ordering::{self, AcqRel, Acquire}; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll, Waker}; + +/// Sends a value to the associated [`Receiver`]. +/// +/// A pair of both a [`Sender`] and a [`Receiver`] are created by the +/// [`channel`](fn@channel) function. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Sender` from a destructor, put it in an [`Option`] and call +/// [`Option::take`]. +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// struct SendOnDrop { +/// sender: Option<oneshot::Sender<&'static str>>, +/// } +/// impl Drop for SendOnDrop { +/// fn drop(&mut self) { +/// if let Some(sender) = self.sender.take() { +/// // Using `let _ =` to ignore send errors. +/// let _ = sender.send("I got dropped!"); +/// } +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let (send, recv) = oneshot::channel(); +/// +/// let send_on_drop = SendOnDrop { sender: Some(send) }; +/// drop(send_on_drop); +/// +/// assert_eq!(recv.await, Ok("I got dropped!")); +/// } +/// ``` +/// +/// [`Option`]: std::option::Option +/// [`Option::take`]: std::option::Option::take +#[derive(Debug)] +pub struct Sender<T> { + inner: Option<Arc<Inner<T>>>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +/// Receives a value from the associated [`Sender`]. +/// +/// A pair of both a [`Sender`] and a [`Receiver`] are created by the +/// [`channel`](fn@channel) function. +/// +/// This channel has no `recv` method because the receiver itself implements the +/// [`Future`] trait. To receive a value, `.await` the `Receiver` object directly. +/// +/// [`Future`]: trait@std::future::Future +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Receiver` in a `tokio::select!` loop, add `&mut` in front of the +/// channel. +/// +/// ``` +/// use tokio::sync::oneshot; +/// use tokio::time::{interval, sleep, Duration}; +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let (send, mut recv) = oneshot::channel(); +/// let mut interval = interval(Duration::from_millis(100)); +/// +/// # let handle = +/// tokio::spawn(async move { +/// sleep(Duration::from_secs(1)).await; +/// send.send("shut down").unwrap(); +/// }); +/// +/// loop { +/// tokio::select! { +/// _ = interval.tick() => println!("Another 100ms"), +/// msg = &mut recv => { +/// println!("Got message: {}", msg.unwrap()); +/// break; +/// } +/// } +/// } +/// # handle.await.unwrap(); +/// } +/// ``` +#[derive(Debug)] +pub struct Receiver<T> { + inner: Option<Arc<Inner<T>>>, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_span: tracing::Span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_poll_span: tracing::Span, +} + +pub mod error { + //! Oneshot error types. + + use std::fmt; + + /// Error returned by the `Future` implementation for `Receiver`. + #[derive(Debug, Eq, PartialEq)] + pub struct RecvError(pub(super) ()); + + /// Error returned by the `try_recv` function on `Receiver`. + #[derive(Debug, Eq, PartialEq)] + pub enum TryRecvError { + /// The send half of the channel has not yet sent a value. + Empty, + + /// The send half of the channel was dropped without sending a value. + Closed, + } + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} + + // ===== impl TryRecvError ===== + + impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(fmt, "channel empty"), + TryRecvError::Closed => write!(fmt, "channel closed"), + } + } + } + + impl std::error::Error for TryRecvError {} +} + +use self::error::*; + +struct Inner<T> { + /// Manages the state of the inner cell. + state: AtomicUsize, + + /// The value. This is set by `Sender` and read by `Receiver`. The state of + /// the cell is tracked by `state`. + value: UnsafeCell<Option<T>>, + + /// The task to notify when the receiver drops without consuming the value. + /// + /// ## Safety + /// + /// The `TX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. + tx_task: Task, + + /// The task to notify when the value is sent. + /// + /// ## Safety + /// + /// The `RX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. + rx_task: Task, +} + +struct Task(UnsafeCell<MaybeUninit<Waker>>); + +impl Task { + unsafe fn will_wake(&self, cx: &mut Context<'_>) -> bool { + self.with_task(|w| w.will_wake(cx.waker())) + } + + unsafe fn with_task<F, R>(&self, f: F) -> R + where + F: FnOnce(&Waker) -> R, + { + self.0.with(|ptr| { + let waker: *const Waker = (&*ptr).as_ptr(); + f(&*waker) + }) + } + + unsafe fn drop_task(&self) { + self.0.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.drop_in_place(); + }); + } + + unsafe fn set_task(&self, cx: &mut Context<'_>) { + self.0.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.write(cx.waker().clone()); + }); + } +} + +#[derive(Clone, Copy)] +struct State(usize); + +/// Creates a new one-shot channel for sending single values across asynchronous +/// tasks. +/// +/// The function returns separate "send" and "receive" handles. The `Sender` +/// handle is used by the producer to send the value. The `Receiver` handle is +/// used by the consumer to receive the value. +/// +/// Each handle can be used on separate tasks. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +#[track_caller] +pub fn channel<T>() -> (Sender<T>, Receiver<T>) { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Sender|Receiver", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + tx_dropped = false, + tx_dropped.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = false, + rx_dropped.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_sent = false, + value_sent.op = "override", + ) + }); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_received = false, + value_received.op = "override", + ) + }); + + resource_span + }; + + let inner = Arc::new(Inner { + state: AtomicUsize::new(State::new().as_usize()), + value: UnsafeCell::new(None), + tx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), + rx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), + }); + + let tx = Sender { + inner: Some(inner.clone()), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let async_op_span = resource_span + .in_scope(|| tracing::trace_span!("runtime.resource.async_op", source = "Receiver::await")); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let async_op_poll_span = + async_op_span.in_scope(|| tracing::trace_span!("runtime.resource.async_op.poll")); + + let rx = Receiver { + inner: Some(inner), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: resource_span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_span, + #[cfg(all(tokio_unstable, feature = "tracing"))] + async_op_poll_span, + }; + + (tx, rx) +} + +impl<T> Sender<T> { + /// Attempts to send a value on this channel, returning it back if it could + /// not be sent. + /// + /// This method consumes `self` as only one value may ever be sent on a oneshot + /// channel. It is not marked async because sending a message to an oneshot + /// channel never requires any form of waiting. Because of this, the `send` + /// method can be used in both synchronous and asynchronous code without + /// problems. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been deallocated. Note that a + /// return value of `Err` means that the data will never be received, but + /// a return value of `Ok` does *not* mean that the data will be received. + /// It is possible for the corresponding receiver to hang up immediately + /// after this function returns `Ok`. + /// + /// # Examples + /// + /// Send a value to another task + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel(); + /// + /// tokio::spawn(async move { + /// if let Err(_) = tx.send(3) { + /// println!("the receiver dropped"); + /// } + /// }); + /// + /// match rx.await { + /// Ok(v) => println!("got = {:?}", v), + /// Err(_) => println!("the sender dropped"), + /// } + /// } + /// ``` + pub fn send(mut self, t: T) -> Result<(), T> { + let inner = self.inner.take().unwrap(); + + inner.value.with_mut(|ptr| unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless the + // channel has been marked as "complete" (the `VALUE_SENT` state bit + // is set). + // That bit is only set by the sender later on in this method, and + // calling this method consumes `self`. Therefore, if it was possible to + // call this method, we know that the `VALUE_SENT` bit is unset, and + // the receiver is not currently accessing the `UnsafeCell`. + *ptr = Some(t); + }); + + if !inner.complete() { + unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless + // the channel has been marked as "complete". Calling + // `complete()` will return true if this bit is set, and false + // if it is not set. Thus, if `complete()` returned false, it is + // safe for us to access the value, because we know that the + // receiver will not. + return Err(inner.consume_value().unwrap()); + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_sent = true, + value_sent.op = "override", + ) + }); + + Ok(()) + } + + /// Waits for the associated [`Receiver`] handle to close. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly or the + /// [`Receiver`] value is dropped. + /// + /// This function is useful when paired with `select!` to abort a + /// computation when the receiver is no longer interested in the result. + /// + /// # Return + /// + /// Returns a `Future` which must be awaited on. + /// + /// [`Receiver`]: Receiver + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// Basic usage + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, rx) = oneshot::channel::<()>(); + /// + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// tx.closed().await; + /// println!("the receiver dropped"); + /// } + /// ``` + /// + /// Paired with select + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::time::{self, Duration}; + /// + /// async fn compute() -> String { + /// // Complex computation returning a `String` + /// # "hello".to_string() + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, rx) = oneshot::channel(); + /// + /// tokio::spawn(async move { + /// tokio::select! { + /// _ = tx.closed() => { + /// // The receiver dropped, no need to do any further work + /// } + /// value = compute() => { + /// // The send can fail if the channel was closed at the exact same + /// // time as when compute() finished, so just ignore the failure. + /// let _ = tx.send(value); + /// } + /// } + /// }); + /// + /// // Wait for up to 10 seconds + /// let _ = time::timeout(Duration::from_secs(10), rx).await; + /// } + /// ``` + pub async fn closed(&mut self) { + use crate::future::poll_fn; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let closed = trace::async_op( + || poll_fn(|cx| self.poll_closed(cx)), + resource_span, + "Sender::closed", + "poll_closed", + false, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let closed = poll_fn(|cx| self.poll_closed(cx)); + + closed.await + } + + /// Returns `true` if the associated [`Receiver`] handle has been dropped. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly or the + /// [`Receiver`] value is dropped. + /// + /// If `true` is returned, a call to `send` will always result in an error. + /// + /// [`Receiver`]: Receiver + /// [`close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel(); + /// + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// + /// assert!(tx.is_closed()); + /// assert!(tx.send("never received").is_err()); + /// } + /// ``` + pub fn is_closed(&self) -> bool { + let inner = self.inner.as_ref().unwrap(); + + let state = State::load(&inner.state, Acquire); + state.is_closed() + } + + /// Checks whether the oneshot channel has been closed, and if not, schedules the + /// `Waker` in the provided `Context` to receive a notification when the channel is + /// closed. + /// + /// A [`Receiver`] is closed by either calling [`close`] explicitly, or when the + /// [`Receiver`] value is dropped. + /// + /// Note that on multiple calls to poll, only the `Waker` from the `Context` passed + /// to the most recent call will be scheduled to receive a wakeup. + /// + /// [`Receiver`]: struct@crate::sync::oneshot::Receiver + /// [`close`]: fn@crate::sync::oneshot::Receiver::close + /// + /// # Return value + /// + /// This function returns: + /// + /// * `Poll::Pending` if the channel is still open. + /// * `Poll::Ready(())` if the channel is closed. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() { + /// let (mut tx, mut rx) = oneshot::channel::<()>(); + /// + /// tokio::spawn(async move { + /// rx.close(); + /// }); + /// + /// poll_fn(|cx| tx.poll_closed(cx)).await; + /// + /// println!("the receiver dropped"); + /// } + /// ``` + pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + let inner = self.inner.as_ref().unwrap(); + + let mut state = State::load(&inner.state, Acquire); + + if state.is_closed() { + coop.made_progress(); + return Poll::Ready(()); + } + + if state.is_tx_task_set() { + let will_notify = unsafe { inner.tx_task.will_wake(cx) }; + + if !will_notify { + state = State::unset_tx_task(&inner.state); + + if state.is_closed() { + // Set the flag again so that the waker is released in drop + State::set_tx_task(&inner.state); + coop.made_progress(); + return Ready(()); + } else { + unsafe { inner.tx_task.drop_task() }; + } + } + } + + if !state.is_tx_task_set() { + // Attempt to set the task + unsafe { + inner.tx_task.set_task(cx); + } + + // Update the state + state = State::set_tx_task(&inner.state); + + if state.is_closed() { + coop.made_progress(); + return Ready(()); + } + } + + Pending + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.complete(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + tx_dropped = true, + tx_dropped.op = "override", + ) + }); + } + } +} + +impl<T> Receiver<T> { + /// Prevents the associated [`Sender`] handle from sending a value. + /// + /// Any `send` operation which happens after calling `close` is guaranteed + /// to fail. After calling `close`, [`try_recv`] should be called to + /// receive a value if one was sent **before** the call to `close` + /// completed. + /// + /// This function is useful to perform a graceful shutdown and ensure that a + /// value will not be sent into the channel and never received. + /// + /// `close` is no-op if a message is already received or the channel + /// is already closed. + /// + /// [`Sender`]: Sender + /// [`try_recv`]: Receiver::try_recv + /// + /// # Examples + /// + /// Prevent a value from being sent + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// assert!(!tx.is_closed()); + /// + /// rx.close(); + /// + /// assert!(tx.is_closed()); + /// assert!(tx.send("never received").is_err()); + /// + /// match rx.try_recv() { + /// Err(TryRecvError::Closed) => {} + /// _ => unreachable!(), + /// } + /// } + /// ``` + /// + /// Receive a value sent **before** calling `close` + /// + /// ``` + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// assert!(tx.send("will receive").is_ok()); + /// + /// rx.close(); + /// + /// let msg = rx.try_recv().unwrap(); + /// assert_eq!(msg, "will receive"); + /// } + /// ``` + pub fn close(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.close(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = true, + rx_dropped.op = "override", + ) + }); + } + } + + /// Attempts to receive a value. + /// + /// If a pending value exists in the channel, it is returned. If no value + /// has been sent, the current task **will not** be registered for + /// future notification. + /// + /// This function is useful to call from outside the context of an + /// asynchronous task. + /// + /// # Return + /// + /// - `Ok(T)` if a value is pending in the channel. + /// - `Err(TryRecvError::Empty)` if no value has been sent yet. + /// - `Err(TryRecvError::Closed)` if the sender has dropped without sending + /// a value. + /// + /// # Examples + /// + /// `try_recv` before a value is sent, then after. + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel(); + /// + /// match rx.try_recv() { + /// // The channel is currently empty + /// Err(TryRecvError::Empty) => {} + /// _ => unreachable!(), + /// } + /// + /// // Send a value + /// tx.send("hello").unwrap(); + /// + /// match rx.try_recv() { + /// Ok(value) => assert_eq!(value, "hello"), + /// _ => unreachable!(), + /// } + /// } + /// ``` + /// + /// `try_recv` when the sender dropped before sending a value + /// + /// ``` + /// use tokio::sync::oneshot; + /// use tokio::sync::oneshot::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = oneshot::channel::<()>(); + /// + /// drop(tx); + /// + /// match rx.try_recv() { + /// // The channel will never receive a value. + /// Err(TryRecvError::Closed) => {} + /// _ => unreachable!(), + /// } + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let result = if let Some(inner) = self.inner.as_ref() { + let state = State::load(&inner.state, Acquire); + + if state.is_complete() { + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. + match unsafe { inner.consume_value() } { + Some(value) => { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + value_received = true, + value_received.op = "override", + ) + }); + Ok(value) + } + None => Err(TryRecvError::Closed), + } + } else if state.is_closed() { + Err(TryRecvError::Closed) + } else { + // Not ready, this does not clear `inner` + return Err(TryRecvError::Empty); + } + } else { + Err(TryRecvError::Closed) + }; + + self.inner = None; + result + } + + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::sync::oneshot; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = oneshot::channel::<u8>(); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Ok(10), rx.blocking_recv()); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(self) -> Result<T, RecvError> { + crate::future::block_on(self) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + if let Some(inner) = self.inner.as_ref() { + inner.close(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + rx_dropped = true, + rx_dropped.op = "override", + ) + }); + } + } +} + +impl<T> Future for Receiver<T> { + type Output = Result<T, RecvError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + // If `inner` is `None`, then `poll()` has already completed. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _res_span = self.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_span = self.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_poll_span = self.async_op_poll_span.clone().entered(); + + let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let res = ready!(trace_poll_op!("poll_recv", inner.poll_recv(cx)))?; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let res = ready!(inner.poll_recv(cx))?; + + res + } else { + panic!("called after complete"); + }; + + self.inner = None; + Ready(Ok(ret)) + } +} + +impl<T> Inner<T> { + fn complete(&self) -> bool { + let prev = State::set_complete(&self.state); + + if prev.is_closed() { + return false; + } + + if prev.is_rx_task_set() { + // TODO: Consume waker? + unsafe { + self.rx_task.with_task(Waker::wake_by_ref); + } + } + + true + } + + fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + // Load the state + let mut state = State::load(&self.state, Acquire); + + if state.is_complete() { + coop.made_progress(); + match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + } + } else if state.is_closed() { + coop.made_progress(); + Ready(Err(RecvError(()))) + } else { + if state.is_rx_task_set() { + let will_notify = unsafe { self.rx_task.will_wake(cx) }; + + // Check if the task is still the same + if !will_notify { + // Unset the task + state = State::unset_rx_task(&self.state); + if state.is_complete() { + // Set the flag again so that the waker is released in drop + State::set_rx_task(&self.state); + + coop.made_progress(); + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. + return match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + }; + } else { + unsafe { self.rx_task.drop_task() }; + } + } + } + + if !state.is_rx_task_set() { + // Attempt to set the task + unsafe { + self.rx_task.set_task(cx); + } + + // Update the state + state = State::set_rx_task(&self.state); + + if state.is_complete() { + coop.made_progress(); + match unsafe { self.consume_value() } { + Some(value) => Ready(Ok(value)), + None => Ready(Err(RecvError(()))), + } + } else { + Pending + } + } else { + Pending + } + } + } + + /// Called by `Receiver` to indicate that the value will never be received. + fn close(&self) { + let prev = State::set_closed(&self.state); + + if prev.is_tx_task_set() && !prev.is_complete() { + unsafe { + self.tx_task.with_task(Waker::wake_by_ref); + } + } + } + + /// Consumes the value. This function does not check `state`. + /// + /// # Safety + /// + /// Calling this method concurrently on multiple threads will result in a + /// data race. The `VALUE_SENT` state bit is used to ensure that only the + /// sender *or* the receiver will call this method at a given point in time. + /// If `VALUE_SENT` is not set, then only the sender may call this method; + /// if it is set, then only the receiver may call this method. + unsafe fn consume_value(&self) -> Option<T> { + self.value.with_mut(|ptr| (*ptr).take()) + } +} + +unsafe impl<T: Send> Send for Inner<T> {} +unsafe impl<T: Send> Sync for Inner<T> {} + +fn mut_load(this: &mut AtomicUsize) -> usize { + this.with_mut(|v| *v) +} + +impl<T> Drop for Inner<T> { + fn drop(&mut self) { + let state = State(mut_load(&mut self.state)); + + if state.is_rx_task_set() { + unsafe { + self.rx_task.drop_task(); + } + } + + if state.is_tx_task_set() { + unsafe { + self.tx_task.drop_task(); + } + } + } +} + +impl<T: fmt::Debug> fmt::Debug for Inner<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + use std::sync::atomic::Ordering::Relaxed; + + fmt.debug_struct("Inner") + .field("state", &State::load(&self.state, Relaxed)) + .finish() + } +} + +/// Indicates that a waker for the receiving task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `rx_task` field may be uninitialized. +const RX_TASK_SET: usize = 0b00001; +/// Indicates that a value has been stored in the channel's inner `UnsafeCell`. +/// +/// # Safety +/// +/// This bit controls which side of the channel is permitted to access the +/// `UnsafeCell`. If it is set, the `UnsafeCell` may ONLY be accessed by the +/// receiver. If this bit is NOT set, the `UnsafeCell` may ONLY be accessed by +/// the sender. +const VALUE_SENT: usize = 0b00010; +const CLOSED: usize = 0b00100; + +/// Indicates that a waker for the sending task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `tx_task` field may be uninitialized. +const TX_TASK_SET: usize = 0b01000; + +impl State { + fn new() -> State { + State(0) + } + + fn is_complete(self) -> bool { + self.0 & VALUE_SENT == VALUE_SENT + } + + fn set_complete(cell: &AtomicUsize) -> State { + // This method is a compare-and-swap loop rather than a fetch-or like + // other `set_$WHATEVER` methods on `State`. This is because we must + // check if the state has been closed before setting the `VALUE_SENT` + // bit. + // + // We don't want to set both the `VALUE_SENT` bit if the `CLOSED` + // bit is already set, because `VALUE_SENT` will tell the receiver that + // it's okay to access the inner `UnsafeCell`. Immediately after calling + // `set_complete`, if the channel was closed, the sender will _also_ + // access the `UnsafeCell` to take the value back out, so if a + // `poll_recv` or `try_recv` call is occurring concurrently, both + // threads may try to access the `UnsafeCell` if we were to set the + // `VALUE_SENT` bit on a closed channel. + let mut state = cell.load(Ordering::Relaxed); + loop { + if State(state).is_closed() { + break; + } + // TODO: This could be `Release`, followed by an `Acquire` fence *if* + // the `RX_TASK_SET` flag is set. However, `loom` does not support + // fences yet. + match cell.compare_exchange_weak( + state, + state | VALUE_SENT, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(actual) => state = actual, + } + } + State(state) + } + + fn is_rx_task_set(self) -> bool { + self.0 & RX_TASK_SET == RX_TASK_SET + } + + fn set_rx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_or(RX_TASK_SET, AcqRel); + State(val | RX_TASK_SET) + } + + fn unset_rx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_and(!RX_TASK_SET, AcqRel); + State(val & !RX_TASK_SET) + } + + fn is_closed(self) -> bool { + self.0 & CLOSED == CLOSED + } + + fn set_closed(cell: &AtomicUsize) -> State { + // Acquire because we want all later writes (attempting to poll) to be + // ordered after this. + let val = cell.fetch_or(CLOSED, Acquire); + State(val) + } + + fn set_tx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_or(TX_TASK_SET, AcqRel); + State(val | TX_TASK_SET) + } + + fn unset_tx_task(cell: &AtomicUsize) -> State { + let val = cell.fetch_and(!TX_TASK_SET, AcqRel); + State(val & !TX_TASK_SET) + } + + fn is_tx_task_set(self) -> bool { + self.0 & TX_TASK_SET == TX_TASK_SET + } + + fn as_usize(self) -> usize { + self.0 + } + + fn load(cell: &AtomicUsize, order: Ordering) -> State { + let val = cell.load(order); + State(val) + } +} + +impl fmt::Debug for State { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("State") + .field("is_complete", &self.is_complete()) + .field("is_closed", &self.is_closed()) + .field("is_rx_task_set", &self.is_rx_task_set()) + .field("is_tx_task_set", &self.is_tx_task_set()) + .finish() + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock.rs b/third_party/rust/tokio/src/sync/rwlock.rs new file mode 100644 index 0000000000..b856cfc856 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock.rs @@ -0,0 +1,1078 @@ +use crate::sync::batch_semaphore::{Semaphore, TryAcquireError}; +use crate::sync::mutex::TryLockError; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use std::cell::UnsafeCell; +use std::marker; +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::sync::Arc; + +pub(crate) mod owned_read_guard; +pub(crate) mod owned_write_guard; +pub(crate) mod owned_write_guard_mapped; +pub(crate) mod read_guard; +pub(crate) mod write_guard; +pub(crate) mod write_guard_mapped; +pub(crate) use owned_read_guard::OwnedRwLockReadGuard; +pub(crate) use owned_write_guard::OwnedRwLockWriteGuard; +pub(crate) use owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; +pub(crate) use read_guard::RwLockReadGuard; +pub(crate) use write_guard::RwLockWriteGuard; +pub(crate) use write_guard_mapped::RwLockMappedWriteGuard; + +#[cfg(not(loom))] +const MAX_READS: u32 = std::u32::MAX >> 3; + +#[cfg(loom)] +const MAX_READS: u32 = 10; + +/// An asynchronous reader-writer lock. +/// +/// This type of lock allows a number of readers or at most one writer at any +/// point in time. The write portion of this lock typically allows modification +/// of the underlying data (exclusive access) and the read portion of this lock +/// typically allows for read-only access (shared access). +/// +/// In comparison, a [`Mutex`] does not distinguish between readers or writers +/// that acquire the lock, therefore causing any tasks waiting for the lock to +/// become available to yield. An `RwLock` will allow any number of readers to +/// acquire the lock as long as a writer is not holding the lock. +/// +/// The priority policy of Tokio's read-write lock is _fair_ (or +/// [_write-preferring_]), in order to ensure that readers cannot starve +/// writers. Fairness is ensured using a first-in, first-out queue for the tasks +/// awaiting the lock; if a task that wishes to acquire the write lock is at the +/// head of the queue, read locks will not be given out until the write lock has +/// been released. This is in contrast to the Rust standard library's +/// `std::sync::RwLock`, where the priority policy is dependent on the +/// operating system's implementation. +/// +/// The type parameter `T` represents the data that this lock protects. It is +/// required that `T` satisfies [`Send`] to be shared across threads. The RAII guards +/// returned from the locking methods implement [`Deref`](trait@std::ops::Deref) +/// (and [`DerefMut`](trait@std::ops::DerefMut) +/// for the `write` methods) to allow access to the content of the lock. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::RwLock; +/// +/// #[tokio::main] +/// async fn main() { +/// let lock = RwLock::new(5); +/// +/// // many reader locks can be held at once +/// { +/// let r1 = lock.read().await; +/// let r2 = lock.read().await; +/// assert_eq!(*r1, 5); +/// assert_eq!(*r2, 5); +/// } // read locks are dropped at this point +/// +/// // only one write lock may be held, however +/// { +/// let mut w = lock.write().await; +/// *w += 1; +/// assert_eq!(*w, 6); +/// } // write lock is dropped here +/// } +/// ``` +/// +/// [`Mutex`]: struct@super::Mutex +/// [`RwLock`]: struct@RwLock +/// [`RwLockReadGuard`]: struct@RwLockReadGuard +/// [`RwLockWriteGuard`]: struct@RwLockWriteGuard +/// [`Send`]: trait@std::marker::Send +/// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies +#[derive(Debug)] +pub struct RwLock<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + + // maximum number of concurrent readers + mr: u32, + + //semaphore to coordinate read and write access to T + s: Semaphore, + + //inner data T + c: UnsafeCell<T>, +} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send<T: Send>() {} + fn check_sync<T: Sync>() {} + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + + check_send::<RwLock<u32>>(); + check_sync::<RwLock<u32>>(); + check_unpin::<RwLock<u32>>(); + + check_send::<RwLockReadGuard<'_, u32>>(); + check_sync::<RwLockReadGuard<'_, u32>>(); + check_unpin::<RwLockReadGuard<'_, u32>>(); + + check_send::<OwnedRwLockReadGuard<u32, i32>>(); + check_sync::<OwnedRwLockReadGuard<u32, i32>>(); + check_unpin::<OwnedRwLockReadGuard<u32, i32>>(); + + check_send::<RwLockWriteGuard<'_, u32>>(); + check_sync::<RwLockWriteGuard<'_, u32>>(); + check_unpin::<RwLockWriteGuard<'_, u32>>(); + + check_send::<RwLockMappedWriteGuard<'_, u32>>(); + check_sync::<RwLockMappedWriteGuard<'_, u32>>(); + check_unpin::<RwLockMappedWriteGuard<'_, u32>>(); + + check_send::<OwnedRwLockWriteGuard<u32>>(); + check_sync::<OwnedRwLockWriteGuard<u32>>(); + check_unpin::<OwnedRwLockWriteGuard<u32>>(); + + check_send::<OwnedRwLockMappedWriteGuard<u32, i32>>(); + check_sync::<OwnedRwLockMappedWriteGuard<u32, i32>>(); + check_unpin::<OwnedRwLockMappedWriteGuard<u32, i32>>(); + + let rwlock = Arc::new(RwLock::new(0)); + check_send_sync_val(rwlock.read()); + check_send_sync_val(Arc::clone(&rwlock).read_owned()); + check_send_sync_val(rwlock.write()); + check_send_sync_val(Arc::clone(&rwlock).write_owned()); +} + +// As long as T: Send + Sync, it's fine to send and share RwLock<T> between threads. +// If T were not Send, sending and sharing a RwLock<T> would be bad, since you can access T through +// RwLock<T>. +unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {} +unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {} +// NB: These impls need to be explicit since we're storing a raw pointer. +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send`. +unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {} +unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} +// T is required to be `Send` because an OwnedRwLockReadGuard can be used to drop the value held in +// the RwLock, unlike RwLockReadGuard. +unsafe impl<T, U> Send for OwnedRwLockReadGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Sync, +{ +} +unsafe impl<T, U> Sync for OwnedRwLockReadGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} +unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Sync for OwnedRwLockWriteGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Sync for RwLockMappedWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T, U> Sync for OwnedRwLockMappedWriteGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send` - but since this is also provides mutable access, we need to +// make sure that `T` is `Send` since its value can be sent across thread +// boundaries. +unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Send for OwnedRwLockWriteGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<T> Send for RwLockMappedWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +unsafe impl<T, U> Send for OwnedRwLockMappedWriteGuard<T, U> +where + T: ?Sized + Send + Sync, + U: ?Sized + Send + Sync, +{ +} + +impl<T: ?Sized> RwLock<T> { + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// let lock = RwLock::new(5); + /// ``` + #[track_caller] + pub fn new(value: T) -> RwLock<T> + where + T: Sized, + { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = MAX_READS, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(MAX_READS as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(MAX_READS as usize); + + RwLock { + mr: MAX_READS, + c: UnsafeCell::new(value), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new instance of an `RwLock<T>` which is unlocked + /// and allows a maximum of `max_reads` concurrent readers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// let lock = RwLock::with_max_readers(5, 1024); + /// ``` + /// + /// # Panics + /// + /// Panics if `max_reads` is more than `u32::MAX >> 3`. + #[track_caller] + pub fn with_max_readers(value: T, max_reads: u32) -> RwLock<T> + where + T: Sized, + { + assert!( + max_reads <= MAX_READS, + "a RwLock may not be created with more than {} readers", + MAX_READS + ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = max_reads, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(max_reads as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(max_reads as usize); + + RwLock { + mr: max_reads, + c: UnsafeCell::new(value), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// static LOCK: RwLock<i32> = RwLock::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(value: T) -> RwLock<T> + where + T: Sized, + { + RwLock { + mr: MAX_READS, + c: UnsafeCell::new(value), + s: Semaphore::const_new(MAX_READS as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Creates a new instance of an `RwLock<T>` which is unlocked + /// and allows a maximum of `max_reads` concurrent readers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// static LOCK: RwLock<i32> = RwLock::const_with_max_readers(5, 1024); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_with_max_readers(value: T, mut max_reads: u32) -> RwLock<T> + where + T: Sized, + { + max_reads &= MAX_READS; + RwLock { + mr: max_reads, + c: UnsafeCell::new(value), + s: Semaphore::const_new(max_reads as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), + } + } + + /// Locks this `RwLock` with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no writers which hold the + /// lock. There may be other readers inside the lock when the task resumes. + /// + /// Note that under the priority policy of [`RwLock`], read locks are not + /// granted until prior write locks, to prevent starvation. Therefore + /// deadlock may occur if a read lock is held by the current task, a write + /// lock attempt is made, and then a subsequent read lock attempt is made + /// by the current task. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read` makes you lose your place in + /// the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let n = lock.read().await; + /// assert_eq!(*n, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let r = c_lock.read().await; + /// assert_eq!(*r, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard after the spawned task finishes. + /// drop(n); + /// } + /// ``` + pub async fn read(&self) -> RwLockReadGuard<'_, T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(1), + self.resource_span.clone(), + "RwLock::read", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(1); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + } + + /// Blockingly locks this `RwLock` with shared read access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the read access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let mut write_lock = rwlock.write().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `write_lock` is released. + /// let read_lock = rwlock.blocking_read(); + /// assert_eq!(*read_lock, 0); + /// } + /// }); + /// + /// *write_lock -= 1; + /// drop(write_lock); // release the lock. + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// assert!(rwlock.try_write().is_ok()); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_read(&self) -> RwLockReadGuard<'_, T> { + crate::future::block_on(self.read()) + } + + /// Locks this `RwLock` with shared read access, causing the current task + /// to yield until the lock has been acquired. + /// + /// The calling task will yield until there are no writers which hold the + /// lock. There may be other readers inside the lock when the task resumes. + /// + /// This method is identical to [`RwLock::read`], except that the returned + /// guard references the `RwLock` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `RwLock` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `RwLock` alive by holding an `Arc`. + /// + /// Note that under the priority policy of [`RwLock`], read locks are not + /// granted until prior write locks, to prevent starvation. Therefore + /// deadlock may occur if a read lock is held by the current task, a write + /// lock attempt is made, and then a subsequent read lock attempt is made + /// by the current task. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let n = lock.read_owned().await; + /// assert_eq!(*n, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let r = c_lock.read_owned().await; + /// assert_eq!(*r, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard after the spawned task finishes. + /// drop(n); + ///} + /// ``` + pub async fn read_owned(self: Arc<Self>) -> OwnedRwLockReadGuard<T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(1), + self.resource_span.clone(), + "RwLock::read_owned", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(1); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + OwnedRwLockReadGuard { + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let v = lock.try_read().unwrap(); + /// assert_eq!(*v, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let n = c_lock.read().await; + /// assert_eq!(*n, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard when spawned task finishes. + /// drop(v); + /// } + /// ``` + pub fn try_read(&self) -> Result<RwLockReadGuard<'_, T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + Ok(RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }) + } + + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + /// + /// This method is identical to [`RwLock::try_read`], except that the + /// returned guard references the `RwLock` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `RwLock` must be wrapped in an `Arc` to + /// call this method, and the guard will live for the `'static` lifetime, + /// as it keeps the `RwLock` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let v = lock.try_read_owned().unwrap(); + /// assert_eq!(*v, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let n = c_lock.read_owned().await; + /// assert_eq!(*n, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard when spawned task finishes. + /// drop(v); + /// } + /// ``` + pub fn try_read_owned(self: Arc<Self>) -> Result<OwnedRwLockReadGuard<T>, TryLockError> { + match self.s.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + Ok(OwnedRwLockReadGuard { + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Locks this `RwLock` with exclusive write access, causing the current + /// task to yield until the lock has been acquired. + /// + /// The calling task will yield while other writers or readers currently + /// have access to the lock. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write` makes you lose your place + /// in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = RwLock::new(1); + /// + /// let mut n = lock.write().await; + /// *n = 2; + ///} + /// ``` + pub async fn write(&self) -> RwLockWriteGuard<'_, T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(self.mr), + self.resource_span.clone(), + "RwLock::write", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(self.mr); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + RwLockWriteGuard { + permits_acquired: self.mr, + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + } + } + + /// Blockingly locks this `RwLock` with exclusive write access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::{sync::RwLock}; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let read_lock = rwlock.read().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `read_lock` is released. + /// let mut write_lock = rwlock.blocking_write(); + /// *write_lock = 2; + /// } + /// }); + /// + /// assert_eq!(*read_lock, 1); + /// // Release the last outstanding read lock. + /// drop(read_lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let read_lock = rwlock.try_read().unwrap(); + /// assert_eq!(*read_lock, 2); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_write(&self) -> RwLockWriteGuard<'_, T> { + crate::future::block_on(self.write()) + } + + /// Locks this `RwLock` with exclusive write access, causing the current + /// task to yield until the lock has been acquired. + /// + /// The calling task will yield while other writers or readers currently + /// have access to the lock. + /// + /// This method is identical to [`RwLock::write`], except that the returned + /// guard references the `RwLock` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `RwLock` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `RwLock` alive by holding an `Arc`. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` + /// when dropped. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let mut n = lock.write_owned().await; + /// *n = 2; + ///} + /// ``` + pub async fn write_owned(self: Arc<Self>) -> OwnedRwLockWriteGuard<T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(self.mr), + self.resource_span.clone(), + "RwLock::write_owned", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(self.mr); + + inner.await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + OwnedRwLockWriteGuard { + permits_acquired: self.mr, + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rw = RwLock::new(1); + /// + /// let v = rw.read().await; + /// assert_eq!(*v, 1); + /// + /// assert!(rw.try_write().is_err()); + /// } + /// ``` + pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, T>, TryLockError> { + match self.s.try_acquire(self.mr) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + Ok(RwLockWriteGuard { + permits_acquired: self.mr, + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }) + } + + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + /// + /// This method is identical to [`RwLock::try_write`], except that the + /// returned guard references the `RwLock` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `RwLock` must be wrapped in an `Arc` to + /// call this method, and the guard will live for the `'static` lifetime, + /// as it keeps the `RwLock` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rw = Arc::new(RwLock::new(1)); + /// + /// let v = Arc::clone(&rw).read_owned().await; + /// assert_eq!(*v, 1); + /// + /// assert!(rw.try_write_owned().is_err()); + /// } + /// ``` + pub fn try_write_owned(self: Arc<Self>) -> Result<OwnedRwLockWriteGuard<T>, TryLockError> { + match self.s.try_acquire(self.mr) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + + Ok(OwnedRwLockWriteGuard { + permits_acquired: self.mr, + data: self.c.get(), + lock: ManuallyDrop::new(self), + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `RwLock` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// fn main() { + /// let mut lock = RwLock::new(1); + /// + /// let n = lock.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } + } + + /// Consumes the lock, returning the underlying data. + pub fn into_inner(self) -> T + where + T: Sized, + { + self.c.into_inner() + } +} + +impl<T> From<T> for RwLock<T> { + fn from(s: T) -> Self { + Self::new(s) + } +} + +impl<T: ?Sized> Default for RwLock<T> +where + T: Default, +{ + fn default() -> Self { + Self::new(T::default()) + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/owned_read_guard.rs b/third_party/rust/tokio/src/sync/rwlock/owned_read_guard.rs new file mode 100644 index 0000000000..27b71bd988 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/owned_read_guard.rs @@ -0,0 +1,170 @@ +use crate::sync::rwlock::RwLock; +use std::fmt; +use std::marker::PhantomData; +use std::mem; +use std::mem::ManuallyDrop; +use std::ops; +use std::sync::Arc; + +/// Owned RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read_owned`] method on +/// [`RwLock`]. +/// +/// [`read_owned`]: method@crate::sync::RwLock::read_owned +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct OwnedRwLockReadGuard<T: ?Sized, U: ?Sized = T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + // ManuallyDrop allows us to destructure into this field without running the destructor. + pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) data: *const U, + pub(super) _p: PhantomData<T>, +} + +impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { + /// Makes a new `OwnedRwLockReadGuard` for a component of the locked data. + /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// let guard = lock.read_owned().await; + /// let guard = OwnedRwLockReadGuard::map(guard, |f| &f.0); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn map<F, V: ?Sized>(mut this: Self, f: F) -> OwnedRwLockReadGuard<T, V> + where + F: FnOnce(&U) -> &V, + { + let data = f(&*this) as *const V; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`OwnedRwLockReadGuard`] for a component of the + /// locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockReadGuard::try_map(..)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// let guard = lock.read_owned().await; + /// let guard = OwnedRwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn try_map<F, V: ?Sized>(mut this: Self, f: F) -> Result<OwnedRwLockReadGuard<T, V>, Self> + where + F: FnOnce(&U) -> Option<&V>, + { + let data = match f(&*this) { + Some(data) => data as *const V, + None => return Err(this), + }; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized, U: ?Sized> ops::Deref for OwnedRwLockReadGuard<T, U> { + type Target = U; + + fn deref(&self) -> &U { + unsafe { &*self.data } + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Debug for OwnedRwLockReadGuard<T, U> +where + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Display for OwnedRwLockReadGuard<T, U> +where + U: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> Drop for OwnedRwLockReadGuard<T, U> { + fn drop(&mut self) { + self.lock.s.release(1); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "sub", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/owned_write_guard.rs b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard.rs new file mode 100644 index 0000000000..dbedab4cbb --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard.rs @@ -0,0 +1,279 @@ +use crate::sync::rwlock::owned_read_guard::OwnedRwLockReadGuard; +use crate::sync::rwlock::owned_write_guard_mapped::OwnedRwLockMappedWriteGuard; +use crate::sync::rwlock::RwLock; +use std::fmt; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop}; +use std::ops; +use std::sync::Arc; + +/// Owned RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by the [`write_owned`] method +/// on [`RwLock`]. +/// +/// [`write_owned`]: method@crate::sync::RwLock::write_owned +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct OwnedRwLockWriteGuard<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + // ManuallyDrop allows us to destructure into this field without running the destructor. + pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) data: *mut T, + pub(super) _p: PhantomData<T>, +} + +impl<T: ?Sized> OwnedRwLockWriteGuard<T> { + /// Makes a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked + /// data. + /// + /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockWriteGuard::map(..)`. A method would interfere with methods + /// of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let lock = Arc::clone(&lock); + /// let mut mapped = OwnedRwLockWriteGuard::map(lock.write_owned().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> OwnedRwLockMappedWriteGuard<T, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`OwnedRwLockMappedWriteGuard`] for a component + /// of the locked data. The original guard is returned if the closure + /// returns `None`. + /// + /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in + /// already locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockWriteGuard::try_map(...)`. A method would interfere + /// with methods of the same name on the contents of the locked data. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let guard = Arc::clone(&lock).write_owned().await; + /// let mut guard = OwnedRwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>( + mut this: Self, + f: F, + ) -> Result<OwnedRwLockMappedWriteGuard<T, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let permits_acquired = this.permits_acquired; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Converts this `OwnedRwLockWriteGuard` into an + /// `OwnedRwLockMappedWriteGuard`. This method can be used to store a + /// non-mapped guard in a struct field that expects a mapped guard. + /// + /// This is equivalent to calling `OwnedRwLockWriteGuard::map(guard, |me| me)`. + #[inline] + pub fn into_mapped(this: Self) -> OwnedRwLockMappedWriteGuard<T> { + Self::map(this, |me| me) + } + + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::RwLock; + /// # use std::sync::Arc; + /// # + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let n = lock.clone().write_owned().await; + /// + /// let cloned_lock = lock.clone(); + /// let handle = tokio::spawn(async move { + /// *cloned_lock.write_owned().await = 2; + /// }); + /// + /// let n = n.downgrade(); + /// assert_eq!(*n, 1, "downgrade is atomic"); + /// + /// drop(n); + /// handle.await.unwrap(); + /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); + /// # } + /// ``` + pub fn downgrade(mut self) -> OwnedRwLockReadGuard<T> { + let lock = unsafe { ManuallyDrop::take(&mut self.lock) }; + let data = self.data; + let to_release = (self.permits_acquired - 1) as usize; + + // Release all but one of the permits held by the write guard + lock.s.release(to_release); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(self); + + OwnedRwLockReadGuard { + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } +} + +impl<T: ?Sized> ops::Deref for OwnedRwLockWriteGuard<T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<T: ?Sized> ops::DerefMut for OwnedRwLockWriteGuard<T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<T: ?Sized> fmt::Debug for OwnedRwLockWriteGuard<T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized> fmt::Display for OwnedRwLockWriteGuard<T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<T: ?Sized> Drop for OwnedRwLockWriteGuard<T> { + fn drop(&mut self) { + self.lock.s.release(self.permits_acquired as usize); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/owned_write_guard_mapped.rs b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard_mapped.rs new file mode 100644 index 0000000000..55a24d96ac --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/owned_write_guard_mapped.rs @@ -0,0 +1,191 @@ +use crate::sync::rwlock::RwLock; +use std::fmt; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop}; +use std::ops; +use std::sync::Arc; + +/// Owned RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by [mapping] an [`OwnedRwLockWriteGuard`]. It is a +/// separate type from `OwnedRwLockWriteGuard` to disallow downgrading a mapped +/// guard, since doing so can cause undefined behavior. +/// +/// [mapping]: method@crate::sync::OwnedRwLockWriteGuard::map +/// [`OwnedRwLockWriteGuard`]: struct@crate::sync::OwnedRwLockWriteGuard +pub struct OwnedRwLockMappedWriteGuard<T: ?Sized, U: ?Sized = T> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + // ManuallyDrop allows us to destructure into this field without running the destructor. + pub(super) lock: ManuallyDrop<Arc<RwLock<T>>>, + pub(super) data: *mut U, + pub(super) _p: PhantomData<T>, +} + +impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { + /// Makes a new `OwnedRwLockMappedWriteGuard` for a component of the locked + /// data. + /// + /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed + /// in already locked the data. + /// + /// This is an associated function that needs to be used as + /// `OwnedRwLockWriteGuard::map(..)`. A method would interfere with methods + /// of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let lock = Arc::clone(&lock); + /// let mut mapped = OwnedRwLockWriteGuard::map(lock.write_owned().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, V: ?Sized>(mut this: Self, f: F) -> OwnedRwLockMappedWriteGuard<T, V> + where + F: FnOnce(&mut U) -> &mut V, + { + let data = f(&mut *this) as *mut V; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new `OwnedRwLockMappedWriteGuard` for a component + /// of the locked data. The original guard is returned if the closure + /// returns `None`. + /// + /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed + /// in already locked the data. + /// + /// This is an associated function that needs to be + /// used as `OwnedRwLockMappedWriteGuard::try_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{RwLock, OwnedRwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(Foo(1))); + /// + /// { + /// let guard = Arc::clone(&lock).write_owned().await; + /// let mut guard = OwnedRwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, V: ?Sized>( + mut this: Self, + f: F, + ) -> Result<OwnedRwLockMappedWriteGuard<T, V>, Self> + where + F: FnOnce(&mut U) -> Option<&mut V>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut V, + None => return Err(this), + }; + let lock = unsafe { ManuallyDrop::take(&mut this.lock) }; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(OwnedRwLockMappedWriteGuard { + permits_acquired, + lock: ManuallyDrop::new(lock), + data, + _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized, U: ?Sized> ops::Deref for OwnedRwLockMappedWriteGuard<T, U> { + type Target = U; + + fn deref(&self) -> &U { + unsafe { &*self.data } + } +} + +impl<T: ?Sized, U: ?Sized> ops::DerefMut for OwnedRwLockMappedWriteGuard<T, U> { + fn deref_mut(&mut self) -> &mut U { + unsafe { &mut *self.data } + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Debug for OwnedRwLockMappedWriteGuard<T, U> +where + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> fmt::Display for OwnedRwLockMappedWriteGuard<T, U> +where + U: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<T: ?Sized, U: ?Sized> Drop for OwnedRwLockMappedWriteGuard<T, U> { + fn drop(&mut self) { + self.lock.s.release(self.permits_acquired as usize); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + unsafe { ManuallyDrop::drop(&mut self.lock) }; + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/read_guard.rs b/third_party/rust/tokio/src/sync/rwlock/read_guard.rs new file mode 100644 index 0000000000..3692131992 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/read_guard.rs @@ -0,0 +1,177 @@ +use crate::sync::batch_semaphore::Semaphore; +use std::fmt; +use std::marker; +use std::mem; +use std::ops; + +/// RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read`] method on +/// [`RwLock`]. +/// +/// [`read`]: method@crate::sync::RwLock::read +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct RwLockReadGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) s: &'a Semaphore, + pub(super) data: *const T, + pub(super) marker: marker::PhantomData<&'a T>, +} + +impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { + /// Makes a new `RwLockReadGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::map(guard, |f| &f.0); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(&*this) as *const U; + let s = this.s; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`RwLockReadGuard`] for a component of the + /// locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockReadGuard::try_map(..)`. A method would interfere with methods of the + /// same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::try_map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(&*this) { + Some(data) => data as *const U, + None => return Err(this), + }; + let s = this.s; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized> ops::Deref for RwLockReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockReadGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockReadGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "sub", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/write_guard.rs b/third_party/rust/tokio/src/sync/rwlock/write_guard.rs new file mode 100644 index 0000000000..7cadd74c60 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/write_guard.rs @@ -0,0 +1,282 @@ +use crate::sync::batch_semaphore::Semaphore; +use crate::sync::rwlock::read_guard::RwLockReadGuard; +use crate::sync::rwlock::write_guard_mapped::RwLockMappedWriteGuard; +use std::fmt; +use std::marker; +use std::mem; +use std::ops; + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by the [`write`] method +/// on [`RwLock`]. +/// +/// [`write`]: method@crate::sync::RwLock::write +/// [`RwLock`]: struct@crate::sync::RwLock +pub struct RwLockWriteGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + pub(super) s: &'a Semaphore, + pub(super) data: *mut T, + pub(super) marker: marker::PhantomData<&'a mut T>, +} + +impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { + /// Makes a new [`RwLockMappedWriteGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockMappedWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`RwLockMappedWriteGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from + /// the [`parking_lot` crate]. + /// + /// [`RwLockMappedWriteGuard`]: struct@crate::sync::RwLockMappedWriteGuard + /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let guard = lock.write().await; + /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>( + mut this: Self, + f: F, + ) -> Result<RwLockMappedWriteGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + Ok(RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } + + /// Converts this `RwLockWriteGuard` into an `RwLockMappedWriteGuard`. This + /// method can be used to store a non-mapped guard in a struct field that + /// expects a mapped guard. + /// + /// This is equivalent to calling `RwLockWriteGuard::map(guard, |me| me)`. + #[inline] + pub fn into_mapped(this: Self) -> RwLockMappedWriteGuard<'a, T> { + RwLockWriteGuard::map(this, |me| me) + } + + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + /// + /// Returns an RAII guard which will drop this read access of the `RwLock` + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::RwLock; + /// # use std::sync::Arc; + /// # + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let n = lock.write().await; + /// + /// let cloned_lock = lock.clone(); + /// let handle = tokio::spawn(async move { + /// *cloned_lock.write().await = 2; + /// }); + /// + /// let n = n.downgrade(); + /// assert_eq!(*n, 1, "downgrade is atomic"); + /// + /// drop(n); + /// handle.await.unwrap(); + /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); + /// # } + /// ``` + /// + /// [`RwLock`]: struct@crate::sync::RwLock + pub fn downgrade(self) -> RwLockReadGuard<'a, T> { + let RwLockWriteGuard { s, data, .. } = self; + let to_release = (self.permits_acquired - 1) as usize; + // Release all but one of the permits held by the write guard + s.release(to_release); + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(self); + + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } +} + +impl<T: ?Sized> ops::Deref for RwLockWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<T: ?Sized> ops::DerefMut for RwLockWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> { + fn drop(&mut self) { + self.s.release(self.permits_acquired as usize); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/rwlock/write_guard_mapped.rs b/third_party/rust/tokio/src/sync/rwlock/write_guard_mapped.rs new file mode 100644 index 0000000000..b5c644a9e8 --- /dev/null +++ b/third_party/rust/tokio/src/sync/rwlock/write_guard_mapped.rs @@ -0,0 +1,197 @@ +use crate::sync::batch_semaphore::Semaphore; +use std::fmt; +use std::marker; +use std::mem; +use std::ops; + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by [mapping] an [`RwLockWriteGuard`]. It is a +/// separate type from `RwLockWriteGuard` to disallow downgrading a mapped +/// guard, since doing so can cause undefined behavior. +/// +/// [mapping]: method@crate::sync::RwLockWriteGuard::map +/// [`RwLockWriteGuard`]: struct@crate::sync::RwLockWriteGuard +pub struct RwLockMappedWriteGuard<'a, T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) resource_span: tracing::Span, + pub(super) permits_acquired: u32, + pub(super) s: &'a Semaphore, + pub(super) data: *mut T, + pub(super) marker: marker::PhantomData<&'a mut T>, +} + +impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { + /// Makes a new `RwLockMappedWriteGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockMappedWriteGuard::map(..)`. A method would interfere with methods + /// of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockMappedWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Attempts to make a new [`RwLockMappedWriteGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockMappedWriteGuard::try_map(...)`. A method would interfere + /// with methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from + /// the [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let guard = lock.write().await; + /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>( + mut this: Self, + f: F, + ) -> Result<RwLockMappedWriteGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + let permits_acquired = this.permits_acquired; + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = this.resource_span.clone(); + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + + Ok(RwLockMappedWriteGuard { + permits_acquired, + s, + data, + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + }) + } +} + +impl<T: ?Sized> ops::Deref for RwLockMappedWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<T: ?Sized> ops::DerefMut for RwLockMappedWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockMappedWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockMappedWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockMappedWriteGuard<'a, T> { + fn drop(&mut self) { + self.s.release(self.permits_acquired as usize); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + write_locked.op = "override", + ) + }); + } +} diff --git a/third_party/rust/tokio/src/sync/semaphore.rs b/third_party/rust/tokio/src/sync/semaphore.rs new file mode 100644 index 0000000000..860f46f399 --- /dev/null +++ b/third_party/rust/tokio/src/sync/semaphore.rs @@ -0,0 +1,644 @@ +use super::batch_semaphore as ll; // low level implementation +use super::{AcquireError, TryAcquireError}; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; +use std::sync::Arc; + +/// Counting semaphore performing asynchronous permit acquisition. +/// +/// A semaphore maintains a set of permits. Permits are used to synchronize +/// access to a shared resource. A semaphore differs from a mutex in that it +/// can allow more than one concurrent caller to access the shared resource at a +/// time. +/// +/// When `acquire` is called and the semaphore has remaining permits, the +/// function immediately returns a permit. However, if no remaining permits are +/// available, `acquire` (asynchronously) waits until an outstanding permit is +/// dropped. At this point, the freed permit is assigned to the caller. +/// +/// This `Semaphore` is fair, which means that permits are given out in the order +/// they were requested. This fairness is also applied when `acquire_many` gets +/// involved, so if a call to `acquire_many` at the front of the queue requests +/// more permits than currently available, this can prevent a call to `acquire` +/// from completing, even if the semaphore has enough permits complete the call +/// to `acquire`. +/// +/// To use the `Semaphore` in a poll function, you can use the [`PollSemaphore`] +/// utility. +/// +/// # Examples +/// +/// Basic usage: +/// +/// ``` +/// use tokio::sync::{Semaphore, TryAcquireError}; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Semaphore::new(3); +/// +/// let a_permit = semaphore.acquire().await.unwrap(); +/// let two_permits = semaphore.acquire_many(2).await.unwrap(); +/// +/// assert_eq!(semaphore.available_permits(), 0); +/// +/// let permit_attempt = semaphore.try_acquire(); +/// assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits)); +/// } +/// ``` +/// +/// Use [`Semaphore::acquire_owned`] to move permits across tasks: +/// +/// ``` +/// use std::sync::Arc; +/// use tokio::sync::Semaphore; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Arc::new(Semaphore::new(3)); +/// let mut join_handles = Vec::new(); +/// +/// for _ in 0..5 { +/// let permit = semaphore.clone().acquire_owned().await.unwrap(); +/// join_handles.push(tokio::spawn(async move { +/// // perform task... +/// // explicitly own `permit` in the task +/// drop(permit); +/// })); +/// } +/// +/// for handle in join_handles { +/// handle.await.unwrap(); +/// } +/// } +/// ``` +/// +/// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html +/// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned +#[derive(Debug)] +pub struct Semaphore { + /// The low level semaphore + ll_sem: ll::Semaphore, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +/// A permit from the semaphore. +/// +/// This type is created by the [`acquire`] method. +/// +/// [`acquire`]: crate::sync::Semaphore::acquire() +#[must_use] +#[derive(Debug)] +pub struct SemaphorePermit<'a> { + sem: &'a Semaphore, + permits: u32, +} + +/// An owned permit from the semaphore. +/// +/// This type is created by the [`acquire_owned`] method. +/// +/// [`acquire_owned`]: crate::sync::Semaphore::acquire_owned() +#[must_use] +#[derive(Debug)] +pub struct OwnedSemaphorePermit { + sem: Arc<Semaphore>, + permits: u32, +} + +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_unpin<T: Unpin>() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val<T: Send + Sync>(_t: T) {} + fn check_send_sync<T: Send + Sync>() {} + check_unpin::<Semaphore>(); + check_unpin::<SemaphorePermit<'_>>(); + check_send_sync::<Semaphore>(); + + let semaphore = Semaphore::new(0); + check_send_sync_val(semaphore.acquire()); +} + +impl Semaphore { + /// Creates a new semaphore with the initial number of permits. + #[track_caller] + pub fn new(permits: usize) -> Self { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Semaphore", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + inherits_child_attrs = true, + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let ll_sem = resource_span.in_scope(|| ll::Semaphore::new(permits)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let ll_sem = ll::Semaphore::new(permits); + + Self { + ll_sem, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } + } + + /// Creates a new semaphore with the initial number of permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// static SEM: Semaphore = Semaphore::const_new(10); + /// ``` + /// + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(permits: usize) -> Self { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return Self { + ll_sem: ll::Semaphore::const_new(permits), + resource_span: tracing::Span::none(), + }; + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return Self { + ll_sem: ll::Semaphore::const_new(permits), + }; + } + + /// Returns the current number of available permits. + pub fn available_permits(&self) -> usize { + self.ll_sem.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. + pub fn add_permits(&self, n: usize) { + self.ll_sem.release(n); + } + + /// Acquires a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permit. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire` makes you lose your place + /// in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// drop(permit_1); + /// assert_eq!(semaphore.available_permits(), 1); + /// } + /// ``` + /// + /// [`AcquireError`]: crate::sync::AcquireError + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(1), + self.resource_span.clone(), + "Semaphore::acquire", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(1); + + inner.await?; + Ok(SemaphorePermit { + sem: self, + permits: 1, + }) + } + + /// Acquires `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permits. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(5); + /// + /// let permit = semaphore.acquire_many(3).await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 2); + /// } + /// ``` + /// + /// [`AcquireError`]: crate::sync::AcquireError + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + trace::async_op( + || self.ll_sem.acquire(n), + self.resource_span.clone(), + "Semaphore::acquire_many", + "poll", + true, + ) + .await?; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + self.ll_sem.acquire(n).await?; + + Ok(SemaphorePermit { + sem: self, + permits: n, + }) + } + + /// Tries to acquire a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, + /// this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(SemaphorePermit { + sem: self, + permits: 1, + }), + Err(e) => Err(e), + } + } + + /// Tries to acquire `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left. + /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(4); + /// + /// let permit_1 = semaphore.try_acquire_many(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> { + match self.ll_sem.try_acquire(n) { + Ok(_) => Ok(SemaphorePermit { + sem: self, + permits: n, + }), + Err(e) => Err(e), + } + } + + /// Acquires a permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(3)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_owned().await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`AcquireError`]: crate::sync::AcquireError + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub async fn acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(1), + self.resource_span.clone(), + "Semaphore::acquire_owned", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(1); + + inner.await?; + Ok(OwnedSemaphorePermit { + sem: self, + permits: 1, + }) + } + + /// Acquires `n` permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many_owned` makes you lose + /// your place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(10)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_many_owned(2).await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`AcquireError`]: crate::sync::AcquireError + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub async fn acquire_many_owned( + self: Arc<Self>, + n: u32, + ) -> Result<OwnedSemaphorePermit, AcquireError> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.ll_sem.acquire(n), + self.resource_span.clone(), + "Semaphore::acquire_many_owned", + "poll", + true, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.ll_sem.acquire(n); + + inner.await?; + Ok(OwnedSemaphorePermit { + sem: self, + permits: n, + }) + } + + /// Tries to acquire a permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(2)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire_owned(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self, + permits: 1, + }), + Err(e) => Err(e), + } + } + + /// Tries to acquire `n` permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(4)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_many_owned(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many_owned(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub fn try_acquire_many_owned( + self: Arc<Self>, + n: u32, + ) -> Result<OwnedSemaphorePermit, TryAcquireError> { + match self.ll_sem.try_acquire(n) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self, + permits: n, + }), + Err(e) => Err(e), + } + } + + /// Closes the semaphore. + /// + /// This prevents the semaphore from issuing new permits and notifies all pending waiters. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// use std::sync::Arc; + /// use tokio::sync::TryAcquireError; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(1)); + /// let semaphore2 = semaphore.clone(); + /// + /// tokio::spawn(async move { + /// let permit = semaphore.acquire_many(2).await; + /// assert!(permit.is_err()); + /// println!("waiter received error"); + /// }); + /// + /// println!("closing semaphore"); + /// semaphore2.close(); + /// + /// // Cannot obtain more permits + /// assert_eq!(semaphore2.try_acquire().err(), Some(TryAcquireError::Closed)) + /// } + /// ``` + pub fn close(&self) { + self.ll_sem.close(); + } + + /// Returns true if the semaphore is closed + pub fn is_closed(&self) -> bool { + self.ll_sem.is_closed() + } +} + +impl<'a> SemaphorePermit<'a> { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + +impl OwnedSemaphorePermit { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + +impl<'a> Drop for SemaphorePermit<'_> { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} + +impl Drop for OwnedSemaphorePermit { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} diff --git a/third_party/rust/tokio/src/sync/task/atomic_waker.rs b/third_party/rust/tokio/src/sync/task/atomic_waker.rs new file mode 100644 index 0000000000..13aba35448 --- /dev/null +++ b/third_party/rust/tokio/src/sync/task/atomic_waker.rs @@ -0,0 +1,382 @@ +#![cfg_attr(any(loom, not(feature = "sync")), allow(dead_code, unreachable_pub))] + +use crate::loom::cell::UnsafeCell; +use crate::loom::hint; +use crate::loom::sync::atomic::AtomicUsize; + +use std::fmt; +use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::task::Waker; + +/// A synchronization primitive for task waking. +/// +/// `AtomicWaker` will coordinate concurrent wakes with the consumer +/// potentially "waking" the underlying task. This is useful in scenarios +/// where a computation completes in another thread and wants to wake the +/// consumer, but the consumer is in the process of being migrated to a new +/// logical task. +/// +/// Consumers should call `register` before checking the result of a computation +/// and producers should call `wake` after producing the computation (this +/// differs from the usual `thread::park` pattern). It is also permitted for +/// `wake` to be called **before** `register`. This results in a no-op. +/// +/// A single `AtomicWaker` may be reused for any number of calls to `register` or +/// `wake`. +pub(crate) struct AtomicWaker { + state: AtomicUsize, + waker: UnsafeCell<Option<Waker>>, +} + +impl RefUnwindSafe for AtomicWaker {} +impl UnwindSafe for AtomicWaker {} + +// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell +// stores a `Waker` value produced by calls to `register` and many threads can +// race to take the waker by calling `wake`. +// +// If a new `Waker` instance is produced by calling `register` before an existing +// one is consumed, then the existing one is overwritten. +// +// While `AtomicWaker` is single-producer, the implementation ensures memory +// safety. In the event of concurrent calls to `register`, there will be a +// single winner whose waker will get stored in the cell. The losers will not +// have their tasks woken. As such, callers should ensure to add synchronization +// to calls to `register`. +// +// The implementation uses a single `AtomicUsize` value to coordinate access to +// the `Waker` cell. There are two bits that are operated on independently. These +// are represented by `REGISTERING` and `WAKING`. +// +// The `REGISTERING` bit is set when a producer enters the critical section. The +// `WAKING` bit is set when a consumer enters the critical section. Neither +// bit being set is represented by `WAITING`. +// +// A thread obtains an exclusive lock on the waker cell by transitioning the +// state from `WAITING` to `REGISTERING` or `WAKING`, depending on the +// operation the thread wishes to perform. When this transition is made, it is +// guaranteed that no other thread will access the waker cell. +// +// # Registering +// +// On a call to `register`, an attempt to transition the state from WAITING to +// REGISTERING is made. On success, the caller obtains a lock on the waker cell. +// +// If the lock is obtained, then the thread sets the waker cell to the waker +// provided as an argument. Then it attempts to transition the state back from +// `REGISTERING` -> `WAITING`. +// +// If this transition is successful, then the registering process is complete +// and the next call to `wake` will observe the waker. +// +// If the transition fails, then there was a concurrent call to `wake` that +// was unable to access the waker cell (due to the registering thread holding the +// lock). To handle this, the registering thread removes the waker it just set +// from the cell and calls `wake` on it. This call to wake represents the +// attempt to wake by the other thread (that set the `WAKING` bit). The +// state is then transitioned from `REGISTERING | WAKING` back to `WAITING`. +// This transition must succeed because, at this point, the state cannot be +// transitioned by another thread. +// +// # Waking +// +// On a call to `wake`, an attempt to transition the state from `WAITING` to +// `WAKING` is made. On success, the caller obtains a lock on the waker cell. +// +// If the lock is obtained, then the thread takes ownership of the current value +// in the waker cell, and calls `wake` on it. The state is then transitioned +// back to `WAITING`. This transition must succeed as, at this point, the state +// cannot be transitioned by another thread. +// +// If the thread is unable to obtain the lock, the `WAKING` bit is still set. +// This is because it has either been set by the current thread but the previous +// value included the `REGISTERING` bit **or** a concurrent thread is in the +// `WAKING` critical section. Either way, no action must be taken. +// +// If the current thread is the only concurrent call to `wake` and another +// thread is in the `register` critical section, when the other thread **exits** +// the `register` critical section, it will observe the `WAKING` bit and +// handle the waker itself. +// +// If another thread is in the `waker` critical section, then it will handle +// waking the caller task. +// +// # A potential race (is safely handled). +// +// Imagine the following situation: +// +// * Thread A obtains the `wake` lock and wakes a task. +// +// * Before thread A releases the `wake` lock, the woken task is scheduled. +// +// * Thread B attempts to wake the task. In theory this should result in the +// task being woken, but it cannot because thread A still holds the wake +// lock. +// +// This case is handled by requiring users of `AtomicWaker` to call `register` +// **before** attempting to observe the application state change that resulted +// in the task being woken. The wakers also change the application state +// before calling wake. +// +// Because of this, the task will do one of two things. +// +// 1) Observe the application state change that Thread B is waking on. In +// this case, it is OK for Thread B's wake to be lost. +// +// 2) Call register before attempting to observe the application state. Since +// Thread A still holds the `wake` lock, the call to `register` will result +// in the task waking itself and get scheduled again. + +/// Idle state. +const WAITING: usize = 0; + +/// A new waker value is being registered with the `AtomicWaker` cell. +const REGISTERING: usize = 0b01; + +/// The task currently registered with the `AtomicWaker` cell is being woken. +const WAKING: usize = 0b10; + +impl AtomicWaker { + /// Create an `AtomicWaker` + pub(crate) fn new() -> AtomicWaker { + AtomicWaker { + state: AtomicUsize::new(WAITING), + waker: UnsafeCell::new(None), + } + } + + /* + /// Registers the current waker to be notified on calls to `wake`. + pub(crate) fn register(&self, waker: Waker) { + self.do_register(waker); + } + */ + + /// Registers the provided waker to be notified on calls to `wake`. + /// + /// The new waker will take place of any previous wakers that were registered + /// by previous calls to `register`. Any calls to `wake` that happen after + /// a call to `register` (as defined by the memory ordering rules), will + /// wake the `register` caller's task. + /// + /// It is safe to call `register` with multiple other threads concurrently + /// calling `wake`. This will result in the `register` caller's current + /// task being woken once. + /// + /// This function is safe to call concurrently, but this is generally a bad + /// idea. Concurrent calls to `register` will attempt to register different + /// tasks to be woken. One of the callers will win and have its task set, + /// but there is no guarantee as to which caller will succeed. + pub(crate) fn register_by_ref(&self, waker: &Waker) { + self.do_register(waker); + } + + fn do_register<W>(&self, waker: W) + where + W: WakerRef, + { + fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> { + std::panic::catch_unwind(AssertUnwindSafe(f)) + } + + match self + .state + .compare_exchange(WAITING, REGISTERING, Acquire, Acquire) + .unwrap_or_else(|x| x) + { + WAITING => { + unsafe { + // If `into_waker` panics (because it's code outside of + // AtomicWaker) we need to prime a guard that is called on + // unwind to restore the waker to a WAITING state. Otherwise + // any future calls to register will incorrectly be stuck + // believing it's being updated by someone else. + let new_waker_or_panic = catch_unwind(move || waker.into_waker()); + + // Set the field to contain the new waker, or if + // `into_waker` panicked, leave the old value. + let mut maybe_panic = None; + let mut old_waker = None; + match new_waker_or_panic { + Ok(new_waker) => { + old_waker = self.waker.with_mut(|t| (*t).take()); + self.waker.with_mut(|t| *t = Some(new_waker)); + } + Err(panic) => maybe_panic = Some(panic), + } + + // Release the lock. If the state transitioned to include + // the `WAKING` bit, this means that a wake has been + // called concurrently, so we have to remove the waker and + // wake it.` + // + // Start by assuming that the state is `REGISTERING` as this + // is what we jut set it to. + let res = self + .state + .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); + + match res { + Ok(_) => { + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + let _ = catch_unwind(move || { + drop(old_waker); + }); + } + Err(actual) => { + // This branch can only be reached if a + // concurrent thread called `wake`. In this + // case, `actual` **must** be `REGISTERING | + // WAKING`. + debug_assert_eq!(actual, REGISTERING | WAKING); + + // Take the waker to wake once the atomic operation has + // completed. + let mut waker = self.waker.with_mut(|t| (*t).take()); + + // Just swap, because no one could change state + // while state == `Registering | `Waking` + self.state.swap(WAITING, AcqRel); + + // If `into_waker` panicked, then the waker in the + // waker slot is actually the old waker. + if maybe_panic.is_some() { + old_waker = waker.take(); + } + + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + if let Some(old_waker) = old_waker { + let _ = catch_unwind(move || { + old_waker.wake(); + }); + } + + // The atomic swap was complete, now wake the waker + // and return. + // + // If this panics, we end up in a consumed state and + // return the panic to the caller. + if let Some(waker) = waker { + debug_assert!(maybe_panic.is_none()); + waker.wake(); + } + } + } + + if let Some(panic) = maybe_panic { + // If `into_waker` panicked, return the panic to the caller. + resume_unwind(panic); + } + } + } + WAKING => { + // Currently in the process of waking the task, i.e., + // `wake` is currently being called on the old waker. + // So, we call wake on the new waker. + // + // If this panics, someone else is responsible for restoring the + // state of the waker. + waker.wake(); + + // This is equivalent to a spin lock, so use a spin hint. + hint::spin_loop(); + } + state => { + // In this case, a concurrent thread is holding the + // "registering" lock. This probably indicates a bug in the + // caller's code as racing to call `register` doesn't make much + // sense. + // + // We just want to maintain memory safety. It is ok to drop the + // call to `register`. + debug_assert!(state == REGISTERING || state == REGISTERING | WAKING); + } + } + } + + /// Wakes the task that last called `register`. + /// + /// If `register` has not been called yet, then this does nothing. + pub(crate) fn wake(&self) { + if let Some(waker) = self.take_waker() { + // If wake panics, we've consumed the waker which is a legitimate + // outcome. + waker.wake(); + } + } + + /// Attempts to take the `Waker` value out of the `AtomicWaker` with the + /// intention that the caller will wake the task later. + pub(crate) fn take_waker(&self) -> Option<Waker> { + // AcqRel ordering is used in order to acquire the value of the `waker` + // cell as well as to establish a `release` ordering with whatever + // memory the `AtomicWaker` is associated with. + match self.state.fetch_or(WAKING, AcqRel) { + WAITING => { + // The waking lock has been acquired. + let waker = unsafe { self.waker.with_mut(|t| (*t).take()) }; + + // Release the lock + self.state.fetch_and(!WAKING, Release); + + waker + } + state => { + // There is a concurrent thread currently updating the + // associated waker. + // + // Nothing more to do as the `WAKING` bit has been set. It + // doesn't matter if there are concurrent registering threads or + // not. + // + debug_assert!( + state == REGISTERING || state == REGISTERING | WAKING || state == WAKING + ); + None + } + } + } +} + +impl Default for AtomicWaker { + fn default() -> Self { + AtomicWaker::new() + } +} + +impl fmt::Debug for AtomicWaker { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "AtomicWaker") + } +} + +unsafe impl Send for AtomicWaker {} +unsafe impl Sync for AtomicWaker {} + +trait WakerRef { + fn wake(self); + fn into_waker(self) -> Waker; +} + +impl WakerRef for Waker { + fn wake(self) { + self.wake() + } + + fn into_waker(self) -> Waker { + self + } +} + +impl WakerRef for &Waker { + fn wake(self) { + self.wake_by_ref() + } + + fn into_waker(self) -> Waker { + self.clone() + } +} diff --git a/third_party/rust/tokio/src/sync/task/mod.rs b/third_party/rust/tokio/src/sync/task/mod.rs new file mode 100644 index 0000000000..a6bc6ed06e --- /dev/null +++ b/third_party/rust/tokio/src/sync/task/mod.rs @@ -0,0 +1,4 @@ +//! Thread-safe task notification primitives. + +mod atomic_waker; +pub(crate) use self::atomic_waker::AtomicWaker; diff --git a/third_party/rust/tokio/src/sync/tests/atomic_waker.rs b/third_party/rust/tokio/src/sync/tests/atomic_waker.rs new file mode 100644 index 0000000000..ec13cbd658 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/atomic_waker.rs @@ -0,0 +1,77 @@ +use crate::sync::AtomicWaker; +use tokio_test::task; + +use std::task::Waker; + +trait AssertSend: Send {} +trait AssertSync: Send {} + +impl AssertSend for AtomicWaker {} +impl AssertSync for AtomicWaker {} + +impl AssertSend for Waker {} +impl AssertSync for Waker {} + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn basic_usage() { + let mut waker = task::spawn(AtomicWaker::new()); + + waker.enter(|cx, waker| waker.register_by_ref(cx.waker())); + waker.wake(); + + assert!(waker.is_woken()); +} + +#[test] +fn wake_without_register() { + let mut waker = task::spawn(AtomicWaker::new()); + waker.wake(); + + // Registering should not result in a notification + waker.enter(|cx, waker| waker.register_by_ref(cx.waker())); + + assert!(!waker.is_woken()); +} + +#[test] +#[cfg(not(target_arch = "wasm32"))] // wasm currently doesn't support unwinding +fn atomic_waker_panic_safe() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| panic!("clone"), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) }; + + let atomic_waker = AtomicWaker::new(); + + let panicking = panic::AssertUnwindSafe(&panicking); + + let result = panic::catch_unwind(|| { + let panic::AssertUnwindSafe(panicking) = panicking; + atomic_waker.register_by_ref(panicking); + }); + + assert!(result.is_err()); + assert!(atomic_waker.take_waker().is_none()); + + atomic_waker.register_by_ref(&nonpanicking); + assert!(atomic_waker.take_waker().is_some()); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs b/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs new file mode 100644 index 0000000000..f8bae65d13 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_atomic_waker.rs @@ -0,0 +1,100 @@ +use crate::sync::task::AtomicWaker; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::Arc; +use std::task::Poll::{Pending, Ready}; + +struct Chan { + num: AtomicUsize, + task: AtomicWaker, +} + +#[test] +fn basic_notification() { + const NUM_NOTIFY: usize = 2; + + loom::model(|| { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} + +#[test] +fn test_panicky_waker() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = + RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ()); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + + // If you're working with this test (and I sure hope you never have to!), + // uncomment the following section because there will be a lot of panics + // which would otherwise log. + // + // We can't however leaved it uncommented, because it's global. + // panic::set_hook(Box::new(|_| ())); + + const NUM_NOTIFY: usize = 2; + + loom::model(move || { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + // Note: this panic should have no effect on the overall state of the + // waker and it should proceed as normal. + // + // A thread above might race to flag a wakeup, and a WAKING state will + // be preserved if this expected panic races with that so the below + // procedure should be allowed to continue uninterrupted. + let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking)); + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs b/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs new file mode 100644 index 0000000000..039b01bf43 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_broadcast.rs @@ -0,0 +1,207 @@ +use crate::sync::broadcast; +use crate::sync::broadcast::error::RecvError::{Closed, Lagged}; + +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; +use tokio_test::{assert_err, assert_ok}; + +#[test] +fn broadcast_send() { + loom::model(|| { + let (tx1, mut rx) = broadcast::channel(2); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + assert_ok!(tx1.send("one")); + assert_ok!(tx1.send("two")); + assert_ok!(tx1.send("three")); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + assert_ok!(tx2.send("eins")); + assert_ok!(tx2.send("zwei")); + assert_ok!(tx2.send("drei")); + }); + }); + + block_on(async { + let mut num = 0; + loop { + match rx.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + assert_eq!(num, 6); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +// An `Arc` is used as the value in order to detect memory leaks. +#[test] +fn broadcast_two() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel::<Arc<&'static str>>(16); + let mut rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx2.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + assert_ok!(tx.send(Arc::new("hello"))); + assert_ok!(tx.send(Arc::new("world"))); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn broadcast_wrap() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel(2); + let mut rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let mut num = 0; + + loop { + match rx1.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + + assert_eq!(num, 3); + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let mut num = 0; + + loop { + match rx2.recv().await { + Ok(_) => num += 1, + Err(Closed) => break, + Err(Lagged(n)) => num += n as usize, + } + } + + assert_eq!(num, 3); + }); + }); + + assert_ok!(tx.send("one")); + assert_ok!(tx.send("two")); + assert_ok!(tx.send("three")); + + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn drop_rx() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel(16); + let rx2 = tx.subscribe(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "one"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "two"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(v, "three"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + drop(rx2); + }); + + assert_ok!(tx.send("one")); + assert_ok!(tx.send("two")); + assert_ok!(tx.send("three")); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn drop_multiple_rx_with_overflow() { + loom::model(move || { + // It is essential to have multiple senders and receivers in this test case. + let (tx, mut rx) = broadcast::channel(1); + let _rx2 = tx.subscribe(); + + let _ = tx.send(()); + let tx2 = tx.clone(); + let th1 = thread::spawn(move || { + block_on(async { + for _ in 0..100 { + let _ = tx2.send(()); + } + }); + }); + let _ = tx.send(()); + + let th2 = thread::spawn(move || { + block_on(async { while let Ok(_) = rx.recv().await {} }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_list.rs b/third_party/rust/tokio/src/sync/tests/loom_list.rs new file mode 100644 index 0000000000..4067f865ce --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_list.rs @@ -0,0 +1,48 @@ +use crate::sync::mpsc::list; + +use loom::thread; +use std::sync::Arc; + +#[test] +fn smoke() { + use crate::sync::mpsc::block::Read::*; + + const NUM_TX: usize = 2; + const NUM_MSG: usize = 2; + + loom::model(|| { + let (tx, mut rx) = list::channel(); + let tx = Arc::new(tx); + + for th in 0..NUM_TX { + let tx = tx.clone(); + + thread::spawn(move || { + for i in 0..NUM_MSG { + tx.push((th, i)); + } + }); + } + + let mut next = vec![0; NUM_TX]; + + loop { + match rx.pop(&tx) { + Some(Value((th, v))) => { + assert_eq!(v, next[th]); + next[th] += 1; + + if next.iter().all(|&i| i == NUM_MSG) { + break; + } + } + Some(Closed) => { + panic!(); + } + None => { + thread::yield_now(); + } + } + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs b/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs new file mode 100644 index 0000000000..f165e7076e --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_mpsc.rs @@ -0,0 +1,190 @@ +use crate::sync::mpsc; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; +use tokio_test::assert_ok; + +#[test] +fn closing_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::channel(16); + + thread::spawn(move || { + tx.try_send(()).unwrap(); + drop(tx); + }); + + let v = block_on(rx.recv()); + assert!(v.is_some()); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn closing_unbounded_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::unbounded_channel(); + + thread::spawn(move || { + tx.send(()).unwrap(); + drop(tx); + }); + + let v = block_on(rx.recv()); + assert!(v.is_some()); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn closing_bounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::channel::<()>(16); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] +fn closing_and_sending() { + loom::model(|| { + let (tx1, mut rx) = mpsc::channel::<()>(16); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + tx1.try_send(()).unwrap(); + }); + + let th2 = thread::spawn(move || { + block_on(tx2.closed()); + }); + + let th3 = thread::spawn(move || { + let v = block_on(rx.recv()); + assert!(v.is_some()); + drop(rx); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn closing_unbounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::unbounded_channel::<()>(); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] +fn dropping_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::channel::<()>(16); + + for _ in 0..2 { + let tx = tx.clone(); + thread::spawn(move || { + drop(tx); + }); + } + drop(tx); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn dropping_unbounded_tx() { + loom::model(|| { + let (tx, mut rx) = mpsc::unbounded_channel::<()>(); + + for _ in 0..2 { + let tx = tx.clone(); + thread::spawn(move || { + drop(tx); + }); + } + drop(tx); + + let v = block_on(rx.recv()); + assert!(v.is_none()); + }); +} + +#[test] +fn try_recv() { + loom::model(|| { + use crate::sync::{mpsc, Semaphore}; + use loom::sync::{Arc, Mutex}; + + const PERMITS: usize = 2; + const TASKS: usize = 2; + const CYCLES: usize = 1; + + struct Context { + sem: Arc<Semaphore>, + tx: mpsc::Sender<()>, + rx: Mutex<mpsc::Receiver<()>>, + } + + fn run(ctx: &Context) { + block_on(async { + let permit = ctx.sem.acquire().await; + assert_ok!(ctx.rx.lock().unwrap().try_recv()); + crate::task::yield_now().await; + assert_ok!(ctx.tx.clone().try_send(())); + drop(permit); + }); + } + + let (tx, rx) = mpsc::channel(PERMITS); + let sem = Arc::new(Semaphore::new(PERMITS)); + let ctx = Arc::new(Context { + sem, + tx, + rx: Mutex::new(rx), + }); + + for _ in 0..PERMITS { + assert_ok!(ctx.tx.clone().try_send(())); + } + + let mut ths = Vec::new(); + + for _ in 0..TASKS { + let ctx = ctx.clone(); + + ths.push(thread::spawn(move || { + run(&ctx); + })); + } + + run(&ctx); + + for th in ths { + th.join().unwrap(); + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_notify.rs b/third_party/rust/tokio/src/sync/tests/loom_notify.rs new file mode 100644 index 0000000000..d484a75817 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_notify.rs @@ -0,0 +1,140 @@ +use crate::sync::Notify; + +use loom::future::block_on; +use loom::sync::Arc; +use loom::thread; + +#[test] +fn notify_one() { + loom::model(|| { + let tx = Arc::new(Notify::new()); + let rx = tx.clone(); + + let th = thread::spawn(move || { + block_on(async { + rx.notified().await; + }); + }); + + tx.notify_one(); + th.join().unwrap(); + }); +} + +#[test] +fn notify_waiters() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + let tx = notify.clone(); + let notified1 = notify.notified(); + let notified2 = notify.notified(); + + let th = thread::spawn(move || { + tx.notify_waiters(); + }); + + block_on(async { + notified1.await; + notified2.await; + }); + + th.join().unwrap(); + }); +} + +#[test] +fn notify_waiters_and_one() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + let tx1 = notify.clone(); + let tx2 = notify.clone(); + + let th1 = thread::spawn(move || { + tx1.notify_waiters(); + }); + + let th2 = thread::spawn(move || { + tx2.notify_one(); + }); + + let th3 = thread::spawn(move || { + let notified = notify.notified(); + + block_on(async { + notified.await; + }); + }); + + th1.join().unwrap(); + th2.join().unwrap(); + th3.join().unwrap(); + }); +} + +#[test] +fn notify_multi() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + + let mut ths = vec![]; + + for _ in 0..2 { + let notify = notify.clone(); + + ths.push(thread::spawn(move || { + block_on(async { + notify.notified().await; + notify.notify_one(); + }) + })); + } + + notify.notify_one(); + + for th in ths.drain(..) { + th.join().unwrap(); + } + + block_on(async { + notify.notified().await; + }); + }); +} + +#[test] +fn notify_drop() { + use crate::future::poll_fn; + use std::future::Future; + use std::task::Poll; + + loom::model(|| { + let notify = Arc::new(Notify::new()); + let rx1 = notify.clone(); + let rx2 = notify.clone(); + + let th1 = thread::spawn(move || { + let mut recv = Box::pin(rx1.notified()); + + block_on(poll_fn(|cx| { + if recv.as_mut().poll(cx).is_ready() { + rx1.notify_one(); + } + Poll::Ready(()) + })); + }); + + let th2 = thread::spawn(move || { + block_on(async { + rx2.notified().await; + // Trigger second notification + rx2.notify_one(); + rx2.notified().await; + }); + }); + + notify.notify_one(); + + th1.join().unwrap(); + th2.join().unwrap(); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs b/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs new file mode 100644 index 0000000000..c5f7972079 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_oneshot.rs @@ -0,0 +1,140 @@ +use crate::sync::oneshot; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::thread; +use std::task::Poll::{Pending, Ready}; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, rx) = oneshot::channel(); + + thread::spawn(move || { + tx.send(1).unwrap(); + }); + + let value = block_on(rx).unwrap(); + assert_eq!(1, value); + }); +} + +#[test] +fn changing_rx_task() { + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + tx.send(1).unwrap(); + }); + + let rx = thread::spawn(move || { + let ready = block_on(poll_fn(|cx| match Pin::new(&mut rx).poll(cx) { + Ready(Ok(value)) => { + assert_eq!(1, value); + Ready(true) + } + Ready(Err(_)) => unimplemented!(), + Pending => Ready(false), + })); + + if ready { + None + } else { + Some(rx) + } + }) + .join() + .unwrap(); + + if let Some(rx) = rx { + // Previous task parked, use a new task... + let value = block_on(rx).unwrap(); + assert_eq!(1, value); + } + }); +} + +#[test] +fn try_recv_close() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + thread::spawn(move || { + let _ = tx.send(()); + }); + + rx.close(); + let _ = rx.try_recv(); + }) +} + +#[test] +fn recv_closed() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + let _ = tx.send(1); + }); + + rx.close(); + let _ = block_on(rx); + }); +} + +// TODO: Move this into `oneshot` proper. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +struct OnClose<'a> { + tx: &'a mut oneshot::Sender<i32>, +} + +impl<'a> OnClose<'a> { + fn new(tx: &'a mut oneshot::Sender<i32>) -> Self { + OnClose { tx } + } +} + +impl Future for OnClose<'_> { + type Output = bool; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> { + let fut = self.get_mut().tx.closed(); + crate::pin!(fut); + + Ready(fut.poll(cx).is_ready()) + } +} + +#[test] +fn changing_tx_task() { + loom::model(|| { + let (mut tx, rx) = oneshot::channel::<i32>(); + + thread::spawn(move || { + drop(rx); + }); + + let tx = thread::spawn(move || { + let t1 = block_on(OnClose::new(&mut tx)); + + if t1 { + None + } else { + Some(tx) + } + }) + .join() + .unwrap(); + + if let Some(mut tx) = tx { + // Previous task parked, use a new task... + block_on(OnClose::new(&mut tx)); + } + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs b/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs new file mode 100644 index 0000000000..4b5cc7edc6 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_rwlock.rs @@ -0,0 +1,105 @@ +use crate::sync::rwlock::*; + +use loom::future::block_on; +use loom::thread; +use std::sync::Arc; + +#[test] +fn concurrent_write() { + let b = loom::model::Builder::new(); + + b.check(|| { + let rwlock = Arc::new(RwLock::<u32>::new(0)); + + let rwclone = rwlock.clone(); + let t1 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t2 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write_owned().await; + *guard += 5; + }); + }); + + t1.join().expect("thread 1 write should not panic"); + t2.join().expect("thread 2 write should not panic"); + //when all threads have finished the value on the lock should be 10 + let guard = block_on(rwlock.read()); + assert_eq!(10, *guard); + }); +} + +#[test] +fn concurrent_read_write() { + let b = loom::model::Builder::new(); + + b.check(|| { + let rwlock = Arc::new(RwLock::<u32>::new(0)); + + let rwclone = rwlock.clone(); + let t1 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t2 = thread::spawn(move || { + block_on(async { + let mut guard = rwclone.write_owned().await; + *guard += 5; + }); + }); + + let rwclone = rwlock.clone(); + let t3 = thread::spawn(move || { + block_on(async { + let guard = rwclone.read().await; + //at this state the value on the lock may either be 0, 5, or 10 + assert!(*guard == 0 || *guard == 5 || *guard == 10); + }); + }); + + { + let guard = block_on(rwlock.clone().read_owned()); + //at this state the value on the lock may either be 0, 5, or 10 + assert!(*guard == 0 || *guard == 5 || *guard == 10); + } + + t1.join().expect("thread 1 write should not panic"); + t2.join().expect("thread 2 write should not panic"); + t3.join().expect("thread 3 read should not panic"); + + let guard = block_on(rwlock.read()); + //when all threads have finished the value on the lock should be 10 + assert_eq!(10, *guard); + }); +} +#[test] +fn downgrade() { + loom::model(|| { + let lock = Arc::new(RwLock::new(1)); + + let n = block_on(lock.write()); + + let cloned_lock = lock.clone(); + let handle = thread::spawn(move || { + let mut guard = block_on(cloned_lock.write()); + *guard = 2; + }); + + let n = n.downgrade(); + assert_eq!(*n, 1); + + drop(n); + handle.join().unwrap(); + assert_eq!(*block_on(lock.read()), 2); + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs b/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs new file mode 100644 index 0000000000..76a1bc0062 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_semaphore_batch.rs @@ -0,0 +1,215 @@ +use crate::sync::batch_semaphore::*; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::Arc; +use std::task::Poll::Ready; +use std::task::{Context, Poll}; + +#[test] +fn basic_usage() { + const NUM: usize = 2; + + struct Shared { + semaphore: Semaphore, + active: AtomicUsize, + } + + async fn actor(shared: Arc<Shared>) { + shared.semaphore.acquire(1).await.unwrap(); + let actual = shared.active.fetch_add(1, SeqCst); + assert!(actual <= NUM - 1); + + let actual = shared.active.fetch_sub(1, SeqCst); + assert!(actual <= NUM); + shared.semaphore.release(1); + } + + loom::model(|| { + let shared = Arc::new(Shared { + semaphore: Semaphore::new(NUM), + active: AtomicUsize::new(0), + }); + + for _ in 0..NUM { + let shared = shared.clone(); + + thread::spawn(move || { + block_on(actor(shared)); + }); + } + + block_on(actor(shared)); + }); +} + +#[test] +fn release() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + { + let semaphore = semaphore.clone(); + thread::spawn(move || { + block_on(semaphore.acquire(1)).unwrap(); + semaphore.release(1); + }); + } + + block_on(semaphore.acquire(1)).unwrap(); + + semaphore.release(1); + }); +} + +#[test] +fn basic_closing() { + const NUM: usize = 2; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + for _ in 0..2 { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + + semaphore.release(1); + } + + Ok::<(), ()>(()) + }); + } + + semaphore.close(); + }); +} + +#[test] +fn concurrent_close() { + const NUM: usize = 3; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + semaphore.release(1); + semaphore.close(); + + Ok::<(), ()>(()) + }); + } + }); +} + +#[test] +fn concurrent_cancel() { + async fn poll_and_cancel(semaphore: Arc<Semaphore>) { + let mut acquire1 = Some(semaphore.acquire(1)); + let mut acquire2 = Some(semaphore.acquire(1)); + poll_fn(|cx| { + // poll the acquire future once, and then immediately throw + // it away. this simulates a situation where a future is + // polled and then cancelled, such as by a timeout. + if let Some(acquire) = acquire1.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + if let Some(acquire) = acquire2.take() { + pin!(acquire); + let _ = acquire.poll(cx); + } + Poll::Ready(()) + }) + .await + } + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(0)); + let t1 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t2 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + let t3 = { + let semaphore = semaphore.clone(); + thread::spawn(move || block_on(poll_and_cancel(semaphore))) + }; + + t1.join().unwrap(); + semaphore.release(10); + t2.join().unwrap(); + t3.join().unwrap(); + }); +} + +#[test] +fn batch() { + let mut b = loom::model::Builder::new(); + b.preemption_bound = Some(1); + + b.check(|| { + let semaphore = Arc::new(Semaphore::new(10)); + let active = Arc::new(AtomicUsize::new(0)); + let mut ths = vec![]; + + for _ in 0..2 { + let semaphore = semaphore.clone(); + let active = active.clone(); + + ths.push(thread::spawn(move || { + for n in &[4, 10, 8] { + block_on(semaphore.acquire(*n)).unwrap(); + + active.fetch_add(*n as usize, SeqCst); + + let num_active = active.load(SeqCst); + assert!(num_active <= 10); + + thread::yield_now(); + + active.fetch_sub(*n as usize, SeqCst); + + semaphore.release(*n as usize); + } + })); + } + + for th in ths.into_iter() { + th.join().unwrap(); + } + + assert_eq!(10, semaphore.available_permits()); + }); +} + +#[test] +fn release_during_acquire() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(10)); + semaphore + .try_acquire(8) + .expect("try_acquire should succeed; semaphore uncontended"); + let semaphore2 = semaphore.clone(); + let thread = thread::spawn(move || block_on(semaphore2.acquire(4)).unwrap()); + + semaphore.release(8); + thread.join().unwrap(); + semaphore.release(4); + assert_eq!(10, semaphore.available_permits()); + }) +} diff --git a/third_party/rust/tokio/src/sync/tests/loom_watch.rs b/third_party/rust/tokio/src/sync/tests/loom_watch.rs new file mode 100644 index 0000000000..c575b5b66c --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/loom_watch.rs @@ -0,0 +1,36 @@ +use crate::sync::watch; + +use loom::future::block_on; +use loom::thread; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, mut rx1) = watch::channel(1); + let mut rx2 = rx1.clone(); + let mut rx3 = rx1.clone(); + let mut rx4 = rx1.clone(); + let mut rx5 = rx1.clone(); + + let th = thread::spawn(move || { + tx.send(2).unwrap(); + }); + + block_on(rx1.changed()).unwrap(); + assert_eq!(*rx1.borrow(), 2); + + block_on(rx2.changed()).unwrap(); + assert_eq!(*rx2.borrow(), 2); + + block_on(rx3.changed()).unwrap(); + assert_eq!(*rx3.borrow(), 2); + + block_on(rx4.changed()).unwrap(); + assert_eq!(*rx4.borrow(), 2); + + block_on(rx5.changed()).unwrap(); + assert_eq!(*rx5.borrow(), 2); + + th.join().unwrap(); + }) +} diff --git a/third_party/rust/tokio/src/sync/tests/mod.rs b/third_party/rust/tokio/src/sync/tests/mod.rs new file mode 100644 index 0000000000..ee76418ac5 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/mod.rs @@ -0,0 +1,17 @@ +cfg_not_loom! { + mod atomic_waker; + mod notify; + mod semaphore_batch; +} + +cfg_loom! { + mod loom_atomic_waker; + mod loom_broadcast; + mod loom_list; + mod loom_mpsc; + mod loom_notify; + mod loom_oneshot; + mod loom_semaphore_batch; + mod loom_watch; + mod loom_rwlock; +} diff --git a/third_party/rust/tokio/src/sync/tests/notify.rs b/third_party/rust/tokio/src/sync/tests/notify.rs new file mode 100644 index 0000000000..20153b7a5a --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/notify.rs @@ -0,0 +1,81 @@ +use crate::sync::Notify; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::task::{Context, RawWaker, RawWakerVTable, Waker}; + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn notify_clones_waker_before_lock() { + const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w); + + unsafe fn clone_w(data: *const ()) -> RawWaker { + let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify)); + // Or some other arbitrary code that shouldn't be executed while the + // Notify wait list is locked. + arc.notify_one(); + let _arc_clone: ManuallyDrop<_> = arc.clone(); + RawWaker::new(data, VTABLE) + } + + unsafe fn drop_w(data: *const ()) { + let _ = Arc::<Notify>::from_raw(data as *const Notify); + } + + unsafe fn wake(_data: *const ()) { + unreachable!() + } + + unsafe fn wake_by_ref(_data: *const ()) { + unreachable!() + } + + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + let waker = + unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let future = notify.notified(); + pin!(future); + + // The result doesn't matter, we're just testing that we don't deadlock. + let _ = future.poll(&mut cx); +} + +#[test] +fn notify_simple() { + let notify = Notify::new(); + + let mut fut1 = tokio_test::task::spawn(notify.notified()); + assert!(fut1.poll().is_pending()); + + let mut fut2 = tokio_test::task::spawn(notify.notified()); + assert!(fut2.poll().is_pending()); + + notify.notify_waiters(); + + assert!(fut1.poll().is_ready()); + assert!(fut2.poll().is_ready()); +} + +#[test] +#[cfg(not(target_arch = "wasm32"))] +fn watch_test() { + let rt = crate::runtime::Builder::new_current_thread() + .build() + .unwrap(); + + rt.block_on(async { + let (tx, mut rx) = crate::sync::watch::channel(()); + + crate::spawn(async move { + let _ = tx.send(()); + }); + + let _ = rx.changed().await; + }); +} diff --git a/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs b/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs new file mode 100644 index 0000000000..d529a9e886 --- /dev/null +++ b/third_party/rust/tokio/src/sync/tests/semaphore_batch.rs @@ -0,0 +1,254 @@ +use crate::sync::batch_semaphore::Semaphore; +use tokio_test::*; + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::wasm_bindgen_test as test; + +#[test] +fn poll_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 99); +} + +#[test] +fn poll_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 95); + + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn try_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 99); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 98); +} + +#[test] +fn try_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 95); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn poll_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 0); + + let mut acquire_2 = task::spawn(s.acquire(1)); + // Try to acquire the second permit + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn poll_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 4); + + // Try to acquire the second permit + let mut acquire_2 = task::spawn(s.acquire(5)); + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the third permit + let mut acquire_3 = task::spawn(s.acquire(3)); + assert_pending!(acquire_3.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(2); + assert!(acquire_3.is_woken()); + + assert_ready_ok!(acquire_3.poll()); +} + +#[test] +fn try_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 0); + + assert_err!(s.try_acquire(1)); + + s.release(1); + + assert_eq!(s.available_permits(), 1); + assert_ok!(s.try_acquire(1)); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn try_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 4); + + assert_err!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 5); + + assert_ok!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 1); + + s.release(1); + assert_eq!(s.available_permits(), 2); +} + +#[test] +fn poll_acquire_one_zero_permits() { + let s = Semaphore::new(0); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the permit + let mut acquire = task::spawn(s.acquire(1)); + assert_pending!(acquire.poll()); + + s.release(1); + + assert!(acquire.is_woken()); + assert_ready_ok!(acquire.poll()); +} + +#[test] +#[should_panic] +#[cfg(not(target_arch = "wasm32"))] // wasm currently doesn't support unwinding +fn validates_max_permits() { + use std::usize; + Semaphore::new((usize::MAX >> 2) + 1); +} + +#[test] +fn close_semaphore_prevents_acquire() { + let s = Semaphore::new(5); + s.close(); + + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); +} + +#[test] +fn close_semaphore_notifies_permit1() { + let s = Semaphore::new(0); + let mut acquire = task::spawn(s.acquire(1)); + + assert_pending!(acquire.poll()); + + s.close(); + + assert!(acquire.is_woken()); + assert_ready_err!(acquire.poll()); +} + +#[test] +fn close_semaphore_notifies_permit2() { + let s = Semaphore::new(2); + + // Acquire a couple of permits + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + + let mut acquire3 = task::spawn(s.acquire(1)); + let mut acquire4 = task::spawn(s.acquire(1)); + assert_pending!(acquire3.poll()); + assert_pending!(acquire4.poll()); + + s.close(); + + assert!(acquire3.is_woken()); + assert!(acquire4.is_woken()); + + assert_ready_err!(acquire3.poll()); + assert_ready_err!(acquire4.poll()); + + assert_eq!(0, s.available_permits()); + + s.release(1); + + assert_eq!(1, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + + s.release(1); + + assert_eq!(2, s.available_permits()); +} + +#[test] +fn cancel_acquire_releases_permits() { + let s = Semaphore::new(10); + s.try_acquire(4).expect("uncontended try_acquire succeeds"); + assert_eq!(6, s.available_permits()); + + let mut acquire = task::spawn(s.acquire(8)); + assert_pending!(acquire.poll()); + + assert_eq!(0, s.available_permits()); + drop(acquire); + + assert_eq!(6, s.available_permits()); + assert_ok!(s.try_acquire(6)); +} diff --git a/third_party/rust/tokio/src/sync/watch.rs b/third_party/rust/tokio/src/sync/watch.rs new file mode 100644 index 0000000000..5673e0fca7 --- /dev/null +++ b/third_party/rust/tokio/src/sync/watch.rs @@ -0,0 +1,834 @@ +#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] + +//! A single-producer, multi-consumer channel that only retains the *last* sent +//! value. +//! +//! This channel is useful for watching for changes to a value from multiple +//! points in the code base, for example, changes to configuration values. +//! +//! # Usage +//! +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer +//! and sender halves of the channel. The channel is created with an initial +//! value. The **latest** value stored in the channel is accessed with +//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new +//! value to sent by the [`Sender`] half. +//! +//! # Examples +//! +//! ``` +//! use tokio::sync::watch; +//! +//! # async fn dox() -> Result<(), Box<dyn std::error::Error>> { +//! let (tx, mut rx) = watch::channel("hello"); +//! +//! tokio::spawn(async move { +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); +//! } +//! }); +//! +//! tx.send("world")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Closing +//! +//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect +//! when all [`Receiver`] handles have been dropped. This indicates that there +//! is no further interest in the values being produced and work can be stopped. +//! +//! # Thread safety +//! +//! Both [`Sender`] and [`Receiver`] are thread safe. They can be moved to other +//! threads and can be used in a concurrent environment. Clones of [`Receiver`] +//! handles may be moved to separate threads and also used concurrently. +//! +//! [`Sender`]: crate::sync::watch::Sender +//! [`Receiver`]: crate::sync::watch::Receiver +//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed +//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow +//! [`channel`]: crate::sync::watch::channel +//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed +//! [`Sender::closed`]: crate::sync::watch::Sender::closed + +use crate::sync::notify::Notify; + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::atomic::Ordering::Relaxed; +use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; +use std::mem; +use std::ops; + +/// Receives values from the associated [`Sender`](struct@Sender). +/// +/// Instances are created by the [`channel`](fn@channel) function. +/// +/// To turn this receiver into a `Stream`, you can use the [`WatchStream`] +/// wrapper. +/// +/// [`WatchStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.WatchStream.html +#[derive(Debug)] +pub struct Receiver<T> { + /// Pointer to the shared state + shared: Arc<Shared<T>>, + + /// Last observed version + version: Version, +} + +/// Sends values to the associated [`Receiver`](struct@Receiver). +/// +/// Instances are created by the [`channel`](fn@channel) function. +#[derive(Debug)] +pub struct Sender<T> { + shared: Arc<Shared<T>>, +} + +/// Returns a reference to the inner value. +/// +/// Outstanding borrows hold a read lock on the inner value. This means that +/// long lived borrows could cause the produce half to block. It is recommended +/// to keep the borrow as short lived as possible. +/// +/// The priority policy of the lock is dependent on the underlying lock +/// implementation, and this type does not guarantee that any particular policy +/// will be used. In particular, a producer which is waiting to acquire the lock +/// in `send` might or might not block concurrent calls to `borrow`, e.g.: +/// +/// <details><summary>Potential deadlock example</summary> +/// +/// ```text +/// // Task 1 (on thread A) | // Task 2 (on thread B) +/// let _ref1 = rx.borrow(); | +/// | // will block +/// | let _ = tx.send(()); +/// // may deadlock | +/// let _ref2 = rx.borrow(); | +/// ``` +/// </details> +#[derive(Debug)] +pub struct Ref<'a, T> { + inner: RwLockReadGuard<'a, T>, +} + +#[derive(Debug)] +struct Shared<T> { + /// The most recent value. + value: RwLock<T>, + + /// The current version. + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + state: AtomicState, + + /// Tracks the number of `Receiver` instances. + ref_count_rx: AtomicUsize, + + /// Notifies waiting receivers that the value changed. + notify_rx: Notify, + + /// Notifies any task listening for `Receiver` dropped events. + notify_tx: Notify, +} + +pub mod error { + //! Watch error types. + + use std::fmt; + + /// Error produced when sending a value fails. + #[derive(Debug)] + pub struct SendError<T>(pub T); + + // ===== impl SendError ===== + + impl<T: fmt::Debug> fmt::Display for SendError<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} + + /// Error produced when receiving a change notification. + #[derive(Debug)] + pub struct RecvError(pub(super) ()); + + // ===== impl RecvError ===== + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl std::error::Error for RecvError {} +} + +use self::state::{AtomicState, Version}; +mod state { + use crate::loom::sync::atomic::AtomicUsize; + use crate::loom::sync::atomic::Ordering::SeqCst; + + const CLOSED: usize = 1; + + /// The version part of the state. The lowest bit is always zero. + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct Version(usize); + + /// Snapshot of the state. The first bit is used as the CLOSED bit. + /// The remaining bits are used as the version. + /// + /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all + /// receivers does not set it. + #[derive(Copy, Clone, Debug)] + pub(super) struct StateSnapshot(usize); + + /// The state stored in an atomic integer. + #[derive(Debug)] + pub(super) struct AtomicState(AtomicUsize); + + impl Version { + /// Get the initial version when creating the channel. + pub(super) fn initial() -> Self { + Version(0) + } + } + + impl StateSnapshot { + /// Extract the version from the state. + pub(super) fn version(self) -> Version { + Version(self.0 & !CLOSED) + } + + /// Is the closed bit set? + pub(super) fn is_closed(self) -> bool { + (self.0 & CLOSED) == CLOSED + } + } + + impl AtomicState { + /// Create a new `AtomicState` that is not closed and which has the + /// version set to `Version::initial()`. + pub(super) fn new() -> Self { + AtomicState(AtomicUsize::new(0)) + } + + /// Load the current value of the state. + pub(super) fn load(&self) -> StateSnapshot { + StateSnapshot(self.0.load(SeqCst)) + } + + /// Increment the version counter. + pub(super) fn increment_version(&self) { + // Increment by two to avoid touching the CLOSED bit. + self.0.fetch_add(2, SeqCst); + } + + /// Set the closed bit in the state. + pub(super) fn set_closed(&self) { + self.0.fetch_or(CLOSED, SeqCst); + } + } +} + +/// Creates a new watch channel, returning the "send" and "receive" handles. +/// +/// All values sent by [`Sender`] will become visible to the [`Receiver`] handles. +/// Only the last value sent is made available to the [`Receiver`] half. All +/// intermediate values are dropped. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::watch; +/// +/// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { +/// let (tx, mut rx) = watch::channel("hello"); +/// +/// tokio::spawn(async move { +/// while rx.changed().await.is_ok() { +/// println!("received = {:?}", *rx.borrow()); +/// } +/// }); +/// +/// tx.send("world")?; +/// # Ok(()) +/// # } +/// ``` +/// +/// [`Sender`]: struct@Sender +/// [`Receiver`]: struct@Receiver +pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { + let shared = Arc::new(Shared { + value: RwLock::new(init), + state: AtomicState::new(), + ref_count_rx: AtomicUsize::new(1), + notify_rx: Notify::new(), + notify_tx: Notify::new(), + }); + + let tx = Sender { + shared: shared.clone(), + }; + + let rx = Receiver { + shared, + version: Version::initial(), + }; + + (tx, rx) +} + +impl<T> Receiver<T> { + fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self { + // No synchronization necessary as this is only used as a counter and + // not memory access. + shared.ref_count_rx.fetch_add(1, Relaxed); + + Self { shared, version } + } + + /// Returns a reference to the most recently sent value. + /// + /// This method does not mark the returned value as seen, so future calls to + /// [`changed`] may return immediately even if you have already seen the + /// value with a call to `borrow`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + /// <details><summary>Potential deadlock example</summary> + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx.borrow(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx.borrow(); | + /// ``` + /// </details> + /// + /// [`changed`]: Receiver::changed + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (_, rx) = watch::channel("hello"); + /// assert_eq!(*rx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + + /// Returns a reference to the most recently sent value and mark that value + /// as seen. + /// + /// This method marks the value as seen, so [`changed`] will not return + /// immediately if the newest value is one previously returned by + /// `borrow_and_update`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// The priority policy of the lock is dependent on the underlying lock + /// implementation, and this type does not guarantee that any particular policy + /// will be used. In particular, a producer which is waiting to acquire the lock + /// in `send` might or might not block concurrent calls to `borrow`, e.g.: + /// + /// <details><summary>Potential deadlock example</summary> + /// + /// ```text + /// // Task 1 (on thread A) | // Task 2 (on thread B) + /// let _ref1 = rx1.borrow_and_update(); | + /// | // will block + /// | let _ = tx.send(()); + /// // may deadlock | + /// let _ref2 = rx2.borrow_and_update(); | + /// ``` + /// </details> + /// + /// [`changed`]: Receiver::changed + pub fn borrow_and_update(&mut self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + self.version = self.shared.state.load().version(); + Ref { inner } + } + + /// Checks if this channel contains a message that this receiver has not yet + /// seen. The new value is not marked as seen. + /// + /// Although this method is called `has_changed`, it does not check new + /// messages for equality, so this call will return true even if the new + /// message is equal to the old message. + /// + /// Returns an error if the channel has been closed. + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// assert!(rx.has_changed().unwrap()); + /// assert_eq!(*rx.borrow_and_update(), "goodbye"); + /// + /// // The value has been marked as seen + /// assert!(!rx.has_changed().unwrap()); + /// + /// drop(tx); + /// // The `tx` handle has been dropped + /// assert!(rx.has_changed().is_err()); + /// } + /// ``` + pub fn has_changed(&self) -> Result<bool, error::RecvError> { + // Load the version from the state + let state = self.shared.state.load(); + if state.is_closed() { + // The sender has dropped. + return Err(error::RecvError(())); + } + let new_version = state.version(); + + Ok(self.version != new_version) + } + + /// Waits for a change notification, then marks the newest value as seen. + /// + /// If the newest value in the channel has not yet been marked seen when + /// this method is called, the method marks that value seen and returns + /// immediately. If the newest value has already been marked seen, then the + /// method sleeps until a new message is sent by the [`Sender`] connected to + /// this `Receiver`, or until the [`Sender`] is dropped. + /// + /// This method returns an error if and only if the [`Sender`] is dropped. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no values have been marked + /// seen by this call to `changed`. + /// + /// [`Sender`]: struct@Sender + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// tx.send("goodbye").unwrap(); + /// }); + /// + /// assert!(rx.changed().await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); + /// + /// // The `tx` handle has been dropped + /// assert!(rx.changed().await.is_err()); + /// } + /// ``` + pub async fn changed(&mut self) -> Result<(), error::RecvError> { + loop { + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. + let notified = self.shared.notify_rx.notified(); + + if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { + return ret; + } + + notified.await; + // loop around again in case the wake-up was spurious + } + } + + cfg_process_driver! { + pub(crate) fn try_has_changed(&mut self) -> Option<Result<(), error::RecvError>> { + maybe_changed(&self.shared, &mut self.version) + } + } +} + +fn maybe_changed<T>( + shared: &Shared<T>, + version: &mut Version, +) -> Option<Result<(), error::RecvError>> { + // Load the version from the state + let state = shared.state.load(); + let new_version = state.version(); + + if *version != new_version { + // Observe the new version and return + *version = new_version; + return Some(Ok(())); + } + + if state.is_closed() { + // All receivers have dropped. + return Some(Err(error::RecvError(()))); + } + + None +} + +impl<T> Clone for Receiver<T> { + fn clone(&self) -> Self { + let version = self.version; + let shared = self.shared.clone(); + + Self::from_shared(version, shared) + } +} + +impl<T> Drop for Receiver<T> { + fn drop(&mut self) { + // No synchronization necessary as this is only used as a counter and + // not memory access. + if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) { + // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` + self.shared.notify_tx.notify_waiters(); + } + } +} + +impl<T> Sender<T> { + /// Sends a new value via the channel, notifying all receivers. + /// + /// This method fails if the channel has been closed, which happens when + /// every receiver has been dropped. + pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { + // This is pretty much only useful as a hint anyway, so synchronization isn't critical. + if 0 == self.receiver_count() { + return Err(error::SendError(value)); + } + + self.send_replace(value); + Ok(()) + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, value: T) -> T { + let old = { + // Acquire the write lock and update the value. + let mut lock = self.shared.value.write().unwrap(); + let old = mem::replace(&mut *lock, value); + + self.shared.state.increment_version(); + + // Release the write lock. + // + // Incrementing the version counter while holding the lock ensures + // that receivers are able to figure out the version number of the + // value they are currently looking at. + drop(lock); + + old + }; + + // Notify all watchers + self.shared.notify_rx.notify_waiters(); + + old + } + + /// Returns a reference to the most recently sent value + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _) = watch::channel("hello"); + /// assert_eq!(*tx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + + /// Checks if the channel has been closed. This happens when all receivers + /// have dropped. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::watch::channel(()); + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.receiver_count() == 0 + } + + /// Completes when all receivers have dropped. + /// + /// This allows the producer to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// // use `rx` + /// drop(rx); + /// }); + /// + /// // Waits for `rx` to drop + /// tx.closed().await; + /// println!("the `rx` handles dropped") + /// } + /// ``` + pub async fn closed(&self) { + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); + + if self.receiver_count() == 0 { + return; + } + + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } + } + + /// Creates a new [`Receiver`] connected to this `Sender`. + /// + /// All messages sent before this call to `subscribe` are initially marked + /// as seen by the new `Receiver`. + /// + /// This method can be called even if there are no other receivers. In this + /// case, the channel is reopened. + /// + /// # Examples + /// + /// The new channel will receive messages sent on this `Sender`. + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// + /// tx.send(5).unwrap(); + /// + /// let rx = tx.subscribe(); + /// assert_eq!(5, *rx.borrow()); + /// + /// tx.send(10).unwrap(); + /// assert_eq!(10, *rx.borrow()); + /// } + /// ``` + /// + /// The most recent message is considered seen by the channel, so this test + /// is guaranteed to pass. + /// + /// ``` + /// use tokio::sync::watch; + /// use tokio::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// tx.send(5).unwrap(); + /// let mut rx = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// // by spawning and sleeping, the message is sent after `main` + /// // hits the call to `changed`. + /// # if false { + /// tokio::time::sleep(Duration::from_millis(10)).await; + /// # } + /// tx.send(100).unwrap(); + /// }); + /// + /// rx.changed().await.unwrap(); + /// assert_eq!(100, *rx.borrow()); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + let version = shared.state.load().version(); + + // The CLOSED bit in the state tracks only whether the sender is + // dropped, so we do not need to unset it if this reopens the channel. + Receiver::from_shared(version, shared) + } + + /// Returns the number of receivers that currently exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx1) = watch::channel("hello"); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = rx1.clone(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + self.shared.ref_count_rx.load(Relaxed) + } +} + +impl<T> Drop for Sender<T> { + fn drop(&mut self) { + self.shared.state.set_closed(); + self.shared.notify_rx.notify_waiters(); + } +} + +// ===== impl Ref ===== + +impl<T> ops::Deref for Ref<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.inner.deref() + } +} + +#[cfg(all(test, loom))] +mod tests { + use futures::future::FutureExt; + use loom::thread; + + // test for https://github.com/tokio-rs/tokio/issues/3168 + #[test] + fn watch_spurious_wakeup() { + loom::model(|| { + let (send, mut recv) = crate::sync::watch::channel(0i32); + + send.send(1).unwrap(); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let mut recv = recv_thread.join().unwrap(); + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + }); + + recv.changed().now_or_never(); + + send_thread.join().unwrap(); + }); + } + + #[test] + fn watch_borrow() { + loom::model(|| { + let (send, mut recv) = crate::sync::watch::channel(0i32); + + assert!(send.borrow().eq(&0)); + assert!(recv.borrow().eq(&0)); + + send.send(1).unwrap(); + assert!(send.borrow().eq(&1)); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let recv = recv_thread.join().unwrap(); + assert!(recv.borrow().eq(&3)); + assert!(send.borrow().eq(&3)); + + send.send(2).unwrap(); + + thread::spawn(move || { + assert!(recv.borrow().eq(&2)); + }); + assert!(send.borrow().eq(&2)); + }); + } +} diff --git a/third_party/rust/tokio/src/task/blocking.rs b/third_party/rust/tokio/src/task/blocking.rs new file mode 100644 index 0000000000..5fe358f3e5 --- /dev/null +++ b/third_party/rust/tokio/src/task/blocking.rs @@ -0,0 +1,199 @@ +use crate::task::JoinHandle; + +cfg_rt_multi_thread! { + /// Runs the provided blocking function on the current thread without + /// blocking the executor. + /// + /// In general, issuing a blocking call or performing a lot of compute in a + /// future without yielding is problematic, as it may prevent the executor + /// from driving other tasks forward. Calling this function informs the + /// executor that the currently executing task is about to block the thread, + /// so the executor is able to hand off any other tasks it has to a new + /// worker thread before that happens. See the [CPU-bound tasks and blocking + /// code][blocking] section for more information. + /// + /// Be aware that although this function avoids starving other independently + /// spawned tasks, any other code running concurrently in the same task will + /// be suspended during the call to `block_in_place`. This can happen e.g. + /// when using the [`join!`] macro. To avoid this issue, use + /// [`spawn_blocking`] instead of `block_in_place`. + /// + /// Note that this function cannot be used within a [`current_thread`] runtime + /// because in this case there are no other worker threads to hand off tasks + /// to. On the other hand, calling the function outside a runtime is + /// allowed. In this case, `block_in_place` just calls the provided closure + /// normally. + /// + /// Code running behind `block_in_place` cannot be cancelled. When you shut + /// down the executor, it will wait indefinitely for all blocking operations + /// to finish. You can use [`shutdown_timeout`] to stop waiting for them + /// after a certain timeout. Be aware that this will still not cancel the + /// tasks — they are simply allowed to keep running after the method + /// returns. + /// + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [`spawn_blocking`]: fn@crate::task::spawn_blocking + /// [`join!`]: macro@join + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// + /// # Examples + /// + /// ``` + /// use tokio::task; + /// + /// # async fn docs() { + /// task::block_in_place(move || { + /// // do some compute-heavy work or call synchronous code + /// }); + /// # } + /// ``` + /// + /// Code running inside `block_in_place` may use `block_on` to reenter the + /// async context. + /// + /// ``` + /// use tokio::task; + /// use tokio::runtime::Handle; + /// + /// # async fn docs() { + /// task::block_in_place(move || { + /// Handle::current().block_on(async move { + /// // do something async + /// }); + /// }); + /// # } + /// ``` + /// + /// # Panics + /// + /// This function panics if called from a [`current_thread`] runtime. + /// + /// [`current_thread`]: fn@crate::runtime::Builder::new_current_thread + pub fn block_in_place<F, R>(f: F) -> R + where + F: FnOnce() -> R, + { + crate::runtime::thread_pool::block_in_place(f) + } +} + +cfg_rt! { + /// Runs the provided closure on a thread where blocking is acceptable. + /// + /// In general, issuing a blocking call or performing a lot of compute in a + /// future without yielding is problematic, as it may prevent the executor from + /// driving other futures forward. This function runs the provided closure on a + /// thread dedicated to blocking operations. See the [CPU-bound tasks and + /// blocking code][blocking] section for more information. + /// + /// Tokio will spawn more blocking threads when they are requested through this + /// function until the upper limit configured on the [`Builder`] is reached. + /// After reaching the upper limit, the tasks are put in a queue. + /// The thread limit is very large by default, because `spawn_blocking` is often + /// used for various kinds of IO operations that cannot be performed + /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you + /// should keep this large upper limit in mind. When running many CPU-bound + /// computations, a semaphore or some other synchronization primitive should be + /// used to limit the number of computation executed in parallel. Specialized + /// CPU-bound executors, such as [rayon], may also be a good fit. + /// + /// This function is intended for non-async operations that eventually finish on + /// their own. If you want to spawn an ordinary thread, you should use + /// [`thread::spawn`] instead. + /// + /// Closures spawned using `spawn_blocking` cannot be cancelled. When you shut + /// down the executor, it will wait indefinitely for all blocking operations to + /// finish. You can use [`shutdown_timeout`] to stop waiting for them after a + /// certain timeout. Be aware that this will still not cancel the tasks — they + /// are simply allowed to keep running after the method returns. + /// + /// Note that if you are using the single threaded runtime, this function will + /// still spawn additional threads for blocking operations. The basic + /// scheduler's single thread is only used for asynchronous code. + /// + /// # Related APIs and patterns for bridging asynchronous and blocking code + /// + /// In simple cases, it is sufficient to have the closure accept input + /// parameters at creation time and return a single value (or struct/tuple, etc.). + /// + /// For more complex situations in which it is desirable to stream data to or from + /// the synchronous context, the [`mpsc channel`] has `blocking_send` and + /// `blocking_recv` methods for use in non-async code such as the thread created + /// by `spawn_blocking`. + /// + /// Another option is [`SyncIoBridge`] for cases where the synchronous context + /// is operating on byte streams. For example, you might use an asynchronous + /// HTTP client such as [hyper] to fetch data, but perform complex parsing + /// of the payload body using a library written for synchronous I/O. + /// + /// Finally, see also [Bridging with sync code][bridgesync] for discussions + /// around the opposite case of using Tokio as part of a larger synchronous + /// codebase. + /// + /// [`Builder`]: struct@crate::runtime::Builder + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [rayon]: https://docs.rs/rayon + /// [`mpsc channel`]: crate::sync::mpsc + /// [`SyncIoBridge`]: https://docs.rs/tokio-util/0.6/tokio_util/io/struct.SyncIoBridge.html + /// [hyper]: https://docs.rs/hyper + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// [bridgesync]: https://tokio.rs/tokio/topics/bridging + /// + /// # Examples + /// + /// Pass an input value and receive result of computation: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ + /// // Initial input + /// let mut v = "Hello, ".to_string(); + /// let res = task::spawn_blocking(move || { + /// // Stand-in for compute-heavy work or using synchronous APIs + /// v.push_str("world"); + /// // Pass ownership of the value back to the asynchronous context + /// v + /// }).await?; + /// + /// // `res` is the value returned from the thread + /// assert_eq!(res.as_str(), "Hello, world"); + /// # Ok(()) + /// # } + /// ``` + /// + /// Use a channel: + /// + /// ``` + /// use tokio::task; + /// use tokio::sync::mpsc; + /// + /// # async fn docs() { + /// let (tx, mut rx) = mpsc::channel(2); + /// let start = 5; + /// let worker = task::spawn_blocking(move || { + /// for x in 0..10 { + /// // Stand in for complex computation + /// tx.blocking_send(start + x).unwrap(); + /// } + /// }); + /// + /// let mut acc = 0; + /// while let Some(v) = rx.recv().await { + /// acc += v; + /// } + /// assert_eq!(acc, 95); + /// worker.await.unwrap(); + /// # } + /// ``` + #[track_caller] + pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + crate::runtime::spawn_blocking(f) + } +} diff --git a/third_party/rust/tokio/src/task/builder.rs b/third_party/rust/tokio/src/task/builder.rs new file mode 100644 index 0000000000..2086302fb9 --- /dev/null +++ b/third_party/rust/tokio/src/task/builder.rs @@ -0,0 +1,115 @@ +#![allow(unreachable_pub)] +use crate::{runtime::context, task::JoinHandle}; +use std::future::Future; + +/// Factory which is used to configure the properties of a new task. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// Methods can be chained in order to configure it. +/// +/// Currently, there is only one configuration option: +/// +/// - [`name`], which specifies an associated name for +/// the task +/// +/// There are three types of task that can be spawned from a Builder: +/// - [`spawn_local`] for executing futures on the current thread +/// - [`spawn`] for executing [`Send`] futures on the runtime +/// - [`spawn_blocking`] for executing blocking code in the +/// blocking thread pool. +/// +/// ## Example +/// +/// ```no_run +/// use tokio::net::{TcpListener, TcpStream}; +/// +/// use std::io; +/// +/// async fn process(socket: TcpStream) { +/// // ... +/// # drop(socket); +/// } +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let listener = TcpListener::bind("127.0.0.1:8080").await?; +/// +/// loop { +/// let (socket, _) = listener.accept().await?; +/// +/// tokio::task::Builder::new() +/// .name("tcp connection handler") +/// .spawn(async move { +/// // Process each socket concurrently. +/// process(socket).await +/// }); +/// } +/// } +/// ``` +/// [unstable]: crate#unstable-features +/// [`name`]: Builder::name +/// [`spawn_local`]: Builder::spawn_local +/// [`spawn`]: Builder::spawn +/// [`spawn_blocking`]: Builder::spawn_blocking +#[derive(Default, Debug)] +#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))] +pub struct Builder<'a> { + name: Option<&'a str>, +} + +impl<'a> Builder<'a> { + /// Creates a new task builder. + pub fn new() -> Self { + Self::default() + } + + /// Assigns a name to the task which will be spawned. + pub fn name(&self, name: &'a str) -> Self { + Self { name: Some(name) } + } + + /// Spawns a task on the executor. + /// + /// See [`task::spawn`](crate::task::spawn) for + /// more details. + #[track_caller] + pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + super::spawn::spawn_inner(future, self.name) + } + + /// Spawns a task on the current thread. + /// + /// See [`task::spawn_local`](crate::task::spawn_local) + /// for more details. + #[track_caller] + pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> + where + Fut: Future + 'static, + Fut::Output: 'static, + { + super::local::spawn_local_inner(future, self.name) + } + + /// Spawns blocking code on the blocking threadpool. + /// + /// See [`task::spawn_blocking`](crate::task::spawn_blocking) + /// for more details. + #[track_caller] + pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output> + where + Function: FnOnce() -> Output + Send + 'static, + Output: Send + 'static, + { + use crate::runtime::Mandatory; + let (join_handle, _was_spawned) = + context::current().spawn_blocking_inner(function, Mandatory::NonMandatory, self.name); + join_handle + } +} diff --git a/third_party/rust/tokio/src/task/join_set.rs b/third_party/rust/tokio/src/task/join_set.rs new file mode 100644 index 0000000000..8e8f74f66d --- /dev/null +++ b/third_party/rust/tokio/src/task/join_set.rs @@ -0,0 +1,248 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Handle; +use crate::task::{JoinError, JoinHandle, LocalSet}; +use crate::util::IdleNotifiedSet; + +/// A collection of tasks spawned on a Tokio runtime. +/// +/// A `JoinSet` can be used to await the completion of some or all of the tasks +/// in the set. The set is not ordered, and the tasks will be returned in the +/// order they complete. +/// +/// All of the tasks must have the same return type `T`. +/// +/// When the `JoinSet` is dropped, all tasks in the `JoinSet` are immediately aborted. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// # Examples +/// +/// Spawn multiple tasks and wait for them. +/// +/// ``` +/// use tokio::task::JoinSet; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut set = JoinSet::new(); +/// +/// for i in 0..10 { +/// set.spawn(async move { i }); +/// } +/// +/// let mut seen = [false; 10]; +/// while let Some(res) = set.join_one().await.unwrap() { +/// seen[res] = true; +/// } +/// +/// for i in 0..10 { +/// assert!(seen[i]); +/// } +/// } +/// ``` +/// +/// [unstable]: crate#unstable-features +pub struct JoinSet<T> { + inner: IdleNotifiedSet<JoinHandle<T>>, +} + +impl<T> JoinSet<T> { + /// Create a new `JoinSet`. + pub fn new() -> Self { + Self { + inner: IdleNotifiedSet::new(), + } + } + + /// Returns the number of tasks currently in the `JoinSet`. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns whether the `JoinSet` is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl<T: 'static> JoinSet<T> { + /// Spawn the provided task on the `JoinSet`. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + pub fn spawn<F>(&mut self, task: F) + where + F: Future<Output = T>, + F: Send + 'static, + T: Send, + { + self.insert(crate::spawn(task)); + } + + /// Spawn the provided task on the provided runtime and store it in this `JoinSet`. + pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) + where + F: Future<Output = T>, + F: Send + 'static, + T: Send, + { + self.insert(handle.spawn(task)); + } + + /// Spawn the provided task on the current [`LocalSet`] and store it in this `JoinSet`. + /// + /// # Panics + /// + /// This method panics if it is called outside of a `LocalSet`. + /// + /// [`LocalSet`]: crate::task::LocalSet + pub fn spawn_local<F>(&mut self, task: F) + where + F: Future<Output = T>, + F: 'static, + { + self.insert(crate::task::spawn_local(task)); + } + + /// Spawn the provided task on the provided [`LocalSet`] and store it in this `JoinSet`. + /// + /// [`LocalSet`]: crate::task::LocalSet + pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) + where + F: Future<Output = T>, + F: 'static, + { + self.insert(local_set.spawn_local(task)); + } + + fn insert(&mut self, jh: JoinHandle<T>) { + let mut entry = self.inner.insert_idle(jh); + + // Set the waker that is notified when the task completes. + entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker())); + } + + /// Waits until one of the tasks in the set completes and returns its output. + /// + /// Returns `None` if the set is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_one` is used as the event in a `tokio::select!` + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinSet`. + pub async fn join_one(&mut self) -> Result<Option<T>, JoinError> { + crate::future::poll_fn(|cx| self.poll_join_one(cx)).await + } + + /// Aborts all tasks and waits for them to finish shutting down. + /// + /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_one`] in + /// a loop until it returns `Ok(None)`. + /// + /// This method ignores any panics in the tasks shutting down. When this call returns, the + /// `JoinSet` will be empty. + /// + /// [`abort_all`]: fn@Self::abort_all + /// [`join_one`]: fn@Self::join_one + pub async fn shutdown(&mut self) { + self.abort_all(); + while self.join_one().await.transpose().is_some() {} + } + + /// Aborts all tasks on this `JoinSet`. + /// + /// This does not remove the tasks from the `JoinSet`. To wait for the tasks to complete + /// cancellation, you should call `join_one` in a loop until the `JoinSet` is empty. + pub fn abort_all(&mut self) { + self.inner.for_each(|jh| jh.abort()); + } + + /// Removes all tasks from this `JoinSet` without aborting them. + /// + /// The tasks removed by this call will continue to run in the background even if the `JoinSet` + /// is dropped. + pub fn detach_all(&mut self) { + self.inner.drain(drop); + } + + /// Polls for one of the tasks in the set to complete. + /// + /// If this returns `Poll::Ready(Ok(Some(_)))` or `Poll::Ready(Err(_))`, then the task that + /// completed is removed from the set. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled + /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to + /// `poll_join_one`, only the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is + /// available right now. + /// * `Poll::Ready(Ok(Some(value)))` if one of the tasks in this `JoinSet` has completed. The + /// `value` is the return value of one of the tasks that completed. + /// * `Poll::Ready(Err(err))` if one of the tasks in this `JoinSet` has panicked or been + /// aborted. + /// * `Poll::Ready(Ok(None))` if the `JoinSet` is empty. + /// + /// Note that this method may return `Poll::Pending` even if one of the tasks has completed. + /// This can happen if the [coop budget] is reached. + /// + /// [coop budget]: crate::task#cooperative-scheduling + fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<T>, JoinError>> { + // The call to `pop_notified` moves the entry to the `idle` list. It is moved back to + // the `notified` list if the waker is notified in the `poll` call below. + let mut entry = match self.inner.pop_notified(cx.waker()) { + Some(entry) => entry, + None => { + if self.is_empty() { + return Poll::Ready(Ok(None)); + } else { + // The waker was set by `pop_notified`. + return Poll::Pending; + } + } + }; + + let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx)); + + if let Poll::Ready(res) = res { + entry.remove(); + Poll::Ready(Some(res).transpose()) + } else { + // A JoinHandle generally won't emit a wakeup without being ready unless + // the coop limit has been reached. We yield to the executor in this + // case. + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +impl<T> Drop for JoinSet<T> { + fn drop(&mut self) { + self.inner.drain(|join_handle| join_handle.abort()); + } +} + +impl<T> fmt::Debug for JoinSet<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JoinSet").field("len", &self.len()).finish() + } +} + +impl<T> Default for JoinSet<T> { + fn default() -> Self { + Self::new() + } +} diff --git a/third_party/rust/tokio/src/task/local.rs b/third_party/rust/tokio/src/task/local.rs new file mode 100644 index 0000000000..2dbd970604 --- /dev/null +++ b/third_party/rust/tokio/src/task/local.rs @@ -0,0 +1,698 @@ +//! Runs `!Send` futures on the current thread. +use crate::loom::sync::{Arc, Mutex}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; +use crate::sync::AtomicWaker; +use crate::util::VecDequeCell; + +use std::cell::Cell; +use std::collections::VecDeque; +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::Poll; + +use pin_project_lite::pin_project; + +cfg_rt! { + /// A set of tasks which are executed on the same thread. + /// + /// In some cases, it is necessary to run one or more futures that do not + /// implement [`Send`] and thus are unsafe to send between threads. In these + /// cases, a [local task set] may be used to schedule one or more `!Send` + /// futures to run together on the same thread. + /// + /// For example, the following code will not compile: + /// + /// ```rust,compile_fail + /// use std::rc::Rc; + /// + /// #[tokio::main] + /// async fn main() { + /// // `Rc` does not implement `Send`, and thus may not be sent between + /// // threads safely. + /// let unsend_data = Rc::new("my unsend data..."); + /// + /// let unsend_data = unsend_data.clone(); + /// // Because the `async` block here moves `unsend_data`, the future is `!Send`. + /// // Since `tokio::spawn` requires the spawned future to implement `Send`, this + /// // will not compile. + /// tokio::spawn(async move { + /// println!("{}", unsend_data); + /// // ... + /// }).await.unwrap(); + /// } + /// ``` + /// + /// # Use with `run_until` + /// + /// To spawn `!Send` futures, we can use a local task set to schedule them + /// on the thread calling [`Runtime::block_on`]. When running inside of the + /// local task set, we can use [`task::spawn_local`], which can spawn + /// `!Send` futures. For example: + /// + /// ```rust + /// use std::rc::Rc; + /// use tokio::task; + /// + /// #[tokio::main] + /// async fn main() { + /// let unsend_data = Rc::new("my unsend data..."); + /// + /// // Construct a local task set that can run `!Send` futures. + /// let local = task::LocalSet::new(); + /// + /// // Run the local task set. + /// local.run_until(async move { + /// let unsend_data = unsend_data.clone(); + /// // `spawn_local` ensures that the future is spawned on the local + /// // task set. + /// task::spawn_local(async move { + /// println!("{}", unsend_data); + /// // ... + /// }).await.unwrap(); + /// }).await; + /// } + /// ``` + /// **Note:** The `run_until` method can only be used in `#[tokio::main]`, + /// `#[tokio::test]` or directly inside a call to [`Runtime::block_on`]. It + /// cannot be used inside a task spawned with `tokio::spawn`. + /// + /// ## Awaiting a `LocalSet` + /// + /// Additionally, a `LocalSet` itself implements `Future`, completing when + /// *all* tasks spawned on the `LocalSet` complete. This can be used to run + /// several futures on a `LocalSet` and drive the whole set until they + /// complete. For example, + /// + /// ```rust + /// use tokio::{task, time}; + /// use std::rc::Rc; + /// + /// #[tokio::main] + /// async fn main() { + /// let unsend_data = Rc::new("world"); + /// let local = task::LocalSet::new(); + /// + /// let unsend_data2 = unsend_data.clone(); + /// local.spawn_local(async move { + /// // ... + /// println!("hello {}", unsend_data2) + /// }); + /// + /// local.spawn_local(async move { + /// time::sleep(time::Duration::from_millis(100)).await; + /// println!("goodbye {}", unsend_data) + /// }); + /// + /// // ... + /// + /// local.await; + /// } + /// ``` + /// **Note:** Awaiting a `LocalSet` can only be done inside + /// `#[tokio::main]`, `#[tokio::test]` or directly inside a call to + /// [`Runtime::block_on`]. It cannot be used inside a task spawned with + /// `tokio::spawn`. + /// + /// ## Use inside `tokio::spawn` + /// + /// The two methods mentioned above cannot be used inside `tokio::spawn`, so + /// to spawn `!Send` futures from inside `tokio::spawn`, we need to do + /// something else. The solution is to create the `LocalSet` somewhere else, + /// and communicate with it using an [`mpsc`] channel. + /// + /// The following example puts the `LocalSet` inside a new thread. + /// ``` + /// use tokio::runtime::Builder; + /// use tokio::sync::{mpsc, oneshot}; + /// use tokio::task::LocalSet; + /// + /// // This struct describes the task you want to spawn. Here we include + /// // some simple examples. The oneshot channel allows sending a response + /// // to the spawner. + /// #[derive(Debug)] + /// enum Task { + /// PrintNumber(u32), + /// AddOne(u32, oneshot::Sender<u32>), + /// } + /// + /// #[derive(Clone)] + /// struct LocalSpawner { + /// send: mpsc::UnboundedSender<Task>, + /// } + /// + /// impl LocalSpawner { + /// pub fn new() -> Self { + /// let (send, mut recv) = mpsc::unbounded_channel(); + /// + /// let rt = Builder::new_current_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// std::thread::spawn(move || { + /// let local = LocalSet::new(); + /// + /// local.spawn_local(async move { + /// while let Some(new_task) = recv.recv().await { + /// tokio::task::spawn_local(run_task(new_task)); + /// } + /// // If the while loop returns, then all the LocalSpawner + /// // objects have have been dropped. + /// }); + /// + /// // This will return once all senders are dropped and all + /// // spawned tasks have returned. + /// rt.block_on(local); + /// }); + /// + /// Self { + /// send, + /// } + /// } + /// + /// pub fn spawn(&self, task: Task) { + /// self.send.send(task).expect("Thread with LocalSet has shut down."); + /// } + /// } + /// + /// // This task may do !Send stuff. We use printing a number as an example, + /// // but it could be anything. + /// // + /// // The Task struct is an enum to support spawning many different kinds + /// // of operations. + /// async fn run_task(task: Task) { + /// match task { + /// Task::PrintNumber(n) => { + /// println!("{}", n); + /// }, + /// Task::AddOne(n, response) => { + /// // We ignore failures to send the response. + /// let _ = response.send(n + 1); + /// }, + /// } + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let spawner = LocalSpawner::new(); + /// + /// let (send, response) = oneshot::channel(); + /// spawner.spawn(Task::AddOne(10, send)); + /// let eleven = response.await.unwrap(); + /// assert_eq!(eleven, 11); + /// } + /// ``` + /// + /// [`Send`]: trait@std::marker::Send + /// [local task set]: struct@LocalSet + /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on + /// [`task::spawn_local`]: fn@spawn_local + /// [`mpsc`]: mod@crate::sync::mpsc + pub struct LocalSet { + /// Current scheduler tick. + tick: Cell<u8>, + + /// State available from thread-local. + context: Context, + + /// This type should not be Send. + _not_send: PhantomData<*const ()>, + } +} + +/// State available from the thread-local. +struct Context { + /// Collection of all active tasks spawned onto this executor. + owned: LocalOwnedTasks<Arc<Shared>>, + + /// Local run queue sender and receiver. + queue: VecDequeCell<task::Notified<Arc<Shared>>>, + + /// State shared between threads. + shared: Arc<Shared>, +} + +/// LocalSet state shared between threads. +struct Shared { + /// Remote run queue sender. + queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>, + + /// Wake the `LocalSet` task. + waker: AtomicWaker, +} + +pin_project! { + #[derive(Debug)] + struct RunUntil<'a, F> { + local_set: &'a LocalSet, + #[pin] + future: F, + } +} + +scoped_thread_local!(static CURRENT: Context); + +cfg_rt! { + /// Spawns a `!Send` future on the local task set. + /// + /// The spawned future will be run on the same thread that called `spawn_local.` + /// This may only be called from the context of a local task set. + /// + /// # Panics + /// + /// - This function panics if called outside of a local task set. + /// + /// # Examples + /// + /// ```rust + /// use std::rc::Rc; + /// use tokio::task; + /// + /// #[tokio::main] + /// async fn main() { + /// let unsend_data = Rc::new("my unsend data..."); + /// + /// let local = task::LocalSet::new(); + /// + /// // Run the local task set. + /// local.run_until(async move { + /// let unsend_data = unsend_data.clone(); + /// task::spawn_local(async move { + /// println!("{}", unsend_data); + /// // ... + /// }).await.unwrap(); + /// }).await; + /// } + /// ``` + #[track_caller] + pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output> + where + F: Future + 'static, + F::Output: 'static, + { + spawn_local_inner(future, None) + } + + + #[track_caller] + pub(super) fn spawn_local_inner<F>(future: F, name: Option<&str>) -> JoinHandle<F::Output> + where F: Future + 'static, + F::Output: 'static + { + let future = crate::util::trace::task(future, "local", name); + CURRENT.with(|maybe_cx| { + let cx = maybe_cx + .expect("`spawn_local` called from outside of a `task::LocalSet`"); + + let (handle, notified) = cx.owned.bind(future, cx.shared.clone()); + + if let Some(notified) = notified { + cx.shared.schedule(notified); + } + + handle + }) + } +} + +/// Initial queue capacity. +const INITIAL_CAPACITY: usize = 64; + +/// Max number of tasks to poll per tick. +const MAX_TASKS_PER_TICK: usize = 61; + +/// How often it check the remote queue first. +const REMOTE_FIRST_INTERVAL: u8 = 31; + +impl LocalSet { + /// Returns a new local task set. + pub fn new() -> LocalSet { + LocalSet { + tick: Cell::new(0), + context: Context { + owned: LocalOwnedTasks::new(), + queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), + shared: Arc::new(Shared { + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), + waker: AtomicWaker::new(), + }), + }, + _not_send: PhantomData, + } + } + + /// Spawns a `!Send` task onto the local task set. + /// + /// This task is guaranteed to be run on the current thread. + /// + /// Unlike the free function [`spawn_local`], this method may be used to + /// spawn local tasks when the task set is _not_ running. For example: + /// ```rust + /// use tokio::task; + /// + /// #[tokio::main] + /// async fn main() { + /// let local = task::LocalSet::new(); + /// + /// // Spawn a future on the local set. This future will be run when + /// // we call `run_until` to drive the task set. + /// local.spawn_local(async { + /// // ... + /// }); + /// + /// // Run the local task set. + /// local.run_until(async move { + /// // ... + /// }).await; + /// + /// // When `run` finishes, we can spawn _more_ futures, which will + /// // run in subsequent calls to `run_until`. + /// local.spawn_local(async { + /// // ... + /// }); + /// + /// local.run_until(async move { + /// // ... + /// }).await; + /// } + /// ``` + /// [`spawn_local`]: fn@spawn_local + #[track_caller] + pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + 'static, + F::Output: 'static, + { + let future = crate::util::trace::task(future, "local", None); + + let (handle, notified) = self.context.owned.bind(future, self.context.shared.clone()); + + if let Some(notified) = notified { + self.context.shared.schedule(notified); + } + + self.context.shared.waker.wake(); + handle + } + + /// Runs a future to completion on the provided runtime, driving any local + /// futures spawned on this task set on the current thread. + /// + /// This runs the given future on the runtime, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers which + /// the future spawns internally will be executed on the runtime. The future + /// may also call [`spawn_local`] to spawn_local additional local futures on the + /// current thread. + /// + /// This method should not be called from an asynchronous context. + /// + /// # Panics + /// + /// This function panics if the executor is at capacity, if the provided + /// future panics, or if called within an asynchronous execution context. + /// + /// # Notes + /// + /// Since this function internally calls [`Runtime::block_on`], and drives + /// futures in the local task set inside that call to `block_on`, the local + /// futures may not use [in-place blocking]. If a blocking call needs to be + /// issued from a local task, the [`spawn_blocking`] API may be used instead. + /// + /// For example, this will panic: + /// ```should_panic + /// use tokio::runtime::Runtime; + /// use tokio::task; + /// + /// let rt = Runtime::new().unwrap(); + /// let local = task::LocalSet::new(); + /// local.block_on(&rt, async { + /// let join = task::spawn_local(async { + /// let blocking_result = task::block_in_place(|| { + /// // ... + /// }); + /// // ... + /// }); + /// join.await.unwrap(); + /// }) + /// ``` + /// This, however, will not panic: + /// ``` + /// use tokio::runtime::Runtime; + /// use tokio::task; + /// + /// let rt = Runtime::new().unwrap(); + /// let local = task::LocalSet::new(); + /// local.block_on(&rt, async { + /// let join = task::spawn_local(async { + /// let blocking_result = task::spawn_blocking(|| { + /// // ... + /// }).await; + /// // ... + /// }); + /// join.await.unwrap(); + /// }) + /// ``` + /// + /// [`spawn_local`]: fn@spawn_local + /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on + /// [in-place blocking]: fn@crate::task::block_in_place + /// [`spawn_blocking`]: fn@crate::task::spawn_blocking + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output + where + F: Future, + { + rt.block_on(self.run_until(future)) + } + + /// Runs a future to completion on the local set, returning its output. + /// + /// This returns a future that runs the given future with a local set, + /// allowing it to call [`spawn_local`] to spawn additional `!Send` futures. + /// Any local futures spawned on the local set will be driven in the + /// background until the future passed to `run_until` completes. When the future + /// passed to `run` finishes, any local futures which have not completed + /// will remain on the local set, and will be driven on subsequent calls to + /// `run_until` or when [awaiting the local set] itself. + /// + /// # Examples + /// + /// ```rust + /// use tokio::task; + /// + /// #[tokio::main] + /// async fn main() { + /// task::LocalSet::new().run_until(async { + /// task::spawn_local(async move { + /// // ... + /// }).await.unwrap(); + /// // ... + /// }).await; + /// } + /// ``` + /// + /// [`spawn_local`]: fn@spawn_local + /// [awaiting the local set]: #awaiting-a-localset + pub async fn run_until<F>(&self, future: F) -> F::Output + where + F: Future, + { + let run_until = RunUntil { + future, + local_set: self, + }; + run_until.await + } + + /// Ticks the scheduler, returning whether the local future needs to be + /// notified again. + fn tick(&self) -> bool { + for _ in 0..MAX_TASKS_PER_TICK { + match self.next_task() { + // Run the task + // + // Safety: As spawned tasks are `!Send`, `run_unchecked` must be + // used. We are responsible for maintaining the invariant that + // `run_unchecked` is only called on threads that spawned the + // task initially. Because `LocalSet` itself is `!Send`, and + // `spawn_local` spawns into the `LocalSet` on the current + // thread, the invariant is maintained. + Some(task) => crate::coop::budget(|| task.run()), + // We have fully drained the queue of notified tasks, so the + // local future doesn't need to be notified again — it can wait + // until something else wakes a task in the local set. + None => return false, + } + } + + true + } + + fn next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>> { + let tick = self.tick.get(); + self.tick.set(tick.wrapping_add(1)); + + let task = if tick % REMOTE_FIRST_INTERVAL == 0 { + self.context + .shared + .queue + .lock() + .as_mut() + .and_then(|queue| queue.pop_front()) + .or_else(|| self.context.queue.pop_front()) + } else { + self.context.queue.pop_front().or_else(|| { + self.context + .shared + .queue + .lock() + .as_mut() + .and_then(|queue| queue.pop_front()) + }) + }; + + task.map(|task| self.context.owned.assert_owner(task)) + } + + fn with<T>(&self, f: impl FnOnce() -> T) -> T { + CURRENT.set(&self.context, f) + } +} + +impl fmt::Debug for LocalSet { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("LocalSet").finish() + } +} + +impl Future for LocalSet { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { + // Register the waker before starting to work + self.context.shared.waker.register_by_ref(cx.waker()); + + if self.with(|| self.tick()) { + // If `tick` returns true, we need to notify the local future again: + // there are still tasks remaining in the run queue. + cx.waker().wake_by_ref(); + Poll::Pending + } else if self.context.owned.is_empty() { + // If the scheduler has no remaining futures, we're done! + Poll::Ready(()) + } else { + // There are still futures in the local set, but we've polled all the + // futures in the run queue. Therefore, we can just return Pending + // since the remaining futures will be woken from somewhere else. + Poll::Pending + } + } +} + +impl Default for LocalSet { + fn default() -> LocalSet { + LocalSet::new() + } +} + +impl Drop for LocalSet { + fn drop(&mut self) { + self.with(|| { + // Shut down all tasks in the LocalOwnedTasks and close it to + // prevent new tasks from ever being added. + self.context.owned.close_and_shutdown_all(); + + // We already called shutdown on all tasks above, so there is no + // need to call shutdown. + for task in self.context.queue.take() { + drop(task); + } + + // Take the queue from the Shared object to prevent pushing + // notifications to it in the future. + let queue = self.context.shared.queue.lock().take().unwrap(); + for task in queue { + drop(task); + } + + assert!(self.context.owned.is_empty()); + }); + } +} + +// === impl LocalFuture === + +impl<T: Future> Future for RunUntil<'_, T> { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + me.local_set.with(|| { + me.local_set + .context + .shared + .waker + .register_by_ref(cx.waker()); + + let _no_blocking = crate::runtime::enter::disallow_blocking(); + let f = me.future; + + if let Poll::Ready(output) = crate::coop::budget(|| f.poll(cx)) { + return Poll::Ready(output); + } + + if me.local_set.tick() { + // If `tick` returns `true`, we need to notify the local future again: + // there are still tasks remaining in the run queue. + cx.waker().wake_by_ref(); + } + + Poll::Pending + }) + } +} + +impl Shared { + /// Schedule the provided task on the scheduler. + fn schedule(&self, task: task::Notified<Arc<Self>>) { + CURRENT.with(|maybe_cx| match maybe_cx { + Some(cx) if cx.shared.ptr_eq(self) => { + cx.queue.push_back(task); + } + _ => { + // First check whether the queue is still there (if not, the + // LocalSet is dropped). Then push to it if so, and if not, + // do nothing. + let mut lock = self.queue.lock(); + + if let Some(queue) = lock.as_mut() { + queue.push_back(task); + drop(lock); + self.waker.wake(); + } + } + }); + } + + fn ptr_eq(&self, other: &Shared) -> bool { + std::ptr::eq(self, other) + } +} + +impl task::Schedule for Arc<Shared> { + fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { + CURRENT.with(|maybe_cx| { + let cx = maybe_cx.expect("scheduler context missing"); + assert!(cx.shared.ptr_eq(self)); + cx.owned.remove(task) + }) + } + + fn schedule(&self, task: task::Notified<Self>) { + Shared::schedule(self, task); + } +} diff --git a/third_party/rust/tokio/src/task/mod.rs b/third_party/rust/tokio/src/task/mod.rs new file mode 100644 index 0000000000..d532155a1f --- /dev/null +++ b/third_party/rust/tokio/src/task/mod.rs @@ -0,0 +1,317 @@ +//! Asynchronous green-threads. +//! +//! ## What are Tasks? +//! +//! A _task_ is a light weight, non-blocking unit of execution. A task is similar +//! to an OS thread, but rather than being managed by the OS scheduler, they are +//! managed by the [Tokio runtime][rt]. Another name for this general pattern is +//! [green threads]. If you are familiar with [Go's goroutines], [Kotlin's +//! coroutines], or [Erlang's processes], you can think of Tokio's tasks as +//! something similar. +//! +//! Key points about tasks include: +//! +//! * Tasks are **light weight**. Because tasks are scheduled by the Tokio +//! runtime rather than the operating system, creating new tasks or switching +//! between tasks does not require a context switch and has fairly low +//! overhead. Creating, running, and destroying large numbers of tasks is +//! quite cheap, especially compared to OS threads. +//! +//! * Tasks are scheduled **cooperatively**. Most operating systems implement +//! _preemptive multitasking_. This is a scheduling technique where the +//! operating system allows each thread to run for a period of time, and then +//! _preempts_ it, temporarily pausing that thread and switching to another. +//! Tasks, on the other hand, implement _cooperative multitasking_. In +//! cooperative multitasking, a task is allowed to run until it _yields_, +//! indicating to the Tokio runtime's scheduler that it cannot currently +//! continue executing. When a task yields, the Tokio runtime switches to +//! executing the next task. +//! +//! * Tasks are **non-blocking**. Typically, when an OS thread performs I/O or +//! must synchronize with another thread, it _blocks_, allowing the OS to +//! schedule another thread. When a task cannot continue executing, it must +//! yield instead, allowing the Tokio runtime to schedule another task. Tasks +//! should generally not perform system calls or other operations that could +//! block a thread, as this would prevent other tasks running on the same +//! thread from executing as well. Instead, this module provides APIs for +//! running blocking operations in an asynchronous context. +//! +//! [rt]: crate::runtime +//! [green threads]: https://en.wikipedia.org/wiki/Green_threads +//! [Go's goroutines]: https://tour.golang.org/concurrency/1 +//! [Kotlin's coroutines]: https://kotlinlang.org/docs/reference/coroutines-overview.html +//! [Erlang's processes]: http://erlang.org/doc/getting_started/conc_prog.html#processes +//! +//! ## Working with Tasks +//! +//! This module provides the following APIs for working with tasks: +//! +//! ### Spawning +//! +//! Perhaps the most important function in this module is [`task::spawn`]. This +//! function can be thought of as an async equivalent to the standard library's +//! [`thread::spawn`][`std::thread::spawn`]. It takes an `async` block or other +//! [future], and creates a new task to run that work concurrently: +//! +//! ``` +//! use tokio::task; +//! +//! # async fn doc() { +//! task::spawn(async { +//! // perform some work here... +//! }); +//! # } +//! ``` +//! +//! Like [`std::thread::spawn`], `task::spawn` returns a [`JoinHandle`] struct. +//! A `JoinHandle` is itself a future which may be used to await the output of +//! the spawned task. For example: +//! +//! ``` +//! use tokio::task; +//! +//! # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let join = task::spawn(async { +//! // ... +//! "hello world!" +//! }); +//! +//! // ... +//! +//! // Await the result of the spawned task. +//! let result = join.await?; +//! assert_eq!(result, "hello world!"); +//! # Ok(()) +//! # } +//! ``` +//! +//! Again, like `std::thread`'s [`JoinHandle` type][thread_join], if the spawned +//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]. For +//! example: +//! +//! ``` +//! use tokio::task; +//! +//! # #[tokio::main] async fn main() { +//! let join = task::spawn(async { +//! panic!("something bad happened!") +//! }); +//! +//! // The returned result indicates that the task failed. +//! assert!(join.await.is_err()); +//! # } +//! ``` +//! +//! `spawn`, `JoinHandle`, and `JoinError` are present when the "rt" +//! feature flag is enabled. +//! +//! [`task::spawn`]: crate::task::spawn() +//! [future]: std::future::Future +//! [`std::thread::spawn`]: std::thread::spawn +//! [`JoinHandle`]: crate::task::JoinHandle +//! [thread_join]: std::thread::JoinHandle +//! [`JoinError`]: crate::task::JoinError +//! +//! ### Blocking and Yielding +//! +//! As we discussed above, code running in asynchronous tasks should not perform +//! operations that can block. A blocking operation performed in a task running +//! on a thread that is also running other tasks would block the entire thread, +//! preventing other tasks from running. +//! +//! Instead, Tokio provides two APIs for running blocking operations in an +//! asynchronous context: [`task::spawn_blocking`] and [`task::block_in_place`]. +//! +//! Be aware that if you call a non-async method from async code, that non-async +//! method is still inside the asynchronous context, so you should also avoid +//! blocking operations there. This includes destructors of objects destroyed in +//! async code. +//! +//! #### spawn_blocking +//! +//! The `task::spawn_blocking` function is similar to the `task::spawn` function +//! discussed in the previous section, but rather than spawning an +//! _non-blocking_ future on the Tokio runtime, it instead spawns a +//! _blocking_ function on a dedicated thread pool for blocking tasks. For +//! example: +//! +//! ``` +//! use tokio::task; +//! +//! # async fn docs() { +//! task::spawn_blocking(|| { +//! // do some compute-heavy work or call synchronous code +//! }); +//! # } +//! ``` +//! +//! Just like `task::spawn`, `task::spawn_blocking` returns a `JoinHandle` +//! which we can use to await the result of the blocking operation: +//! +//! ```rust +//! # use tokio::task; +//! # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ +//! let join = task::spawn_blocking(|| { +//! // do some compute-heavy work or call synchronous code +//! "blocking completed" +//! }); +//! +//! let result = join.await?; +//! assert_eq!(result, "blocking completed"); +//! # Ok(()) +//! # } +//! ``` +//! +//! #### block_in_place +//! +//! When using the [multi-threaded runtime][rt-multi-thread], the [`task::block_in_place`] +//! function is also available. Like `task::spawn_blocking`, this function +//! allows running a blocking operation from an asynchronous context. Unlike +//! `spawn_blocking`, however, `block_in_place` works by transitioning the +//! _current_ worker thread to a blocking thread, moving other tasks running on +//! that thread to another worker thread. This can improve performance by avoiding +//! context switches. +//! +//! For example: +//! +//! ``` +//! use tokio::task; +//! +//! # async fn docs() { +//! let result = task::block_in_place(|| { +//! // do some compute-heavy work or call synchronous code +//! "blocking completed" +//! }); +//! +//! assert_eq!(result, "blocking completed"); +//! # } +//! ``` +//! +//! #### yield_now +//! +//! In addition, this module provides a [`task::yield_now`] async function +//! that is analogous to the standard library's [`thread::yield_now`]. Calling +//! and `await`ing this function will cause the current task to yield to the +//! Tokio runtime's scheduler, allowing other tasks to be +//! scheduled. Eventually, the yielding task will be polled again, allowing it +//! to execute. For example: +//! +//! ```rust +//! use tokio::task; +//! +//! # #[tokio::main] async fn main() { +//! async { +//! task::spawn(async { +//! // ... +//! println!("spawned task done!") +//! }); +//! +//! // Yield, allowing the newly-spawned task to execute first. +//! task::yield_now().await; +//! println!("main task done!"); +//! } +//! # .await; +//! # } +//! ``` +//! +//! ### Cooperative scheduling +//! +//! A single call to [`poll`] on a top-level task may potentially do a lot of +//! work before it returns `Poll::Pending`. If a task runs for a long period of +//! time without yielding back to the executor, it can starve other tasks +//! waiting on that executor to execute them, or drive underlying resources. +//! Since Rust does not have a runtime, it is difficult to forcibly preempt a +//! long-running task. Instead, this module provides an opt-in mechanism for +//! futures to collaborate with the executor to avoid starvation. +//! +//! Consider a future like this one: +//! +//! ``` +//! # use tokio_stream::{Stream, StreamExt}; +//! async fn drop_all<I: Stream + Unpin>(mut input: I) { +//! while let Some(_) = input.next().await {} +//! } +//! ``` +//! +//! It may look harmless, but consider what happens under heavy load if the +//! input stream is _always_ ready. If we spawn `drop_all`, the task will never +//! yield, and will starve other tasks and resources on the same executor. +//! +//! To account for this, Tokio has explicit yield points in a number of library +//! functions, which force tasks to return to the executor periodically. +//! +//! +//! #### unconstrained +//! +//! If necessary, [`task::unconstrained`] lets you opt out a future of Tokio's cooperative +//! scheduling. When a future is wrapped with `unconstrained`, it will never be forced to yield to +//! Tokio. For example: +//! +//! ``` +//! # #[tokio::main] +//! # async fn main() { +//! use tokio::{task, sync::mpsc}; +//! +//! let fut = async { +//! let (tx, mut rx) = mpsc::unbounded_channel(); +//! +//! for i in 0..1000 { +//! let _ = tx.send(()); +//! // This will always be ready. If coop was in effect, this code would be forced to yield +//! // periodically. However, if left unconstrained, then this code will never yield. +//! rx.recv().await; +//! } +//! }; +//! +//! task::unconstrained(fut).await; +//! # } +//! ``` +//! +//! [`task::spawn_blocking`]: crate::task::spawn_blocking +//! [`task::block_in_place`]: crate::task::block_in_place +//! [rt-multi-thread]: ../runtime/index.html#threaded-scheduler +//! [`task::yield_now`]: crate::task::yield_now() +//! [`thread::yield_now`]: std::thread::yield_now +//! [`task::unconstrained`]: crate::task::unconstrained() +//! [`poll`]: method@std::future::Future::poll + +cfg_rt! { + pub use crate::runtime::task::{JoinError, JoinHandle}; + + mod blocking; + pub use blocking::spawn_blocking; + + mod spawn; + pub use spawn::spawn; + + cfg_rt_multi_thread! { + pub use blocking::block_in_place; + } + + mod yield_now; + pub use yield_now::yield_now; + + mod local; + pub use local::{spawn_local, LocalSet}; + + mod task_local; + pub use task_local::LocalKey; + + mod unconstrained; + pub use unconstrained::{unconstrained, Unconstrained}; + + cfg_unstable! { + mod join_set; + pub use join_set::JoinSet; + } + + cfg_trace! { + mod builder; + pub use builder::Builder; + } + + /// Task-related futures. + pub mod futures { + pub use super::task_local::TaskLocalFuture; + } +} diff --git a/third_party/rust/tokio/src/task/spawn.rs b/third_party/rust/tokio/src/task/spawn.rs new file mode 100644 index 0000000000..a9d736674c --- /dev/null +++ b/third_party/rust/tokio/src/task/spawn.rs @@ -0,0 +1,149 @@ +use crate::{task::JoinHandle, util::error::CONTEXT_MISSING_ERROR}; + +use std::future::Future; + +cfg_rt! { + /// Spawns a new asynchronous task, returning a + /// [`JoinHandle`](super::JoinHandle) for it. + /// + /// Spawning a task enables the task to execute concurrently to other tasks. The + /// spawned task may execute on the current thread, or it may be sent to a + /// different thread to be executed. The specifics depend on the current + /// [`Runtime`](crate::runtime::Runtime) configuration. + /// + /// There is no guarantee that a spawned task will execute to completion. + /// When a runtime is shutdown, all outstanding tasks are dropped, + /// regardless of the lifecycle of that task. + /// + /// This function must be called from the context of a Tokio runtime. Tasks running on + /// the Tokio runtime are always inside its context, but you can also enter the context + /// using the [`Runtime::enter`](crate::runtime::Runtime::enter()) method. + /// + /// # Examples + /// + /// In this example, a server is started and `spawn` is used to start a new task + /// that processes each received connection. + /// + /// ```no_run + /// use tokio::net::{TcpListener, TcpStream}; + /// + /// use std::io; + /// + /// async fn process(socket: TcpStream) { + /// // ... + /// # drop(socket); + /// } + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; + /// + /// loop { + /// let (socket, _) = listener.accept().await?; + /// + /// tokio::spawn(async move { + /// // Process each socket concurrently. + /// process(socket).await + /// }); + /// } + /// } + /// ``` + /// + /// # Panics + /// + /// Panics if called from **outside** of the Tokio runtime. + /// + /// # Using `!Send` values from a task + /// + /// The task supplied to `spawn` must implement `Send`. However, it is + /// possible to **use** `!Send` values from the task as long as they only + /// exist between calls to `.await`. + /// + /// For example, this will work: + /// + /// ``` + /// use tokio::task; + /// + /// use std::rc::Rc; + /// + /// fn use_rc(rc: Rc<()>) { + /// // Do stuff w/ rc + /// # drop(rc); + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// tokio::spawn(async { + /// // Force the `Rc` to stay in a scope with no `.await` + /// { + /// let rc = Rc::new(()); + /// use_rc(rc.clone()); + /// } + /// + /// task::yield_now().await; + /// }).await.unwrap(); + /// } + /// ``` + /// + /// This will **not** work: + /// + /// ```compile_fail + /// use tokio::task; + /// + /// use std::rc::Rc; + /// + /// fn use_rc(rc: Rc<()>) { + /// // Do stuff w/ rc + /// # drop(rc); + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// tokio::spawn(async { + /// let rc = Rc::new(()); + /// + /// task::yield_now().await; + /// + /// use_rc(rc.clone()); + /// }).await.unwrap(); + /// } + /// ``` + /// + /// Holding on to a `!Send` value across calls to `.await` will result in + /// an unfriendly compile error message similar to: + /// + /// ```text + /// `[... some type ...]` cannot be sent between threads safely + /// ``` + /// + /// or: + /// + /// ```text + /// error[E0391]: cycle detected when processing `main` + /// ``` + #[track_caller] + pub fn spawn<T>(future: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + // preventing stack overflows on debug mode, by quickly sending the + // task to the heap. + if cfg!(debug_assertions) && std::mem::size_of::<T>() > 2048 { + spawn_inner(Box::pin(future), None) + } else { + spawn_inner(future, None) + } + } + + #[track_caller] + pub(super) fn spawn_inner<T>(future: T, name: Option<&str>) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let spawn_handle = crate::runtime::context::spawn_handle().expect(CONTEXT_MISSING_ERROR); + let task = crate::util::trace::task(future, "task", name); + spawn_handle.spawn(task) + } +} diff --git a/third_party/rust/tokio/src/task/task_local.rs b/third_party/rust/tokio/src/task/task_local.rs new file mode 100644 index 0000000000..949bbca3ee --- /dev/null +++ b/third_party/rust/tokio/src/task/task_local.rs @@ -0,0 +1,302 @@ +use pin_project_lite::pin_project; +use std::cell::RefCell; +use std::error::Error; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, thread}; + +/// Declares a new task-local key of type [`tokio::task::LocalKey`]. +/// +/// # Syntax +/// +/// The macro wraps any number of static declarations and makes them local to the current task. +/// Publicity and attributes for each static is preserved. For example: +/// +/// # Examples +/// +/// ``` +/// # use tokio::task_local; +/// task_local! { +/// pub static ONE: u32; +/// +/// #[allow(unused)] +/// static TWO: f32; +/// } +/// # fn main() {} +/// ``` +/// +/// See [LocalKey documentation][`tokio::task::LocalKey`] for more +/// information. +/// +/// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +macro_rules! task_local { + // empty (base case for the recursion) + () => {}; + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { + $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); + $crate::task_local!($($rest)*); + }; + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { + $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __task_local_inner { + ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { + $vis static $name: $crate::task::LocalKey<$t> = { + std::thread_local! { + static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None); + } + + $crate::task::LocalKey { inner: __KEY } + }; + }; +} + +/// A key for task-local data. +/// +/// This type is generated by the `task_local!` macro. +/// +/// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will +/// _not_ lazily initialize the value on first access. Instead, the +/// value is first initialized when the future containing +/// the task-local is first polled by a futures executor, like Tokio. +/// +/// # Examples +/// +/// ``` +/// # async fn dox() { +/// tokio::task_local! { +/// static NUMBER: u32; +/// } +/// +/// NUMBER.scope(1, async move { +/// assert_eq!(NUMBER.get(), 1); +/// }).await; +/// +/// NUMBER.scope(2, async move { +/// assert_eq!(NUMBER.get(), 2); +/// +/// NUMBER.scope(3, async move { +/// assert_eq!(NUMBER.get(), 3); +/// }).await; +/// }).await; +/// # } +/// ``` +/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +pub struct LocalKey<T: 'static> { + #[doc(hidden)] + pub inner: thread::LocalKey<RefCell<Option<T>>>, +} + +impl<T: 'static> LocalKey<T> { + /// Sets a value `T` as the task-local value for the future `F`. + /// + /// On completion of `scope`, the task-local will be dropped. + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.scope(1, async move { + /// println!("task local value: {}", NUMBER.get()); + /// }).await; + /// # } + /// ``` + pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> + where + F: Future, + { + TaskLocalFuture { + local: self, + slot: Some(value), + future: f, + _pinned: PhantomPinned, + } + } + + /// Sets a value `T` as the task-local value for the closure `F`. + /// + /// On completion of `scope`, the task-local will be dropped. + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.sync_scope(1, || { + /// println!("task local value: {}", NUMBER.get()); + /// }); + /// # } + /// ``` + pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R + where + F: FnOnce() -> R, + { + let scope = TaskLocalFuture { + local: self, + slot: Some(value), + future: (), + _pinned: PhantomPinned, + }; + crate::pin!(scope); + scope.with_task(|_| f()) + } + + /// Accesses the current task-local and runs the provided closure. + /// + /// # Panics + /// + /// This function will panic if not called within the context + /// of a future containing a task-local with the corresponding key. + pub fn with<F, R>(&'static self, f: F) -> R + where + F: FnOnce(&T) -> R, + { + self.try_with(f).expect( + "cannot access a Task Local Storage value \ + without setting it via `LocalKey::set`", + ) + } + + /// Accesses the current task-local and runs the provided closure. + /// + /// If the task-local with the associated key is not present, this + /// method will return an `AccessError`. For a panicking variant, + /// see `with`. + pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> + where + F: FnOnce(&T) -> R, + { + self.inner.with(|v| { + if let Some(val) = v.borrow().as_ref() { + Ok(f(val)) + } else { + Err(AccessError { _private: () }) + } + }) + } +} + +impl<T: Copy + 'static> LocalKey<T> { + /// Returns a copy of the task-local value + /// if the task-local value implements `Copy`. + pub fn get(&'static self) -> T { + self.with(|v| *v) + } +} + +impl<T: 'static> fmt::Debug for LocalKey<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("LocalKey { .. }") + } +} + +pin_project! { + /// A future that sets a value `T` of a task local for the future `F` during + /// its execution. + /// + /// The value of the task-local must be `'static` and will be dropped on the + /// completion of the future. + /// + /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.scope(1, async move { + /// println!("task local value: {}", NUMBER.get()); + /// }).await; + /// # } + /// ``` + pub struct TaskLocalFuture<T, F> + where + T: 'static + { + local: &'static LocalKey<T>, + slot: Option<T>, + #[pin] + future: F, + #[pin] + _pinned: PhantomPinned, + } +} + +impl<T: 'static, F> TaskLocalFuture<T, F> { + fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R { + struct Guard<'a, T: 'static> { + local: &'static LocalKey<T>, + slot: &'a mut Option<T>, + prev: Option<T>, + } + + impl<T> Drop for Guard<'_, T> { + fn drop(&mut self) { + let value = self.local.inner.with(|c| c.replace(self.prev.take())); + *self.slot = value; + } + } + + let project = self.project(); + let val = project.slot.take(); + + let prev = project.local.inner.with(|c| c.replace(val)); + + let _guard = Guard { + prev, + slot: project.slot, + local: *project.local, + }; + + f(project.future) + } +} + +impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.with_task(|f| f.poll(cx)) + } +} + +/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct AccessError { + _private: (), +} + +impl fmt::Debug for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AccessError").finish() + } +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt("task-local value not set", f) + } +} + +impl Error for AccessError {} diff --git a/third_party/rust/tokio/src/task/unconstrained.rs b/third_party/rust/tokio/src/task/unconstrained.rs new file mode 100644 index 0000000000..31c732bfc9 --- /dev/null +++ b/third_party/rust/tokio/src/task/unconstrained.rs @@ -0,0 +1,45 @@ +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`unconstrained`](unconstrained) method. + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + #[must_use = "Unconstrained does nothing unless polled"] + pub struct Unconstrained<F> { + #[pin] + inner: F, + } +} + +impl<F> Future for Unconstrained<F> +where + F: Future, +{ + type Output = <F as Future>::Output; + + cfg_coop! { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let inner = self.project().inner; + crate::coop::with_unconstrained(|| inner.poll(cx)) + } + } + + cfg_not_coop! { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let inner = self.project().inner; + inner.poll(cx) + } + } +} + +/// Turn off cooperative scheduling for a future. The future will never be forced to yield by +/// Tokio. Using this exposes your service to starvation if the unconstrained future never yields +/// otherwise. +/// +/// See also the usage example in the [task module](index.html#unconstrained). +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +pub fn unconstrained<F>(inner: F) -> Unconstrained<F> { + Unconstrained { inner } +} diff --git a/third_party/rust/tokio/src/task/yield_now.rs b/third_party/rust/tokio/src/task/yield_now.rs new file mode 100644 index 0000000000..148e3dc0c8 --- /dev/null +++ b/third_party/rust/tokio/src/task/yield_now.rs @@ -0,0 +1,58 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Yields execution back to the Tokio runtime. +/// +/// A task yields by awaiting on `yield_now()`, and may resume when that future +/// completes (with no output.) The current task will be re-added as a pending +/// task at the _back_ of the pending queue. Any other pending tasks will be +/// scheduled. No other waking is required for the task to continue. +/// +/// See also the usage example in the [task module](index.html#yield_now). +/// +/// ## Non-guarantees +/// +/// This function may not yield all the way up to the executor if there are any +/// special combinators above it in the call stack. For example, if a +/// [`tokio::select!`] has another branch complete during the same poll as the +/// `yield_now()`, then the yield is not propagated all the way up to the +/// runtime. +/// +/// It is generally not guaranteed that the runtime behaves like you expect it +/// to when deciding which task to schedule next after a call to `yield_now()`. +/// In particular, the runtime may choose to poll the task that just ran +/// `yield_now()` again immediately without polling any other tasks first. For +/// example, the runtime will not drive the IO driver between every poll of a +/// task, and this could result in the runtime polling the current task again +/// immediately even if there is another task that could make progress if that +/// other task is waiting for a notification from the IO driver. +/// +/// In general, changes to the order in which the runtime polls tasks is not +/// considered a breaking change, and your program should be correct no matter +/// which order the runtime polls your tasks in. +/// +/// [`tokio::select!`]: macro@crate::select +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +pub async fn yield_now() { + /// Yield implementation + struct YieldNow { + yielded: bool, + } + + impl Future for YieldNow { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.yielded { + return Poll::Ready(()); + } + + self.yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + YieldNow { yielded: false }.await +} diff --git a/third_party/rust/tokio/src/time/clock.rs b/third_party/rust/tokio/src/time/clock.rs new file mode 100644 index 0000000000..41be9bac48 --- /dev/null +++ b/third_party/rust/tokio/src/time/clock.rs @@ -0,0 +1,233 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +//! Source of time abstraction. +//! +//! By default, `std::time::Instant::now()` is used. However, when the +//! `test-util` feature flag is enabled, the values returned for `now()` are +//! configurable. + +cfg_not_test_util! { + use crate::time::{Instant}; + + #[derive(Debug, Clone)] + pub(crate) struct Clock {} + + pub(crate) fn now() -> Instant { + Instant::from_std(std::time::Instant::now()) + } + + impl Clock { + pub(crate) fn new(_enable_pausing: bool, _start_paused: bool) -> Clock { + Clock {} + } + + pub(crate) fn now(&self) -> Instant { + now() + } + } +} + +cfg_test_util! { + use crate::time::{Duration, Instant}; + use crate::loom::sync::{Arc, Mutex}; + + cfg_rt! { + fn clock() -> Option<Clock> { + crate::runtime::context::clock() + } + } + + cfg_not_rt! { + fn clock() -> Option<Clock> { + None + } + } + + /// A handle to a source of time. + #[derive(Debug, Clone)] + pub(crate) struct Clock { + inner: Arc<Mutex<Inner>>, + } + + #[derive(Debug)] + struct Inner { + /// True if the ability to pause time is enabled. + enable_pausing: bool, + + /// Instant to use as the clock's base instant. + base: std::time::Instant, + + /// Instant at which the clock was last unfrozen. + unfrozen: Option<std::time::Instant>, + } + + /// Pauses time. + /// + /// The current value of `Instant::now()` is saved and all subsequent calls + /// to `Instant::now()` will return the saved value. The saved value can be + /// changed by [`advance`] or by the time auto-advancing once the runtime + /// has no work to do. This only affects the `Instant` type in Tokio, and + /// the `Instant` in std continues to work as normal. + /// + /// Pausing time requires the `current_thread` Tokio runtime. This is the + /// default runtime used by `#[tokio::test]`. The runtime can be initialized + /// with time in a paused state using the `Builder::start_paused` method. + /// + /// For cases where time is immediately paused, it is better to pause + /// the time using the `main` or `test` macro: + /// ``` + /// #[tokio::main(flavor = "current_thread", start_paused = true)] + /// async fn main() { + /// println!("Hello world"); + /// } + /// ``` + /// + /// # Panics + /// + /// Panics if time is already frozen or if called from outside of a + /// `current_thread` Tokio runtime. + /// + /// # Auto-advance + /// + /// If time is paused and the runtime has no work to do, the clock is + /// auto-advanced to the next pending timer. This means that [`Sleep`] or + /// other timer-backed primitives can cause the runtime to advance the + /// current time when awaited. + /// + /// [`Sleep`]: crate::time::Sleep + /// [`advance`]: crate::time::advance + pub fn pause() { + let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); + clock.pause(); + } + + /// Resumes time. + /// + /// Clears the saved `Instant::now()` value. Subsequent calls to + /// `Instant::now()` will return the value returned by the system call. + /// + /// # Panics + /// + /// Panics if time is not frozen or if called from outside of the Tokio + /// runtime. + pub fn resume() { + let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); + let mut inner = clock.inner.lock(); + + if inner.unfrozen.is_some() { + panic!("time is not frozen"); + } + + inner.unfrozen = Some(std::time::Instant::now()); + } + + /// Advances time. + /// + /// Increments the saved `Instant::now()` value by `duration`. Subsequent + /// calls to `Instant::now()` will return the result of the increment. + /// + /// This function will make the current time jump forward by the given + /// duration in one jump. This means that all `sleep` calls with a deadline + /// before the new time will immediately complete "at the same time", and + /// the runtime is free to poll them in any order. Additionally, this + /// method will not wait for the `sleep` calls it advanced past to complete. + /// If you want to do that, you should instead call [`sleep`] and rely on + /// the runtime's auto-advance feature. + /// + /// Note that calls to `sleep` are not guaranteed to complete the first time + /// they are polled after a call to `advance`. For example, this can happen + /// if the runtime has not yet touched the timer driver after the call to + /// `advance`. However if they don't, the runtime will poll the task again + /// shortly. + /// + /// # Panics + /// + /// Panics if time is not frozen or if called from outside of the Tokio + /// runtime. + /// + /// # Auto-advance + /// + /// If the time is paused and there is no work to do, the runtime advances + /// time to the next timer. See [`pause`](pause#auto-advance) for more + /// details. + /// + /// [`sleep`]: fn@crate::time::sleep + pub async fn advance(duration: Duration) { + let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); + clock.advance(duration); + + crate::task::yield_now().await; + } + + /// Returns the current instant, factoring in frozen time. + pub(crate) fn now() -> Instant { + if let Some(clock) = clock() { + clock.now() + } else { + Instant::from_std(std::time::Instant::now()) + } + } + + impl Clock { + /// Returns a new `Clock` instance that uses the current execution context's + /// source of time. + pub(crate) fn new(enable_pausing: bool, start_paused: bool) -> Clock { + let now = std::time::Instant::now(); + + let clock = Clock { + inner: Arc::new(Mutex::new(Inner { + enable_pausing, + base: now, + unfrozen: Some(now), + })), + }; + + if start_paused { + clock.pause(); + } + + clock + } + + pub(crate) fn pause(&self) { + let mut inner = self.inner.lock(); + + if !inner.enable_pausing { + drop(inner); // avoid poisoning the lock + panic!("`time::pause()` requires the `current_thread` Tokio runtime. \ + This is the default Runtime used by `#[tokio::test]."); + } + + let elapsed = inner.unfrozen.as_ref().expect("time is already frozen").elapsed(); + inner.base += elapsed; + inner.unfrozen = None; + } + + pub(crate) fn is_paused(&self) -> bool { + let inner = self.inner.lock(); + inner.unfrozen.is_none() + } + + pub(crate) fn advance(&self, duration: Duration) { + let mut inner = self.inner.lock(); + + if inner.unfrozen.is_some() { + panic!("time is not frozen"); + } + + inner.base += duration; + } + + pub(crate) fn now(&self) -> Instant { + let inner = self.inner.lock(); + + let mut ret = inner.base; + + if let Some(unfrozen) = inner.unfrozen { + ret += unfrozen.elapsed(); + } + + Instant::from_std(ret) + } + } +} diff --git a/third_party/rust/tokio/src/time/driver/entry.rs b/third_party/rust/tokio/src/time/driver/entry.rs new file mode 100644 index 0000000000..f0ea898e12 --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/entry.rs @@ -0,0 +1,633 @@ +//! Timer state structures. +//! +//! This module contains the heart of the intrusive timer implementation, and as +//! such the structures inside are full of tricky concurrency and unsafe code. +//! +//! # Ground rules +//! +//! The heart of the timer implementation here is the [`TimerShared`] structure, +//! shared between the [`TimerEntry`] and the driver. Generally, we permit access +//! to [`TimerShared`] ONLY via either 1) a mutable reference to [`TimerEntry`] or +//! 2) a held driver lock. +//! +//! It follows from this that any changes made while holding BOTH 1 and 2 will +//! be reliably visible, regardless of ordering. This is because of the acq/rel +//! fences on the driver lock ensuring ordering with 2, and rust mutable +//! reference rules for 1 (a mutable reference to an object can't be passed +//! between threads without an acq/rel barrier, and same-thread we have local +//! happens-before ordering). +//! +//! # State field +//! +//! Each timer has a state field associated with it. This field contains either +//! the current scheduled time, or a special flag value indicating its state. +//! This state can either indicate that the timer is on the 'pending' queue (and +//! thus will be fired with an `Ok(())` result soon) or that it has already been +//! fired/deregistered. +//! +//! This single state field allows for code that is firing the timer to +//! synchronize with any racing `reset` calls reliably. +//! +//! # Cached vs true timeouts +//! +//! To allow for the use case of a timeout that is periodically reset before +//! expiration to be as lightweight as possible, we support optimistically +//! lock-free timer resets, in the case where a timer is rescheduled to a later +//! point than it was originally scheduled for. +//! +//! This is accomplished by lazily rescheduling timers. That is, we update the +//! state field field with the true expiration of the timer from the holder of +//! the [`TimerEntry`]. When the driver services timers (ie, whenever it's +//! walking lists of timers), it checks this "true when" value, and reschedules +//! based on it. +//! +//! We do, however, also need to track what the expiration time was when we +//! originally registered the timer; this is used to locate the right linked +//! list when the timer is being cancelled. This is referred to as the "cached +//! when" internally. +//! +//! There is of course a race condition between timer reset and timer +//! expiration. If the driver fails to observe the updated expiration time, it +//! could trigger expiration of the timer too early. However, because +//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and +//! refuse to mark the timer as pending. +//! +//! [mark_pending]: TimerHandle::mark_pending + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::AtomicU64; +use crate::loom::sync::atomic::Ordering; + +use crate::sync::AtomicWaker; +use crate::time::Instant; +use crate::util::linked_list; + +use super::Handle; + +use std::cell::UnsafeCell as StdUnsafeCell; +use std::task::{Context, Poll, Waker}; +use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; + +type TimerResult = Result<(), crate::time::error::Error>; + +const STATE_DEREGISTERED: u64 = u64::MAX; +const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; +const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; + +/// This structure holds the current shared state of the timer - its scheduled +/// time (if registered), or otherwise the result of the timer completing, as +/// well as the registered waker. +/// +/// Generally, the StateCell is only permitted to be accessed from two contexts: +/// Either a thread holding the corresponding &mut TimerEntry, or a thread +/// holding the timer driver lock. The write actions on the StateCell amount to +/// passing "ownership" of the StateCell between these contexts; moving a timer +/// from the TimerEntry to the driver requires _both_ holding the &mut +/// TimerEntry and the driver lock, while moving it back (firing the timer) +/// requires only the driver lock. +pub(super) struct StateCell { + /// Holds either the scheduled expiration time for this timer, or (if the + /// timer has been fired and is unregistered), `u64::MAX`. + state: AtomicU64, + /// If the timer is fired (an Acquire order read on state shows + /// `u64::MAX`), holds the result that should be returned from + /// polling the timer. Otherwise, the contents are unspecified and reading + /// without holding the driver lock is undefined behavior. + result: UnsafeCell<TimerResult>, + /// The currently-registered waker + waker: CachePadded<AtomicWaker>, +} + +impl Default for StateCell { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for StateCell { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "StateCell({:?})", self.read_state()) + } +} + +impl StateCell { + fn new() -> Self { + Self { + state: AtomicU64::new(STATE_DEREGISTERED), + result: UnsafeCell::new(Ok(())), + waker: CachePadded(AtomicWaker::new()), + } + } + + fn is_pending(&self) -> bool { + self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE + } + + /// Returns the current expiration time, or None if not currently scheduled. + fn when(&self) -> Option<u64> { + let cur_state = self.state.load(Ordering::Relaxed); + + if cur_state == u64::MAX { + None + } else { + Some(cur_state) + } + } + + /// If the timer is completed, returns the result of the timer. Otherwise, + /// returns None and registers the waker. + fn poll(&self, waker: &Waker) -> Poll<TimerResult> { + // We must register first. This ensures that either `fire` will + // observe the new waker, or we will observe a racing fire to have set + // the state, or both. + self.waker.0.register_by_ref(waker); + + self.read_state() + } + + fn read_state(&self) -> Poll<TimerResult> { + let cur_state = self.state.load(Ordering::Acquire); + + if cur_state == STATE_DEREGISTERED { + // SAFETY: The driver has fired this timer; this involves writing + // the result, and then writing (with release ordering) the state + // field. + Poll::Ready(unsafe { self.result.with(|p| *p) }) + } else { + Poll::Pending + } + } + + /// Marks this timer as being moved to the pending list, if its scheduled + /// time is not after `not_after`. + /// + /// If the timer is scheduled for a time after not_after, returns an Err + /// containing the current scheduled time. + /// + /// SAFETY: Must hold the driver lock. + unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + // Quick initial debug check to see if the timer is already fired. Since + // firing the timer can only happen with the driver lock held, we know + // we shouldn't be able to "miss" a transition to a fired state, even + // with relaxed ordering. + let mut cur_state = self.state.load(Ordering::Relaxed); + + loop { + debug_assert!(cur_state < STATE_MIN_VALUE); + + if cur_state > not_after { + break Err(cur_state); + } + + match self.state.compare_exchange( + cur_state, + STATE_PENDING_FIRE, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + break Ok(()); + } + Err(actual_state) => { + cur_state = actual_state; + } + } + } + } + + /// Fires the timer, setting the result to the provided result. + /// + /// Returns: + /// * `Some(waker) - if fired and a waker needs to be invoked once the + /// driver lock is released + /// * `None` - if fired and a waker does not need to be invoked, or if + /// already fired + /// + /// SAFETY: The driver lock must be held. + unsafe fn fire(&self, result: TimerResult) -> Option<Waker> { + // Quick initial check to see if the timer is already fired. Since + // firing the timer can only happen with the driver lock held, we know + // we shouldn't be able to "miss" a transition to a fired state, even + // with relaxed ordering. + let cur_state = self.state.load(Ordering::Relaxed); + if cur_state == STATE_DEREGISTERED { + return None; + } + + // SAFETY: We assume the driver lock is held and the timer is not + // fired, so only the driver is accessing this field. + // + // We perform a release-ordered store to state below, to ensure this + // write is visible before the state update is visible. + unsafe { self.result.with_mut(|p| *p = result) }; + + self.state.store(STATE_DEREGISTERED, Ordering::Release); + + self.waker.0.take_waker() + } + + /// Marks the timer as registered (poll will return None) and sets the + /// expiration time. + /// + /// While this function is memory-safe, it should only be called from a + /// context holding both `&mut TimerEntry` and the driver lock. + fn set_expiration(&self, timestamp: u64) { + debug_assert!(timestamp < STATE_MIN_VALUE); + + // We can use relaxed ordering because we hold the driver lock and will + // fence when we release the lock. + self.state.store(timestamp, Ordering::Relaxed); + } + + /// Attempts to adjust the timer to a new timestamp. + /// + /// If the timer has already been fired, is pending firing, or the new + /// timestamp is earlier than the old timestamp, (or occasionally + /// spuriously) returns Err without changing the timer's state. In this + /// case, the timer must be deregistered and re-registered. + fn extend_expiration(&self, new_timestamp: u64) -> Result<(), ()> { + let mut prior = self.state.load(Ordering::Relaxed); + loop { + if new_timestamp < prior || prior >= STATE_MIN_VALUE { + return Err(()); + } + + match self.state.compare_exchange_weak( + prior, + new_timestamp, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + return Ok(()); + } + Err(true_prior) => { + prior = true_prior; + } + } + } + } + + /// Returns true if the state of this timer indicates that the timer might + /// be registered with the driver. This check is performed with relaxed + /// ordering, but is conservative - if it returns false, the timer is + /// definitely _not_ registered. + pub(super) fn might_be_registered(&self) -> bool { + self.state.load(Ordering::Relaxed) != u64::MAX + } +} + +/// A timer entry. +/// +/// This is the handle to a timer that is controlled by the requester of the +/// timer. As this participates in intrusive data structures, it must be pinned +/// before polling. +#[derive(Debug)] +pub(super) struct TimerEntry { + /// Arc reference to the driver. We can only free the driver after + /// deregistering everything from their respective timer wheels. + driver: Handle, + /// Shared inner structure; this is part of an intrusive linked list, and + /// therefore other references can exist to it while mutable references to + /// Entry exist. + /// + /// This is manipulated only under the inner mutex. TODO: Can we use loom + /// cells for this? + inner: StdUnsafeCell<TimerShared>, + /// Initial deadline for the timer. This is used to register on the first + /// poll, as we can't register prior to being pinned. + initial_deadline: Option<Instant>, + /// Ensure the type is !Unpin + _m: std::marker::PhantomPinned, +} + +unsafe impl Send for TimerEntry {} +unsafe impl Sync for TimerEntry {} + +/// An TimerHandle is the (non-enforced) "unique" pointer from the driver to the +/// timer entry. Generally, at most one TimerHandle exists for a timer at a time +/// (enforced by the timer state machine). +/// +/// SAFETY: An TimerHandle is essentially a raw pointer, and the usual caveats +/// of pointer safety apply. In particular, TimerHandle does not itself enforce +/// that the timer does still exist; however, normally an TimerHandle is created +/// immediately before registering the timer, and is consumed when firing the +/// timer, to help minimize mistakes. Still, because TimerHandle cannot enforce +/// memory safety, all operations are unsafe. +#[derive(Debug)] +pub(crate) struct TimerHandle { + inner: NonNull<TimerShared>, +} + +pub(super) type EntryList = crate::util::linked_list::LinkedList<TimerShared, TimerShared>; + +/// The shared state structure of a timer. This structure is shared between the +/// frontend (`Entry`) and driver backend. +/// +/// Note that this structure is located inside the `TimerEntry` structure. +#[derive(Debug)] +#[repr(C)] // required by `link_list::Link` impl +pub(crate) struct TimerShared { + /// Data manipulated by the driver thread itself, only. + driver_state: CachePadded<TimerSharedPadded>, + + /// Current state. This records whether the timer entry is currently under + /// the ownership of the driver, and if not, its current state (not + /// complete, fired, error, etc). + state: StateCell, + + _p: PhantomPinned, +} + +impl TimerShared { + pub(super) fn new() -> Self { + Self { + state: StateCell::default(), + driver_state: CachePadded(TimerSharedPadded::new()), + _p: PhantomPinned, + } + } + + /// Gets the cached time-of-expiration value. + pub(super) fn cached_when(&self) -> u64 { + // Cached-when is only accessed under the driver lock, so we can use relaxed + self.driver_state.0.cached_when.load(Ordering::Relaxed) + } + + /// Gets the true time-of-expiration value, and copies it into the cached + /// time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn sync_when(&self) -> u64 { + let true_when = self.true_when(); + + self.driver_state + .0 + .cached_when + .store(true_when, Ordering::Relaxed); + + true_when + } + + /// Sets the cached time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + unsafe fn set_cached_when(&self, when: u64) { + self.driver_state + .0 + .cached_when + .store(when, Ordering::Relaxed); + } + + /// Returns the true time-of-expiration value, with relaxed memory ordering. + pub(super) fn true_when(&self) -> u64 { + self.state.when().expect("Timer already fired") + } + + /// Sets the true time-of-expiration value, even if it is less than the + /// current expiration or the timer is deregistered. + /// + /// SAFETY: Must only be called with the driver lock held and the entry not + /// in the timer wheel. + pub(super) unsafe fn set_expiration(&self, t: u64) { + self.state.set_expiration(t); + self.driver_state.0.cached_when.store(t, Ordering::Relaxed); + } + + /// Sets the true time-of-expiration only if it is after the current. + pub(super) fn extend_expiration(&self, t: u64) -> Result<(), ()> { + self.state.extend_expiration(t) + } + + /// Returns a TimerHandle for this timer. + pub(super) fn handle(&self) -> TimerHandle { + TimerHandle { + inner: NonNull::from(self), + } + } + + /// Returns true if the state of this timer indicates that the timer might + /// be registered with the driver. This check is performed with relaxed + /// ordering, but is conservative - if it returns false, the timer is + /// definitely _not_ registered. + pub(super) fn might_be_registered(&self) -> bool { + self.state.might_be_registered() + } +} + +/// Additional shared state between the driver and the timer which is cache +/// padded. This contains the information that the driver thread accesses most +/// frequently to minimize contention. In particular, we move it away from the +/// waker, as the waker is updated on every poll. +#[repr(C)] // required by `link_list::Link` impl +struct TimerSharedPadded { + /// A link within the doubly-linked list of timers on a particular level and + /// slot. Valid only if state is equal to Registered. + /// + /// Only accessed under the entry lock. + pointers: linked_list::Pointers<TimerShared>, + + /// The expiration time for which this entry is currently registered. + /// Generally owned by the driver, but is accessed by the entry when not + /// registered. + cached_when: AtomicU64, + + /// The true expiration time. Set by the timer future, read by the driver. + true_when: AtomicU64, +} + +impl std::fmt::Debug for TimerSharedPadded { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TimerSharedPadded") + .field("when", &self.true_when.load(Ordering::Relaxed)) + .field("cached_when", &self.cached_when.load(Ordering::Relaxed)) + .finish() + } +} + +impl TimerSharedPadded { + fn new() -> Self { + Self { + cached_when: AtomicU64::new(0), + true_when: AtomicU64::new(0), + pointers: linked_list::Pointers::new(), + } + } +} + +unsafe impl Send for TimerShared {} +unsafe impl Sync for TimerShared {} + +unsafe impl linked_list::Link for TimerShared { + type Handle = TimerHandle; + + type Target = TimerShared; + + fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> { + handle.inner + } + + unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle { + TimerHandle { inner: ptr } + } + + unsafe fn pointers( + target: NonNull<Self::Target>, + ) -> NonNull<linked_list::Pointers<Self::Target>> { + target.cast() + } +} + +// ===== impl Entry ===== + +impl TimerEntry { + pub(crate) fn new(handle: &Handle, deadline: Instant) -> Self { + let driver = handle.clone(); + + Self { + driver, + inner: StdUnsafeCell::new(TimerShared::new()), + initial_deadline: Some(deadline), + _m: std::marker::PhantomPinned, + } + } + + fn inner(&self) -> &TimerShared { + unsafe { &*self.inner.get() } + } + + pub(crate) fn is_elapsed(&self) -> bool { + !self.inner().state.might_be_registered() && self.initial_deadline.is_none() + } + + /// Cancels and deregisters the timer. This operation is irreversible. + pub(crate) fn cancel(self: Pin<&mut Self>) { + // We need to perform an acq/rel fence with the driver thread, and the + // simplest way to do so is to grab the driver lock. + // + // Why is this necessary? We're about to release this timer's memory for + // some other non-timer use. However, we've been doing a bunch of + // relaxed (or even non-atomic) writes from the driver thread, and we'll + // be doing more from _this thread_ (as this memory is interpreted as + // something else). + // + // It is critical to ensure that, from the point of view of the driver, + // those future non-timer writes happen-after the timer is fully fired, + // and from the purpose of this thread, the driver's writes all + // happen-before we drop the timer. This in turn requires us to perform + // an acquire-release barrier in _both_ directions between the driver + // and dropping thread. + // + // The lock acquisition in clear_entry serves this purpose. All of the + // driver manipulations happen with the lock held, so we can just take + // the lock and be sure that this drop happens-after everything the + // driver did so far and happens-before everything the driver does in + // the future. While we have the lock held, we also go ahead and + // deregister the entry if necessary. + unsafe { self.driver.clear_entry(NonNull::from(self.inner())) }; + } + + pub(crate) fn reset(mut self: Pin<&mut Self>, new_time: Instant) { + unsafe { self.as_mut().get_unchecked_mut() }.initial_deadline = None; + + let tick = self.driver.time_source().deadline_to_tick(new_time); + + if self.inner().extend_expiration(tick).is_ok() { + return; + } + + unsafe { + self.driver.reregister(tick, self.inner().into()); + } + } + + pub(crate) fn poll_elapsed( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), super::Error>> { + if self.driver.is_shutdown() { + panic!("{}", crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR); + } + + if let Some(deadline) = self.initial_deadline { + self.as_mut().reset(deadline); + } + + let this = unsafe { self.get_unchecked_mut() }; + + this.inner().state.poll(cx.waker()) + } +} + +impl TimerHandle { + pub(super) unsafe fn cached_when(&self) -> u64 { + unsafe { self.inner.as_ref().cached_when() } + } + + pub(super) unsafe fn sync_when(&self) -> u64 { + unsafe { self.inner.as_ref().sync_when() } + } + + pub(super) unsafe fn is_pending(&self) -> bool { + unsafe { self.inner.as_ref().state.is_pending() } + } + + /// Forcibly sets the true and cached expiration times to the given tick. + /// + /// SAFETY: The caller must ensure that the handle remains valid, the driver + /// lock is held, and that the timer is not in any wheel linked lists. + pub(super) unsafe fn set_expiration(&self, tick: u64) { + self.inner.as_ref().set_expiration(tick); + } + + /// Attempts to mark this entry as pending. If the expiration time is after + /// `not_after`, however, returns an Err with the current expiration time. + /// + /// If an `Err` is returned, the `cached_when` value will be updated to this + /// new expiration time. + /// + /// SAFETY: The caller must ensure that the handle remains valid, the driver + /// lock is held, and that the timer is not in any wheel linked lists. + /// After returning Ok, the entry must be added to the pending list. + pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + match self.inner.as_ref().state.mark_pending(not_after) { + Ok(()) => { + // mark this as being on the pending queue in cached_when + self.inner.as_ref().set_cached_when(u64::MAX); + Ok(()) + } + Err(tick) => { + self.inner.as_ref().set_cached_when(tick); + Err(tick) + } + } + } + + /// Attempts to transition to a terminal state. If the state is already a + /// terminal state, does nothing. + /// + /// Because the entry might be dropped after the state is moved to a + /// terminal state, this function consumes the handle to ensure we don't + /// access the entry afterwards. + /// + /// Returns the last-registered waker, if any. + /// + /// SAFETY: The driver lock must be held while invoking this function, and + /// the entry must not be in any wheel linked lists. + pub(super) unsafe fn fire(self, completed_state: TimerResult) -> Option<Waker> { + self.inner.as_ref().state.fire(completed_state) + } +} + +impl Drop for TimerEntry { + fn drop(&mut self) { + unsafe { Pin::new_unchecked(self) }.as_mut().cancel() + } +} + +#[cfg_attr(target_arch = "x86_64", repr(align(128)))] +#[cfg_attr(not(target_arch = "x86_64"), repr(align(64)))] +#[derive(Debug, Default)] +struct CachePadded<T>(T); diff --git a/third_party/rust/tokio/src/time/driver/handle.rs b/third_party/rust/tokio/src/time/driver/handle.rs new file mode 100644 index 0000000000..b61c0476e1 --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/handle.rs @@ -0,0 +1,94 @@ +use crate::loom::sync::Arc; +use crate::time::driver::ClockTime; +use std::fmt; + +/// Handle to time driver instance. +#[derive(Clone)] +pub(crate) struct Handle { + time_source: ClockTime, + inner: Arc<super::Inner>, +} + +impl Handle { + /// Creates a new timer `Handle` from a shared `Inner` timer state. + pub(super) fn new(inner: Arc<super::Inner>) -> Self { + let time_source = inner.state.lock().time_source.clone(); + Handle { time_source, inner } + } + + /// Returns the time source associated with this handle. + pub(super) fn time_source(&self) -> &ClockTime { + &self.time_source + } + + /// Access the driver's inner structure. + pub(super) fn get(&self) -> &super::Inner { + &*self.inner + } + + /// Checks whether the driver has been shutdown. + pub(super) fn is_shutdown(&self) -> bool { + self.inner.is_shutdown() + } +} + +cfg_rt! { + impl Handle { + /// Tries to get a handle to the current timer. + /// + /// # Panics + /// + /// This function panics if there is no current timer set. + /// + /// It can be triggered when [`Builder::enable_time`] or + /// [`Builder::enable_all`] are not included in the builder. + /// + /// It can also panic whenever a timer is created outside of a + /// Tokio runtime. That is why `rt.block_on(sleep(...))` will panic, + /// since the function is executed outside of the runtime. + /// Whereas `rt.block_on(async {sleep(...).await})` doesn't panic. + /// And this is because wrapping the function on an async makes it lazy, + /// and so gets executed inside the runtime successfully without + /// panicking. + /// + /// [`Builder::enable_time`]: crate::runtime::Builder::enable_time + /// [`Builder::enable_all`]: crate::runtime::Builder::enable_all + pub(crate) fn current() -> Self { + crate::runtime::context::time_handle() + .expect("A Tokio 1.x context was found, but timers are disabled. Call `enable_time` on the runtime builder to enable timers.") + } + } +} + +cfg_not_rt! { + impl Handle { + /// Tries to get a handle to the current timer. + /// + /// # Panics + /// + /// This function panics if there is no current timer set. + /// + /// It can be triggered when [`Builder::enable_time`] or + /// [`Builder::enable_all`] are not included in the builder. + /// + /// It can also panic whenever a timer is created outside of a + /// Tokio runtime. That is why `rt.block_on(sleep(...))` will panic, + /// since the function is executed outside of the runtime. + /// Whereas `rt.block_on(async {sleep(...).await})` doesn't panic. + /// And this is because wrapping the function on an async makes it lazy, + /// and so gets executed inside the runtime successfully without + /// panicking. + /// + /// [`Builder::enable_time`]: crate::runtime::Builder::enable_time + /// [`Builder::enable_all`]: crate::runtime::Builder::enable_all + pub(crate) fn current() -> Self { + panic!("{}", crate::util::error::CONTEXT_MISSING_ERROR) + } + } +} + +impl fmt::Debug for Handle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Handle") + } +} diff --git a/third_party/rust/tokio/src/time/driver/mod.rs b/third_party/rust/tokio/src/time/driver/mod.rs new file mode 100644 index 0000000000..9971877479 --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/mod.rs @@ -0,0 +1,528 @@ +// Currently, rust warns when an unsafe fn contains an unsafe {} block. However, +// in the future, this will change to the reverse. For now, suppress this +// warning and generally stick with being explicit about unsafety. +#![allow(unused_unsafe)] +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +//! Time driver. + +mod entry; +pub(self) use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared}; + +mod handle; +pub(crate) use self::handle::Handle; + +mod wheel; + +pub(super) mod sleep; + +use crate::loom::sync::atomic::{AtomicBool, Ordering}; +use crate::loom::sync::{Arc, Mutex}; +use crate::park::{Park, Unpark}; +use crate::time::error::Error; +use crate::time::{Clock, Duration, Instant}; + +use std::convert::TryInto; +use std::fmt; +use std::{num::NonZeroU64, ptr::NonNull, task::Waker}; + +/// Time implementation that drives [`Sleep`][sleep], [`Interval`][interval], and [`Timeout`][timeout]. +/// +/// A `Driver` instance tracks the state necessary for managing time and +/// notifying the [`Sleep`][sleep] instances once their deadlines are reached. +/// +/// It is expected that a single instance manages many individual [`Sleep`][sleep] +/// instances. The `Driver` implementation is thread-safe and, as such, is able +/// to handle callers from across threads. +/// +/// After creating the `Driver` instance, the caller must repeatedly call `park` +/// or `park_timeout`. The time driver will perform no work unless `park` or +/// `park_timeout` is called repeatedly. +/// +/// The driver has a resolution of one millisecond. Any unit of time that falls +/// between milliseconds are rounded up to the next millisecond. +/// +/// When an instance is dropped, any outstanding [`Sleep`][sleep] instance that has not +/// elapsed will be notified with an error. At this point, calling `poll` on the +/// [`Sleep`][sleep] instance will result in panic. +/// +/// # Implementation +/// +/// The time driver is based on the [paper by Varghese and Lauck][paper]. +/// +/// A hashed timing wheel is a vector of slots, where each slot handles a time +/// slice. As time progresses, the timer walks over the slot for the current +/// instant, and processes each entry for that slot. When the timer reaches the +/// end of the wheel, it starts again at the beginning. +/// +/// The implementation maintains six wheels arranged in a set of levels. As the +/// levels go up, the slots of the associated wheel represent larger intervals +/// of time. At each level, the wheel has 64 slots. Each slot covers a range of +/// time equal to the wheel at the lower level. At level zero, each slot +/// represents one millisecond of time. +/// +/// The wheels are: +/// +/// * Level 0: 64 x 1 millisecond slots. +/// * Level 1: 64 x 64 millisecond slots. +/// * Level 2: 64 x ~4 second slots. +/// * Level 3: 64 x ~4 minute slots. +/// * Level 4: 64 x ~4 hour slots. +/// * Level 5: 64 x ~12 day slots. +/// +/// When the timer processes entries at level zero, it will notify all the +/// `Sleep` instances as their deadlines have been reached. For all higher +/// levels, all entries will be redistributed across the wheel at the next level +/// down. Eventually, as time progresses, entries with [`Sleep`][sleep] instances will +/// either be canceled (dropped) or their associated entries will reach level +/// zero and be notified. +/// +/// [paper]: http://www.cs.columbia.edu/~nahum/w6998/papers/ton97-timing-wheels.pdf +/// [sleep]: crate::time::Sleep +/// [timeout]: crate::time::Timeout +/// [interval]: crate::time::Interval +#[derive(Debug)] +pub(crate) struct Driver<P: Park + 'static> { + /// Timing backend in use. + time_source: ClockTime, + + /// Shared state. + handle: Handle, + + /// Parker to delegate to. + park: P, + + // When `true`, a call to `park_timeout` should immediately return and time + // should not advance. One reason for this to be `true` is if the task + // passed to `Runtime::block_on` called `task::yield_now()`. + // + // While it may look racy, it only has any effect when the clock is paused + // and pausing the clock is restricted to a single-threaded runtime. + #[cfg(feature = "test-util")] + did_wake: Arc<AtomicBool>, +} + +/// A structure which handles conversion from Instants to u64 timestamps. +#[derive(Debug, Clone)] +pub(self) struct ClockTime { + clock: super::clock::Clock, + start_time: Instant, +} + +impl ClockTime { + pub(self) fn new(clock: Clock) -> Self { + Self { + start_time: clock.now(), + clock, + } + } + + pub(self) fn deadline_to_tick(&self, t: Instant) -> u64 { + // Round up to the end of a ms + self.instant_to_tick(t + Duration::from_nanos(999_999)) + } + + pub(self) fn instant_to_tick(&self, t: Instant) -> u64 { + // round up + let dur: Duration = t + .checked_duration_since(self.start_time) + .unwrap_or_else(|| Duration::from_secs(0)); + let ms = dur.as_millis(); + + ms.try_into().unwrap_or(u64::MAX) + } + + pub(self) fn tick_to_duration(&self, t: u64) -> Duration { + Duration::from_millis(t) + } + + pub(self) fn now(&self) -> u64 { + self.instant_to_tick(self.clock.now()) + } +} + +/// Timer state shared between `Driver`, `Handle`, and `Registration`. +struct Inner { + // The state is split like this so `Handle` can access `is_shutdown` without locking the mutex + pub(super) state: Mutex<InnerState>, + + /// True if the driver is being shutdown. + pub(super) is_shutdown: AtomicBool, +} + +/// Time state shared which must be protected by a `Mutex` +struct InnerState { + /// Timing backend in use. + time_source: ClockTime, + + /// The last published timer `elapsed` value. + elapsed: u64, + + /// The earliest time at which we promise to wake up without unparking. + next_wake: Option<NonZeroU64>, + + /// Timer wheel. + wheel: wheel::Wheel, + + /// Unparker that can be used to wake the time driver. + unpark: Box<dyn Unpark>, +} + +// ===== impl Driver ===== + +impl<P> Driver<P> +where + P: Park + 'static, +{ + /// Creates a new `Driver` instance that uses `park` to block the current + /// thread and `time_source` to get the current time and convert to ticks. + /// + /// Specifying the source of time is useful when testing. + pub(crate) fn new(park: P, clock: Clock) -> Driver<P> { + let time_source = ClockTime::new(clock); + + let inner = Inner::new(time_source.clone(), Box::new(park.unpark())); + + Driver { + time_source, + handle: Handle::new(Arc::new(inner)), + park, + #[cfg(feature = "test-util")] + did_wake: Arc::new(AtomicBool::new(false)), + } + } + + /// Returns a handle to the timer. + /// + /// The `Handle` is how `Sleep` instances are created. The `Sleep` instances + /// can either be created directly or the `Handle` instance can be passed to + /// `with_default`, setting the timer as the default timer for the execution + /// context. + pub(crate) fn handle(&self) -> Handle { + self.handle.clone() + } + + fn park_internal(&mut self, limit: Option<Duration>) -> Result<(), P::Error> { + let mut lock = self.handle.get().state.lock(); + + assert!(!self.handle.is_shutdown()); + + let next_wake = lock.wheel.next_expiration_time(); + lock.next_wake = + next_wake.map(|t| NonZeroU64::new(t).unwrap_or_else(|| NonZeroU64::new(1).unwrap())); + + drop(lock); + + match next_wake { + Some(when) => { + let now = self.time_source.now(); + // Note that we effectively round up to 1ms here - this avoids + // very short-duration microsecond-resolution sleeps that the OS + // might treat as zero-length. + let mut duration = self.time_source.tick_to_duration(when.saturating_sub(now)); + + if duration > Duration::from_millis(0) { + if let Some(limit) = limit { + duration = std::cmp::min(limit, duration); + } + + self.park_timeout(duration)?; + } else { + self.park.park_timeout(Duration::from_secs(0))?; + } + } + None => { + if let Some(duration) = limit { + self.park_timeout(duration)?; + } else { + self.park.park()?; + } + } + } + + // Process pending timers after waking up + self.handle.process(); + + Ok(()) + } + + cfg_test_util! { + fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> { + let clock = &self.time_source.clock; + + if clock.is_paused() { + self.park.park_timeout(Duration::from_secs(0))?; + + // If the time driver was woken, then the park completed + // before the "duration" elapsed (usually caused by a + // yield in `Runtime::block_on`). In this case, we don't + // advance the clock. + if !self.did_wake() { + // Simulate advancing time + clock.advance(duration); + } + } else { + self.park.park_timeout(duration)?; + } + + Ok(()) + } + + fn did_wake(&self) -> bool { + self.did_wake.swap(false, Ordering::SeqCst) + } + } + + cfg_not_test_util! { + fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> { + self.park.park_timeout(duration) + } + } +} + +impl Handle { + /// Runs timer related logic, and returns the next wakeup time + pub(self) fn process(&self) { + let now = self.time_source().now(); + + self.process_at_time(now) + } + + pub(self) fn process_at_time(&self, mut now: u64) { + let mut waker_list: [Option<Waker>; 32] = Default::default(); + let mut waker_idx = 0; + + let mut lock = self.get().lock(); + + if now < lock.elapsed { + // Time went backwards! This normally shouldn't happen as the Rust language + // guarantees that an Instant is monotonic, but can happen when running + // Linux in a VM on a Windows host due to std incorrectly trusting the + // hardware clock to be monotonic. + // + // See <https://github.com/tokio-rs/tokio/issues/3619> for more information. + now = lock.elapsed; + } + + while let Some(entry) = lock.wheel.poll(now) { + debug_assert!(unsafe { entry.is_pending() }); + + // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. + if let Some(waker) = unsafe { entry.fire(Ok(())) } { + waker_list[waker_idx] = Some(waker); + + waker_idx += 1; + + if waker_idx == waker_list.len() { + // Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped. + drop(lock); + + for waker in waker_list.iter_mut() { + waker.take().unwrap().wake(); + } + + waker_idx = 0; + + lock = self.get().lock(); + } + } + } + + // Update the elapsed cache + lock.elapsed = lock.wheel.elapsed(); + lock.next_wake = lock + .wheel + .poll_at() + .map(|t| NonZeroU64::new(t).unwrap_or_else(|| NonZeroU64::new(1).unwrap())); + + drop(lock); + + for waker in waker_list[0..waker_idx].iter_mut() { + waker.take().unwrap().wake(); + } + } + + /// Removes a registered timer from the driver. + /// + /// The timer will be moved to the cancelled state. Wakers will _not_ be + /// invoked. If the timer is already completed, this function is a no-op. + /// + /// This function always acquires the driver lock, even if the entry does + /// not appear to be registered. + /// + /// SAFETY: The timer must not be registered with some other driver, and + /// `add_entry` must not be called concurrently. + pub(self) unsafe fn clear_entry(&self, entry: NonNull<TimerShared>) { + unsafe { + let mut lock = self.get().lock(); + + if entry.as_ref().might_be_registered() { + lock.wheel.remove(entry); + } + + entry.as_ref().handle().fire(Ok(())); + } + } + + /// Removes and re-adds an entry to the driver. + /// + /// SAFETY: The timer must be either unregistered, or registered with this + /// driver. No other threads are allowed to concurrently manipulate the + /// timer at all (the current thread should hold an exclusive reference to + /// the `TimerEntry`) + pub(self) unsafe fn reregister(&self, new_tick: u64, entry: NonNull<TimerShared>) { + let waker = unsafe { + let mut lock = self.get().lock(); + + // We may have raced with a firing/deregistration, so check before + // deregistering. + if unsafe { entry.as_ref().might_be_registered() } { + lock.wheel.remove(entry); + } + + // Now that we have exclusive control of this entry, mint a handle to reinsert it. + let entry = entry.as_ref().handle(); + + if self.is_shutdown() { + unsafe { entry.fire(Err(crate::time::error::Error::shutdown())) } + } else { + entry.set_expiration(new_tick); + + // Note: We don't have to worry about racing with some other resetting + // thread, because add_entry and reregister require exclusive control of + // the timer entry. + match unsafe { lock.wheel.insert(entry) } { + Ok(when) => { + if lock + .next_wake + .map(|next_wake| when < next_wake.get()) + .unwrap_or(true) + { + lock.unpark.unpark(); + } + + None + } + Err((entry, super::error::InsertError::Elapsed)) => unsafe { + entry.fire(Ok(())) + }, + } + } + + // Must release lock before invoking waker to avoid the risk of deadlock. + }; + + // The timer was fired synchronously as a result of the reregistration. + // Wake the waker; this is needed because we might reset _after_ a poll, + // and otherwise the task won't be awoken to poll again. + if let Some(waker) = waker { + waker.wake(); + } + } +} + +impl<P> Park for Driver<P> +where + P: Park + 'static, +{ + type Unpark = TimerUnpark<P>; + type Error = P::Error; + + fn unpark(&self) -> Self::Unpark { + TimerUnpark::new(self) + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.park_internal(None) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.park_internal(Some(duration)) + } + + fn shutdown(&mut self) { + if self.handle.is_shutdown() { + return; + } + + self.handle.get().is_shutdown.store(true, Ordering::SeqCst); + + // Advance time forward to the end of time. + + self.handle.process_at_time(u64::MAX); + + self.park.shutdown(); + } +} + +impl<P> Drop for Driver<P> +where + P: Park + 'static, +{ + fn drop(&mut self) { + self.shutdown(); + } +} + +pub(crate) struct TimerUnpark<P: Park + 'static> { + inner: P::Unpark, + + #[cfg(feature = "test-util")] + did_wake: Arc<AtomicBool>, +} + +impl<P: Park + 'static> TimerUnpark<P> { + fn new(driver: &Driver<P>) -> TimerUnpark<P> { + TimerUnpark { + inner: driver.park.unpark(), + + #[cfg(feature = "test-util")] + did_wake: driver.did_wake.clone(), + } + } +} + +impl<P: Park + 'static> Unpark for TimerUnpark<P> { + fn unpark(&self) { + #[cfg(feature = "test-util")] + self.did_wake.store(true, Ordering::SeqCst); + + self.inner.unpark(); + } +} + +// ===== impl Inner ===== + +impl Inner { + pub(self) fn new(time_source: ClockTime, unpark: Box<dyn Unpark>) -> Self { + Inner { + state: Mutex::new(InnerState { + time_source, + elapsed: 0, + next_wake: None, + unpark, + wheel: wheel::Wheel::new(), + }), + is_shutdown: AtomicBool::new(false), + } + } + + /// Locks the driver's inner structure + pub(super) fn lock(&self) -> crate::loom::sync::MutexGuard<'_, InnerState> { + self.state.lock() + } + + // Check whether the driver has been shutdown + pub(super) fn is_shutdown(&self) -> bool { + self.is_shutdown.load(Ordering::SeqCst) + } +} + +impl fmt::Debug for Inner { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Inner").finish() + } +} + +#[cfg(test)] +mod tests; diff --git a/third_party/rust/tokio/src/time/driver/sleep.rs b/third_party/rust/tokio/src/time/driver/sleep.rs new file mode 100644 index 0000000000..7f27ef201f --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/sleep.rs @@ -0,0 +1,438 @@ +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::time::driver::ClockTime; +use crate::time::driver::{Handle, TimerEntry}; +use crate::time::{error::Error, Duration, Instant}; +use crate::util::trace; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::panic::Location; +use std::pin::Pin; +use std::task::{self, Poll}; + +/// Waits until `deadline` is reached. +/// +/// No work is performed while awaiting on the sleep future to complete. `Sleep` +/// operates at millisecond granularity and should not be used for tasks that +/// require high-resolution timers. +/// +/// To run something regularly on a schedule, see [`interval`]. +/// +/// # Cancellation +/// +/// Canceling a sleep instance is done by dropping the returned future. No additional +/// cleanup work is required. +/// +/// # Examples +/// +/// Wait 100ms and print "100 ms have elapsed". +/// +/// ``` +/// use tokio::time::{sleep_until, Instant, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// sleep_until(Instant::now() + Duration::from_millis(100)).await; +/// println!("100 ms have elapsed"); +/// } +/// ``` +/// +/// See the documentation for the [`Sleep`] type for more examples. +/// +/// # Panics +/// +/// This function panics if there is no current timer set. +/// +/// It can be triggered when [`Builder::enable_time`] or +/// [`Builder::enable_all`] are not included in the builder. +/// +/// It can also panic whenever a timer is created outside of a +/// Tokio runtime. That is why `rt.block_on(sleep(...))` will panic, +/// since the function is executed outside of the runtime. +/// Whereas `rt.block_on(async {sleep(...).await})` doesn't panic. +/// And this is because wrapping the function on an async makes it lazy, +/// and so gets executed inside the runtime successfully without +/// panicking. +/// +/// [`Sleep`]: struct@crate::time::Sleep +/// [`interval`]: crate::time::interval() +/// [`Builder::enable_time`]: crate::runtime::Builder::enable_time +/// [`Builder::enable_all`]: crate::runtime::Builder::enable_all +// Alias for old name in 0.x +#[cfg_attr(docsrs, doc(alias = "delay_until"))] +#[track_caller] +pub fn sleep_until(deadline: Instant) -> Sleep { + return Sleep::new_timeout(deadline, trace::caller_location()); +} + +/// Waits until `duration` has elapsed. +/// +/// Equivalent to `sleep_until(Instant::now() + duration)`. An asynchronous +/// analog to `std::thread::sleep`. +/// +/// No work is performed while awaiting on the sleep future to complete. `Sleep` +/// operates at millisecond granularity and should not be used for tasks that +/// require high-resolution timers. +/// +/// To run something regularly on a schedule, see [`interval`]. +/// +/// The maximum duration for a sleep is 68719476734 milliseconds (approximately 2.2 years). +/// +/// # Cancellation +/// +/// Canceling a sleep instance is done by dropping the returned future. No additional +/// cleanup work is required. +/// +/// # Examples +/// +/// Wait 100ms and print "100 ms have elapsed". +/// +/// ``` +/// use tokio::time::{sleep, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// sleep(Duration::from_millis(100)).await; +/// println!("100 ms have elapsed"); +/// } +/// ``` +/// +/// See the documentation for the [`Sleep`] type for more examples. +/// +/// # Panics +/// +/// This function panics if there is no current timer set. +/// +/// It can be triggered when [`Builder::enable_time`] or +/// [`Builder::enable_all`] are not included in the builder. +/// +/// It can also panic whenever a timer is created outside of a +/// Tokio runtime. That is why `rt.block_on(sleep(...))` will panic, +/// since the function is executed outside of the runtime. +/// Whereas `rt.block_on(async {sleep(...).await})` doesn't panic. +/// And this is because wrapping the function on an async makes it lazy, +/// and so gets executed inside the runtime successfully without +/// panicking. +/// +/// [`Sleep`]: struct@crate::time::Sleep +/// [`interval`]: crate::time::interval() +/// [`Builder::enable_time`]: crate::runtime::Builder::enable_time +/// [`Builder::enable_all`]: crate::runtime::Builder::enable_all +// Alias for old name in 0.x +#[cfg_attr(docsrs, doc(alias = "delay_for"))] +#[cfg_attr(docsrs, doc(alias = "wait"))] +#[track_caller] +pub fn sleep(duration: Duration) -> Sleep { + let location = trace::caller_location(); + + match Instant::now().checked_add(duration) { + Some(deadline) => Sleep::new_timeout(deadline, location), + None => Sleep::new_timeout(Instant::far_future(), location), + } +} + +pin_project! { + /// Future returned by [`sleep`](sleep) and [`sleep_until`](sleep_until). + /// + /// This type does not implement the `Unpin` trait, which means that if you + /// use it with [`select!`] or by calling `poll`, you have to pin it first. + /// If you use it with `.await`, this does not apply. + /// + /// # Examples + /// + /// Wait 100ms and print "100 ms have elapsed". + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// + /// #[tokio::main] + /// async fn main() { + /// sleep(Duration::from_millis(100)).await; + /// println!("100 ms have elapsed"); + /// } + /// ``` + /// + /// Use with [`select!`]. Pinning the `Sleep` with [`tokio::pin!`] is + /// necessary when the same `Sleep` is selected on multiple times. + /// ```no_run + /// use tokio::time::{self, Duration, Instant}; + /// + /// #[tokio::main] + /// async fn main() { + /// let sleep = time::sleep(Duration::from_millis(10)); + /// tokio::pin!(sleep); + /// + /// loop { + /// tokio::select! { + /// () = &mut sleep => { + /// println!("timer elapsed"); + /// sleep.as_mut().reset(Instant::now() + Duration::from_millis(50)); + /// }, + /// } + /// } + /// } + /// ``` + /// Use in a struct with boxing. By pinning the `Sleep` with a `Box`, the + /// `HasSleep` struct implements `Unpin`, even though `Sleep` does not. + /// ``` + /// use std::future::Future; + /// use std::pin::Pin; + /// use std::task::{Context, Poll}; + /// use tokio::time::Sleep; + /// + /// struct HasSleep { + /// sleep: Pin<Box<Sleep>>, + /// } + /// + /// impl Future for HasSleep { + /// type Output = (); + /// + /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + /// self.sleep.as_mut().poll(cx) + /// } + /// } + /// ``` + /// Use in a struct with pin projection. This method avoids the `Box`, but + /// the `HasSleep` struct will not be `Unpin` as a consequence. + /// ``` + /// use std::future::Future; + /// use std::pin::Pin; + /// use std::task::{Context, Poll}; + /// use tokio::time::Sleep; + /// use pin_project_lite::pin_project; + /// + /// pin_project! { + /// struct HasSleep { + /// #[pin] + /// sleep: Sleep, + /// } + /// } + /// + /// impl Future for HasSleep { + /// type Output = (); + /// + /// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + /// self.project().sleep.poll(cx) + /// } + /// } + /// ``` + /// + /// [`select!`]: ../macro.select.html + /// [`tokio::pin!`]: ../macro.pin.html + // Alias for old name in 0.2 + #[cfg_attr(docsrs, doc(alias = "Delay"))] + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Sleep { + inner: Inner, + + // The link between the `Sleep` instance and the timer that drives it. + #[pin] + entry: TimerEntry, + } +} + +cfg_trace! { + #[derive(Debug)] + struct Inner { + deadline: Instant, + ctx: trace::AsyncOpTracingCtx, + time_source: ClockTime, + } +} + +cfg_not_trace! { + #[derive(Debug)] + struct Inner { + deadline: Instant, + } +} + +impl Sleep { + #[cfg_attr(not(all(tokio_unstable, feature = "tracing")), allow(unused_variables))] + pub(crate) fn new_timeout( + deadline: Instant, + location: Option<&'static Location<'static>>, + ) -> Sleep { + let handle = Handle::current(); + let entry = TimerEntry::new(&handle, deadline); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = { + let time_source = handle.time_source().clone(); + let deadline_tick = time_source.deadline_to_tick(deadline); + let duration = deadline_tick.checked_sub(time_source.now()).unwrap_or(0); + + let location = location.expect("should have location if tracing"); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Sleep", + kind = "timer", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + let async_op_span = resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + duration = duration, + duration.unit = "ms", + duration.op = "override", + ); + + tracing::trace_span!("runtime.resource.async_op", source = "Sleep::new_timeout") + }); + + let async_op_poll_span = + async_op_span.in_scope(|| tracing::trace_span!("runtime.resource.async_op.poll")); + + let ctx = trace::AsyncOpTracingCtx { + async_op_span, + async_op_poll_span, + resource_span, + }; + + Inner { + deadline, + ctx, + time_source, + } + }; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = Inner { deadline }; + + Sleep { inner, entry } + } + + pub(crate) fn far_future(location: Option<&'static Location<'static>>) -> Sleep { + Self::new_timeout(Instant::far_future(), location) + } + + /// Returns the instant at which the future will complete. + pub fn deadline(&self) -> Instant { + self.inner.deadline + } + + /// Returns `true` if `Sleep` has elapsed. + /// + /// A `Sleep` instance is elapsed when the requested duration has elapsed. + pub fn is_elapsed(&self) -> bool { + self.entry.is_elapsed() + } + + /// Resets the `Sleep` instance to a new deadline. + /// + /// Calling this function allows changing the instant at which the `Sleep` + /// future completes without having to create new associated state. + /// + /// This function can be called both before and after the future has + /// completed. + /// + /// To call this method, you will usually combine the call with + /// [`Pin::as_mut`], which lets you call the method without consuming the + /// `Sleep` itself. + /// + /// # Example + /// + /// ``` + /// use tokio::time::{Duration, Instant}; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let sleep = tokio::time::sleep(Duration::from_millis(10)); + /// tokio::pin!(sleep); + /// + /// sleep.as_mut().reset(Instant::now() + Duration::from_millis(20)); + /// # } + /// ``` + /// + /// See also the top-level examples. + /// + /// [`Pin::as_mut`]: fn@std::pin::Pin::as_mut + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + self.reset_inner(deadline) + } + + fn reset_inner(self: Pin<&mut Self>, deadline: Instant) { + let me = self.project(); + me.entry.reset(deadline); + (*me.inner).deadline = deadline; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + { + let _resource_enter = me.inner.ctx.resource_span.enter(); + me.inner.ctx.async_op_span = + tracing::trace_span!("runtime.resource.async_op", source = "Sleep::reset"); + let _async_op_enter = me.inner.ctx.async_op_span.enter(); + + me.inner.ctx.async_op_poll_span = + tracing::trace_span!("runtime.resource.async_op.poll"); + + let duration = { + let now = me.inner.time_source.now(); + let deadline_tick = me.inner.time_source.deadline_to_tick(deadline); + deadline_tick.checked_sub(now).unwrap_or(0) + }; + + tracing::trace!( + target: "runtime::resource::state_update", + duration = duration, + duration.unit = "ms", + duration.op = "override", + ); + } + } + + fn poll_elapsed(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { + let me = self.project(); + + // Keep track of task budget + #[cfg(all(tokio_unstable, feature = "tracing"))] + let coop = ready!(trace_poll_op!( + "poll_elapsed", + crate::coop::poll_proceed(cx), + )); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let coop = ready!(crate::coop::poll_proceed(cx)); + + let result = me.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + return trace_poll_op!("poll_elapsed", result); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + return result; + } +} + +impl Future for Sleep { + type Output = (); + + // `poll_elapsed` can return an error in two cases: + // + // - AtCapacity: this is a pathological case where far too many + // sleep instances have been scheduled. + // - Shutdown: No timer has been setup, which is a mis-use error. + // + // Both cases are extremely rare, and pretty accurately fit into + // "logic errors", so we just panic in this case. A user couldn't + // really do much better if we passed the error onwards. + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _res_span = self.inner.ctx.resource_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_span = self.inner.ctx.async_op_span.clone().entered(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _ao_poll_span = self.inner.ctx.async_op_poll_span.clone().entered(); + match ready!(self.as_mut().poll_elapsed(cx)) { + Ok(()) => Poll::Ready(()), + Err(e) => panic!("timer error: {}", e), + } + } +} diff --git a/third_party/rust/tokio/src/time/driver/tests/mod.rs b/third_party/rust/tokio/src/time/driver/tests/mod.rs new file mode 100644 index 0000000000..3ac8c75643 --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/tests/mod.rs @@ -0,0 +1,301 @@ +use std::{task::Context, time::Duration}; + +#[cfg(not(loom))] +use futures::task::noop_waker_ref; + +use crate::loom::sync::Arc; +use crate::loom::thread; +use crate::{ + loom::sync::atomic::{AtomicBool, Ordering}, + park::Unpark, +}; + +use super::{Handle, TimerEntry}; + +struct MockUnpark {} +impl Unpark for MockUnpark { + fn unpark(&self) {} +} +impl MockUnpark { + fn mock() -> Box<dyn Unpark> { + Box::new(Self {}) + } +} + +fn block_on<T>(f: impl std::future::Future<Output = T>) -> T { + #[cfg(loom)] + return loom::future::block_on(f); + + #[cfg(not(loom))] + { + let rt = crate::runtime::Builder::new_current_thread() + .build() + .unwrap(); + rt.block_on(f) + } +} + +fn model(f: impl Fn() + Send + Sync + 'static) { + #[cfg(loom)] + loom::model(f); + + #[cfg(not(loom))] + f(); +} + +#[test] +fn single_timer() { + model(|| { + let clock = crate::time::clock::Clock::new(true, false); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(inner)); + + let handle_ = handle.clone(); + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, clock.now() + Duration::from_secs(1)); + pin!(entry); + + block_on(futures::future::poll_fn(|cx| { + entry.as_mut().poll_elapsed(cx) + })) + .unwrap(); + }); + + thread::yield_now(); + + // This may or may not return Some (depending on how it races with the + // thread). If it does return None, however, the timer should complete + // synchronously. + handle.process_at_time(time_source.now() + 2_000_000_000); + + jh.join().unwrap(); + }) +} + +#[test] +fn drop_timer() { + model(|| { + let clock = crate::time::clock::Clock::new(true, false); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(inner)); + + let handle_ = handle.clone(); + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, clock.now() + Duration::from_secs(1)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + }); + + thread::yield_now(); + + // advance 2s in the future. + handle.process_at_time(time_source.now() + 2_000_000_000); + + jh.join().unwrap(); + }) +} + +#[test] +fn change_waker() { + model(|| { + let clock = crate::time::clock::Clock::new(true, false); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(inner)); + + let handle_ = handle.clone(); + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, clock.now() + Duration::from_secs(1)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + block_on(futures::future::poll_fn(|cx| { + entry.as_mut().poll_elapsed(cx) + })) + .unwrap(); + }); + + thread::yield_now(); + + // advance 2s + handle.process_at_time(time_source.now() + 2_000_000_000); + + jh.join().unwrap(); + }) +} + +#[test] +fn reset_future() { + model(|| { + let finished_early = Arc::new(AtomicBool::new(false)); + + let clock = crate::time::clock::Clock::new(true, false); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(inner)); + + let handle_ = handle.clone(); + let finished_early_ = finished_early.clone(); + let start = clock.now(); + + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, start + Duration::from_secs(1)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + entry.as_mut().reset(start + Duration::from_secs(2)); + + // shouldn't complete before 2s + block_on(futures::future::poll_fn(|cx| { + entry.as_mut().poll_elapsed(cx) + })) + .unwrap(); + + finished_early_.store(true, Ordering::Relaxed); + }); + + thread::yield_now(); + + // This may or may not return a wakeup time. + handle.process_at_time(time_source.instant_to_tick(start + Duration::from_millis(1500))); + + assert!(!finished_early.load(Ordering::Relaxed)); + + handle.process_at_time(time_source.instant_to_tick(start + Duration::from_millis(2500))); + + jh.join().unwrap(); + + assert!(finished_early.load(Ordering::Relaxed)); + }) +} + +#[cfg(not(loom))] +fn normal_or_miri<T>(normal: T, miri: T) -> T { + if cfg!(miri) { + miri + } else { + normal + } +} + +#[test] +#[cfg(not(loom))] +fn poll_process_levels() { + let clock = crate::time::clock::Clock::new(true, false); + clock.pause(); + + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source, MockUnpark::mock()); + let handle = Handle::new(Arc::new(inner)); + + let mut entries = vec![]; + + for i in 0..normal_or_miri(1024, 64) { + let mut entry = Box::pin(TimerEntry::new( + &handle, + clock.now() + Duration::from_millis(i), + )); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(noop_waker_ref())); + + entries.push(entry); + } + + for t in 1..normal_or_miri(1024, 64) { + handle.process_at_time(t as u64); + for (deadline, future) in entries.iter_mut().enumerate() { + let mut context = Context::from_waker(noop_waker_ref()); + if deadline <= t { + assert!(future.as_mut().poll_elapsed(&mut context).is_ready()); + } else { + assert!(future.as_mut().poll_elapsed(&mut context).is_pending()); + } + } + } +} + +#[test] +#[cfg(not(loom))] +fn poll_process_levels_targeted() { + let mut context = Context::from_waker(noop_waker_ref()); + + let clock = crate::time::clock::Clock::new(true, false); + clock.pause(); + + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source, MockUnpark::mock()); + let handle = Handle::new(Arc::new(inner)); + + let e1 = TimerEntry::new(&handle, clock.now() + Duration::from_millis(193)); + pin!(e1); + + handle.process_at_time(62); + assert!(e1.as_mut().poll_elapsed(&mut context).is_pending()); + handle.process_at_time(192); + handle.process_at_time(192); +} + +/* +#[test] +fn balanced_incr_and_decr() { + const OPS: usize = 5; + + fn incr(inner: Arc<Inner>) { + for _ in 0..OPS { + inner.increment().expect("increment should not have failed"); + thread::yield_now(); + } + } + + fn decr(inner: Arc<Inner>) { + let mut ops_performed = 0; + while ops_performed < OPS { + if inner.num(Ordering::Relaxed) > 0 { + ops_performed += 1; + inner.decrement(); + } + thread::yield_now(); + } + } + + loom::model(|| { + let unpark = Box::new(MockUnpark); + let instant = Instant::now(); + + let inner = Arc::new(Inner::new(instant, unpark)); + + let incr_inner = inner.clone(); + let decr_inner = inner.clone(); + + let incr_hndle = thread::spawn(move || incr(incr_inner)); + let decr_hndle = thread::spawn(move || decr(decr_inner)); + + incr_hndle.join().expect("should never fail"); + decr_hndle.join().expect("should never fail"); + + assert_eq!(inner.num(Ordering::SeqCst), 0); + }) +} +*/ diff --git a/third_party/rust/tokio/src/time/driver/wheel/level.rs b/third_party/rust/tokio/src/time/driver/wheel/level.rs new file mode 100644 index 0000000000..878754177b --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/wheel/level.rs @@ -0,0 +1,276 @@ +use crate::time::driver::TimerHandle; + +use crate::time::driver::{EntryList, TimerShared}; + +use std::{fmt, ptr::NonNull}; + +/// Wheel for a single level in the timer. This wheel contains 64 slots. +pub(crate) struct Level { + level: usize, + + /// Bit field tracking which slots currently contain entries. + /// + /// Using a bit field to track slots that contain entries allows avoiding a + /// scan to find entries. This field is updated when entries are added or + /// removed from a slot. + /// + /// The least-significant bit represents slot zero. + occupied: u64, + + /// Slots. We access these via the EntryInner `current_list` as well, so this needs to be an UnsafeCell. + slot: [EntryList; LEVEL_MULT], +} + +/// Indicates when a slot must be processed next. +#[derive(Debug)] +pub(crate) struct Expiration { + /// The level containing the slot. + pub(crate) level: usize, + + /// The slot index. + pub(crate) slot: usize, + + /// The instant at which the slot needs to be processed. + pub(crate) deadline: u64, +} + +/// Level multiplier. +/// +/// Being a power of 2 is very important. +const LEVEL_MULT: usize = 64; + +impl Level { + pub(crate) fn new(level: usize) -> Level { + // A value has to be Copy in order to use syntax like: + // let stack = Stack::default(); + // ... + // slots: [stack; 64], + // + // Alternatively, since Stack is Default one can + // use syntax like: + // let slots: [Stack; 64] = Default::default(); + // + // However, that is only supported for arrays of size + // 32 or fewer. So in our case we have to explicitly + // invoke the constructor for each array element. + let ctor = EntryList::default; + + Level { + level, + occupied: 0, + slot: [ + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ], + } + } + + /// Finds the slot that needs to be processed next and returns the slot and + /// `Instant` at which this slot must be processed. + pub(crate) fn next_expiration(&self, now: u64) -> Option<Expiration> { + // Use the `occupied` bit field to get the index of the next slot that + // needs to be processed. + let slot = match self.next_occupied_slot(now) { + Some(slot) => slot, + None => return None, + }; + + // From the slot index, calculate the `Instant` at which it needs to be + // processed. This value *must* be in the future with respect to `now`. + + let level_range = level_range(self.level); + let slot_range = slot_range(self.level); + + // Compute the start date of the current level by masking the low bits + // of `now` (`level_range` is a power of 2). + let level_start = now & !(level_range - 1); + let mut deadline = level_start + slot as u64 * slot_range; + + if deadline <= now { + // A timer is in a slot "prior" to the current time. This can occur + // because we do not have an infinite hierarchy of timer levels, and + // eventually a timer scheduled for a very distant time might end up + // being placed in a slot that is beyond the end of all of the + // arrays. + // + // To deal with this, we first limit timers to being scheduled no + // more than MAX_DURATION ticks in the future; that is, they're at + // most one rotation of the top level away. Then, we force timers + // that logically would go into the top+1 level, to instead go into + // the top level's slots. + // + // What this means is that the top level's slots act as a + // pseudo-ring buffer, and we rotate around them indefinitely. If we + // compute a deadline before now, and it's the top level, it + // therefore means we're actually looking at a slot in the future. + debug_assert_eq!(self.level, super::NUM_LEVELS - 1); + + deadline += level_range; + } + + debug_assert!( + deadline >= now, + "deadline={:016X}; now={:016X}; level={}; lr={:016X}, sr={:016X}, slot={}; occupied={:b}", + deadline, + now, + self.level, + level_range, + slot_range, + slot, + self.occupied + ); + + Some(Expiration { + level: self.level, + slot, + deadline, + }) + } + + fn next_occupied_slot(&self, now: u64) -> Option<usize> { + if self.occupied == 0 { + return None; + } + + // Get the slot for now using Maths + let now_slot = (now / slot_range(self.level)) as usize; + let occupied = self.occupied.rotate_right(now_slot as u32); + let zeros = occupied.trailing_zeros() as usize; + let slot = (zeros + now_slot) % 64; + + Some(slot) + } + + pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { + let slot = slot_for(item.cached_when(), self.level); + + self.slot[slot].push_front(item); + + self.occupied |= occupied_bit(slot); + } + + pub(crate) unsafe fn remove_entry(&mut self, item: NonNull<TimerShared>) { + let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level); + + unsafe { self.slot[slot].remove(item) }; + if self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + // Unset the bit + self.occupied ^= occupied_bit(slot); + } + } + + pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList { + self.occupied &= !occupied_bit(slot); + + std::mem::take(&mut self.slot[slot]) + } +} + +impl fmt::Debug for Level { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Level") + .field("occupied", &self.occupied) + .finish() + } +} + +fn occupied_bit(slot: usize) -> u64 { + 1 << slot +} + +fn slot_range(level: usize) -> u64 { + LEVEL_MULT.pow(level as u32) as u64 +} + +fn level_range(level: usize) -> u64 { + LEVEL_MULT as u64 * slot_range(level) +} + +/// Converts a duration (milliseconds) and a level to a slot position. +fn slot_for(duration: u64, level: usize) -> usize { + ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_slot_for() { + for pos in 0..64 { + assert_eq!(pos as usize, slot_for(pos, 0)); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!(pos as usize, slot_for(a as u64, level)); + } + } + } +} diff --git a/third_party/rust/tokio/src/time/driver/wheel/mod.rs b/third_party/rust/tokio/src/time/driver/wheel/mod.rs new file mode 100644 index 0000000000..f088f2cfd6 --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/wheel/mod.rs @@ -0,0 +1,359 @@ +use crate::time::driver::{TimerHandle, TimerShared}; +use crate::time::error::InsertError; + +mod level; +pub(crate) use self::level::Expiration; +use self::level::Level; + +use std::ptr::NonNull; + +use super::EntryList; + +/// Timing wheel implementation. +/// +/// This type provides the hashed timing wheel implementation that backs `Timer` +/// and `DelayQueue`. +/// +/// The structure is generic over `T: Stack`. This allows handling timeout data +/// being stored on the heap or in a slab. In order to support the latter case, +/// the slab must be passed into each function allowing the implementation to +/// lookup timer entries. +/// +/// See `Timer` documentation for some implementation notes. +#[derive(Debug)] +pub(crate) struct Wheel { + /// The number of milliseconds elapsed since the wheel started. + elapsed: u64, + + /// Timer wheel. + /// + /// Levels: + /// + /// * 1 ms slots / 64 ms range + /// * 64 ms slots / ~ 4 sec range + /// * ~ 4 sec slots / ~ 4 min range + /// * ~ 4 min slots / ~ 4 hr range + /// * ~ 4 hr slots / ~ 12 day range + /// * ~ 12 day slots / ~ 2 yr range + levels: Vec<Level>, + + /// Entries queued for firing + pending: EntryList, +} + +/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots +/// each, the timer is able to track time up to 2 years into the future with a +/// precision of 1 millisecond. +const NUM_LEVELS: usize = 6; + +/// The maximum duration of a `Sleep`. +pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; + +impl Wheel { + /// Creates a new timing wheel. + pub(crate) fn new() -> Wheel { + let levels = (0..NUM_LEVELS).map(Level::new).collect(); + + Wheel { + elapsed: 0, + levels, + pending: EntryList::new(), + } + } + + /// Returns the number of milliseconds that have elapsed since the timing + /// wheel's creation. + pub(crate) fn elapsed(&self) -> u64 { + self.elapsed + } + + /// Inserts an entry into the timing wheel. + /// + /// # Arguments + /// + /// * `item`: The item to insert into the wheel. + /// + /// # Return + /// + /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. + /// + /// `Err(Elapsed)` indicates that `when` represents an instant that has + /// already passed. In this case, the caller should fire the timeout + /// immediately. + /// + /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. + /// + /// # Safety + /// + /// This function registers item into an intrusive linked list. The caller + /// must ensure that `item` is pinned and will not be dropped without first + /// being deregistered. + pub(crate) unsafe fn insert( + &mut self, + item: TimerHandle, + ) -> Result<u64, (TimerHandle, InsertError)> { + let when = item.sync_when(); + + if when <= self.elapsed { + return Err((item, InsertError::Elapsed)); + } + + // Get the level at which the entry should be stored + let level = self.level_for(when); + + unsafe { + self.levels[level].add_entry(item); + } + + debug_assert!({ + self.levels[level] + .next_expiration(self.elapsed) + .map(|e| e.deadline >= self.elapsed) + .unwrap_or(true) + }); + + Ok(when) + } + + /// Removes `item` from the timing wheel. + pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) { + unsafe { + let when = item.as_ref().cached_when(); + if when == u64::MAX { + self.pending.remove(item); + } else { + debug_assert!( + self.elapsed <= when, + "elapsed={}; when={}", + self.elapsed, + when + ); + + let level = self.level_for(when); + + self.levels[level].remove_entry(item); + } + } + } + + /// Instant at which to poll. + pub(crate) fn poll_at(&self) -> Option<u64> { + self.next_expiration().map(|expiration| expiration.deadline) + } + + /// Advances the timer up to the instant represented by `now`. + pub(crate) fn poll(&mut self, now: u64) -> Option<TimerHandle> { + loop { + if let Some(handle) = self.pending.pop_back() { + return Some(handle); + } + + // under what circumstances is poll.expiration Some vs. None? + let expiration = self.next_expiration().and_then(|expiration| { + if expiration.deadline > now { + None + } else { + Some(expiration) + } + }); + + match expiration { + Some(ref expiration) if expiration.deadline > now => return None, + Some(ref expiration) => { + self.process_expiration(expiration); + + self.set_elapsed(expiration.deadline); + } + None => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + break; + } + } + } + + self.pending.pop_back() + } + + /// Returns the instant at which the next timeout expires. + fn next_expiration(&self) -> Option<Expiration> { + if !self.pending.is_empty() { + // Expire immediately as we have things pending firing + return Some(Expiration { + level: 0, + slot: 0, + deadline: self.elapsed, + }); + } + + // Check all levels + for level in 0..NUM_LEVELS { + if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { + // There cannot be any expirations at a higher level that happen + // before this one. + debug_assert!(self.no_expirations_before(level + 1, expiration.deadline)); + + return Some(expiration); + } + } + + None + } + + /// Returns the tick at which this timer wheel next needs to perform some + /// processing, or None if there are no timers registered. + pub(super) fn next_expiration_time(&self) -> Option<u64> { + self.next_expiration().map(|ex| ex.deadline) + } + + /// Used for debug assertions + fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { + let mut res = true; + + for l2 in start_level..NUM_LEVELS { + if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) { + if e2.deadline < before { + res = false; + } + } + } + + res + } + + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// queue it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn process_expiration(&mut self, expiration: &Expiration) { + // Note that we need to take _all_ of the entries off the list before + // processing any of them. This is important because it's possible that + // those entries might need to be reinserted into the same slot. + // + // This happens only on the highest level, when an entry is inserted + // more than MAX_DURATION into the future. When this happens, we wrap + // around, and process some entries a multiple of MAX_DURATION before + // they actually need to be dropped down a level. We then reinsert them + // back into the same position; we must make sure we don't then process + // those entries again or we'll end up in an infinite loop. + let mut entries = self.take_entries(expiration); + + while let Some(item) = entries.pop_back() { + if expiration.level == 0 { + debug_assert_eq!(unsafe { item.cached_when() }, expiration.deadline); + } + + // Try to expire the entry; this is cheap (doesn't synchronize) if + // the timer is not expired, and updates cached_when. + match unsafe { item.mark_pending(expiration.deadline) } { + Ok(()) => { + // Item was expired + self.pending.push_front(item); + } + Err(expiration_tick) => { + let level = level_for(expiration.deadline, expiration_tick); + unsafe { + self.levels[level].add_entry(item); + } + } + } + } + } + + fn set_elapsed(&mut self, when: u64) { + assert!( + self.elapsed <= when, + "elapsed={:?}; when={:?}", + self.elapsed, + when + ); + + if when > self.elapsed { + self.elapsed = when; + } + } + + /// Obtains the list of entries that need processing for the given expiration. + /// + fn take_entries(&mut self, expiration: &Expiration) -> EntryList { + self.levels[expiration.level].take_slot(expiration.slot) + } + + fn level_for(&self, when: u64) -> usize { + level_for(self.elapsed, when) + } +} + +fn level_for(elapsed: u64, when: u64) -> usize { + const SLOT_MASK: u64 = (1 << 6) - 1; + + // Mask in the trailing bits ignored by the level calculation in order to cap + // the possible leading zeros + let mut masked = elapsed ^ when | SLOT_MASK; + + if masked >= MAX_DURATION { + // Fudge the timer into the top level + masked = MAX_DURATION - 1; + } + + let leading_zeros = masked.leading_zeros() as usize; + let significant = 63 - leading_zeros; + + significant / 6 +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_level_for() { + for pos in 0..64 { + assert_eq!( + 0, + level_for(0, pos), + "level_for({}) -- binary = {:b}", + pos, + pos + ); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + + if pos > level { + let a = a - 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + + if pos < 64 { + let a = a + 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + } + } + } +} diff --git a/third_party/rust/tokio/src/time/driver/wheel/stack.rs b/third_party/rust/tokio/src/time/driver/wheel/stack.rs new file mode 100644 index 0000000000..80651c309e --- /dev/null +++ b/third_party/rust/tokio/src/time/driver/wheel/stack.rs @@ -0,0 +1,112 @@ +use super::{Item, OwnedItem}; +use crate::time::driver::Entry; + +use std::ptr; + +/// A doubly linked stack. +#[derive(Debug)] +pub(crate) struct Stack { + head: Option<OwnedItem>, +} + +impl Default for Stack { + fn default() -> Stack { + Stack { head: None } + } +} + +impl Stack { + pub(crate) fn is_empty(&self) -> bool { + self.head.is_none() + } + + pub(crate) fn push(&mut self, entry: OwnedItem) { + // Get a pointer to the entry to for the prev link + let ptr: *const Entry = &*entry as *const _; + + // Remove the old head entry + let old = self.head.take(); + + unsafe { + // Ensure the entry is not already in a stack. + debug_assert!((*entry.next_stack.get()).is_none()); + debug_assert!((*entry.prev_stack.get()).is_null()); + + if let Some(ref entry) = old.as_ref() { + debug_assert!({ + // The head is not already set to the entry + ptr != &***entry as *const _ + }); + + // Set the previous link on the old head + *entry.prev_stack.get() = ptr; + } + + // Set this entry's next pointer + *entry.next_stack.get() = old; + } + + // Update the head pointer + self.head = Some(entry); + } + + /// Pops an item from the stack. + pub(crate) fn pop(&mut self) -> Option<OwnedItem> { + let entry = self.head.take(); + + unsafe { + if let Some(entry) = entry.as_ref() { + self.head = (*entry.next_stack.get()).take(); + + if let Some(entry) = self.head.as_ref() { + *entry.prev_stack.get() = ptr::null(); + } + + *entry.prev_stack.get() = ptr::null(); + } + } + + entry + } + + pub(crate) fn remove(&mut self, entry: &Item) { + unsafe { + // Ensure that the entry is in fact contained by the stack + debug_assert!({ + // This walks the full linked list even if an entry is found. + let mut next = self.head.as_ref(); + let mut contains = false; + + while let Some(n) = next { + if entry as *const _ == &**n as *const _ { + debug_assert!(!contains); + contains = true; + } + + next = (*n.next_stack.get()).as_ref(); + } + + contains + }); + + // Unlink `entry` from the next node + let next = (*entry.next_stack.get()).take(); + + if let Some(next) = next.as_ref() { + (*next.prev_stack.get()) = *entry.prev_stack.get(); + } + + // Unlink `entry` from the prev node + + if let Some(prev) = (*entry.prev_stack.get()).as_ref() { + *prev.next_stack.get() = next; + } else { + // It is the head + self.head = next; + } + + // Unset the prev pointer + *entry.prev_stack.get() = ptr::null(); + } + } +} diff --git a/third_party/rust/tokio/src/time/error.rs b/third_party/rust/tokio/src/time/error.rs new file mode 100644 index 0000000000..63f0a3b0bd --- /dev/null +++ b/third_party/rust/tokio/src/time/error.rs @@ -0,0 +1,120 @@ +//! Time error types. + +use self::Kind::*; +use std::error; +use std::fmt; + +/// Errors encountered by the timer implementation. +/// +/// Currently, there are two different errors that can occur: +/// +/// * `shutdown` occurs when a timer operation is attempted, but the timer +/// instance has been dropped. In this case, the operation will never be able +/// to complete and the `shutdown` error is returned. This is a permanent +/// error, i.e., once this error is observed, timer operations will never +/// succeed in the future. +/// +/// * `at_capacity` occurs when a timer operation is attempted, but the timer +/// instance is currently handling its maximum number of outstanding sleep instances. +/// In this case, the operation is not able to be performed at the current +/// moment, and `at_capacity` is returned. This is a transient error, i.e., at +/// some point in the future, if the operation is attempted again, it might +/// succeed. Callers that observe this error should attempt to [shed load]. One +/// way to do this would be dropping the future that issued the timer operation. +/// +/// [shed load]: https://en.wikipedia.org/wiki/Load_Shedding +#[derive(Debug, Copy, Clone)] +pub struct Error(Kind); + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +pub(crate) enum Kind { + Shutdown = 1, + AtCapacity = 2, + Invalid = 3, +} + +impl From<Kind> for Error { + fn from(k: Kind) -> Self { + Error(k) + } +} + +/// Errors returned by `Timeout`. +#[derive(Debug, PartialEq)] +pub struct Elapsed(()); + +#[derive(Debug)] +pub(crate) enum InsertError { + Elapsed, +} + +// ===== impl Error ===== + +impl Error { + /// Creates an error representing a shutdown timer. + pub fn shutdown() -> Error { + Error(Shutdown) + } + + /// Returns `true` if the error was caused by the timer being shutdown. + pub fn is_shutdown(&self) -> bool { + matches!(self.0, Kind::Shutdown) + } + + /// Creates an error representing a timer at capacity. + pub fn at_capacity() -> Error { + Error(AtCapacity) + } + + /// Returns `true` if the error was caused by the timer being at capacity. + pub fn is_at_capacity(&self) -> bool { + matches!(self.0, Kind::AtCapacity) + } + + /// Creates an error representing a misconfigured timer. + pub fn invalid() -> Error { + Error(Invalid) + } + + /// Returns `true` if the error was caused by the timer being misconfigured. + pub fn is_invalid(&self) -> bool { + matches!(self.0, Kind::Invalid) + } +} + +impl error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + use self::Kind::*; + let descr = match self.0 { + Shutdown => "the timer is shutdown, must be called from the context of Tokio runtime", + AtCapacity => "timer is at capacity and cannot create a new entry", + Invalid => "timer duration exceeds maximum duration", + }; + write!(fmt, "{}", descr) + } +} + +// ===== impl Elapsed ===== + +impl Elapsed { + pub(crate) fn new() -> Self { + Elapsed(()) + } +} + +impl fmt::Display for Elapsed { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "deadline has elapsed".fmt(fmt) + } +} + +impl std::error::Error for Elapsed {} + +impl From<Elapsed> for std::io::Error { + fn from(_err: Elapsed) -> std::io::Error { + std::io::ErrorKind::TimedOut.into() + } +} diff --git a/third_party/rust/tokio/src/time/instant.rs b/third_party/rust/tokio/src/time/instant.rs new file mode 100644 index 0000000000..f18492930a --- /dev/null +++ b/third_party/rust/tokio/src/time/instant.rs @@ -0,0 +1,223 @@ +#![allow(clippy::trivially_copy_pass_by_ref)] + +use std::fmt; +use std::ops; +use std::time::Duration; + +/// A measurement of a monotonically nondecreasing clock. +/// Opaque and useful only with `Duration`. +/// +/// Instants are always guaranteed to be no less than any previously measured +/// instant when created, and are often useful for tasks such as measuring +/// benchmarks or timing how long an operation takes. +/// +/// Note, however, that instants are not guaranteed to be **steady**. In other +/// words, each tick of the underlying clock may not be the same length (e.g. +/// some seconds may be longer than others). An instant may jump forwards or +/// experience time dilation (slow down or speed up), but it will never go +/// backwards. +/// +/// Instants are opaque types that can only be compared to one another. There is +/// no method to get "the number of seconds" from an instant. Instead, it only +/// allows measuring the duration between two instants (or comparing two +/// instants). +/// +/// The size of an `Instant` struct may vary depending on the target operating +/// system. +/// +/// # Note +/// +/// This type wraps the inner `std` variant and is used to align the Tokio +/// clock for uses of `now()`. This can be useful for testing where you can +/// take advantage of `time::pause()` and `time::advance()`. +#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct Instant { + std: std::time::Instant, +} + +impl Instant { + /// Returns an instant corresponding to "now". + /// + /// # Examples + /// + /// ``` + /// use tokio::time::Instant; + /// + /// let now = Instant::now(); + /// ``` + pub fn now() -> Instant { + variant::now() + } + + /// Create a `tokio::time::Instant` from a `std::time::Instant`. + pub fn from_std(std: std::time::Instant) -> Instant { + Instant { std } + } + + pub(crate) fn far_future() -> Instant { + // Roughly 30 years from now. + // API does not provide a way to obtain max `Instant` + // or convert specific date in the future to instant. + // 1000 years overflows on macOS, 100 years overflows on FreeBSD. + Self::now() + Duration::from_secs(86400 * 365 * 30) + } + + /// Convert the value into a `std::time::Instant`. + pub fn into_std(self) -> std::time::Instant { + self.std + } + + /// Returns the amount of time elapsed from another instant to this one, or + /// zero duration if that instant is later than this one. + pub fn duration_since(&self, earlier: Instant) -> Duration { + self.std.saturating_duration_since(earlier.std) + } + + /// Returns the amount of time elapsed from another instant to this one, or + /// None if that instant is later than this one. + /// + /// # Examples + /// + /// ``` + /// use tokio::time::{Duration, Instant, sleep}; + /// + /// #[tokio::main] + /// async fn main() { + /// let now = Instant::now(); + /// sleep(Duration::new(1, 0)).await; + /// let new_now = Instant::now(); + /// println!("{:?}", new_now.checked_duration_since(now)); + /// println!("{:?}", now.checked_duration_since(new_now)); // None + /// } + /// ``` + pub fn checked_duration_since(&self, earlier: Instant) -> Option<Duration> { + self.std.checked_duration_since(earlier.std) + } + + /// Returns the amount of time elapsed from another instant to this one, or + /// zero duration if that instant is later than this one. + /// + /// # Examples + /// + /// ``` + /// use tokio::time::{Duration, Instant, sleep}; + /// + /// #[tokio::main] + /// async fn main() { + /// let now = Instant::now(); + /// sleep(Duration::new(1, 0)).await; + /// let new_now = Instant::now(); + /// println!("{:?}", new_now.saturating_duration_since(now)); + /// println!("{:?}", now.saturating_duration_since(new_now)); // 0ns + /// } + /// ``` + pub fn saturating_duration_since(&self, earlier: Instant) -> Duration { + self.std.saturating_duration_since(earlier.std) + } + + /// Returns the amount of time elapsed since this instant was created, + /// or zero duration if that this instant is in the future. + /// + /// # Examples + /// + /// ``` + /// use tokio::time::{Duration, Instant, sleep}; + /// + /// #[tokio::main] + /// async fn main() { + /// let instant = Instant::now(); + /// let three_secs = Duration::from_secs(3); + /// sleep(three_secs).await; + /// assert!(instant.elapsed() >= three_secs); + /// } + /// ``` + pub fn elapsed(&self) -> Duration { + Instant::now().saturating_duration_since(*self) + } + + /// Returns `Some(t)` where `t` is the time `self + duration` if `t` can be + /// represented as `Instant` (which means it's inside the bounds of the + /// underlying data structure), `None` otherwise. + pub fn checked_add(&self, duration: Duration) -> Option<Instant> { + self.std.checked_add(duration).map(Instant::from_std) + } + + /// Returns `Some(t)` where `t` is the time `self - duration` if `t` can be + /// represented as `Instant` (which means it's inside the bounds of the + /// underlying data structure), `None` otherwise. + pub fn checked_sub(&self, duration: Duration) -> Option<Instant> { + self.std.checked_sub(duration).map(Instant::from_std) + } +} + +impl From<std::time::Instant> for Instant { + fn from(time: std::time::Instant) -> Instant { + Instant::from_std(time) + } +} + +impl From<Instant> for std::time::Instant { + fn from(time: Instant) -> std::time::Instant { + time.into_std() + } +} + +impl ops::Add<Duration> for Instant { + type Output = Instant; + + fn add(self, other: Duration) -> Instant { + Instant::from_std(self.std + other) + } +} + +impl ops::AddAssign<Duration> for Instant { + fn add_assign(&mut self, rhs: Duration) { + *self = *self + rhs; + } +} + +impl ops::Sub for Instant { + type Output = Duration; + + fn sub(self, rhs: Instant) -> Duration { + self.std.saturating_duration_since(rhs.std) + } +} + +impl ops::Sub<Duration> for Instant { + type Output = Instant; + + fn sub(self, rhs: Duration) -> Instant { + Instant::from_std(self.std - rhs) + } +} + +impl ops::SubAssign<Duration> for Instant { + fn sub_assign(&mut self, rhs: Duration) { + *self = *self - rhs; + } +} + +impl fmt::Debug for Instant { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.std.fmt(fmt) + } +} + +#[cfg(not(feature = "test-util"))] +mod variant { + use super::Instant; + + pub(super) fn now() -> Instant { + Instant::from_std(std::time::Instant::now()) + } +} + +#[cfg(feature = "test-util")] +mod variant { + use super::Instant; + + pub(super) fn now() -> Instant { + crate::time::clock::now() + } +} diff --git a/third_party/rust/tokio/src/time/interval.rs b/third_party/rust/tokio/src/time/interval.rs new file mode 100644 index 0000000000..8ecb15b389 --- /dev/null +++ b/third_party/rust/tokio/src/time/interval.rs @@ -0,0 +1,531 @@ +use crate::future::poll_fn; +use crate::time::{sleep_until, Duration, Instant, Sleep}; +use crate::util::trace; + +use std::panic::Location; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{convert::TryInto, future::Future}; + +/// Creates new [`Interval`] that yields with interval of `period`. The first +/// tick completes immediately. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). +/// +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. +/// +/// This function is equivalent to +/// [`interval_at(Instant::now(), period)`](interval_at). +/// +/// # Panics +/// +/// This function panics if `period` is zero. +/// +/// # Examples +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut interval = time::interval(Duration::from_millis(10)); +/// +/// interval.tick().await; // ticks immediately +/// interval.tick().await; // ticks after 10ms +/// interval.tick().await; // ticks after 10ms +/// +/// // approximately 20ms have elapsed. +/// } +/// ``` +/// +/// A simple example using `interval` to execute a task every two seconds. +/// +/// The difference between `interval` and [`sleep`] is that an [`Interval`] +/// measures the time since the last tick, which means that [`.tick().await`] +/// may wait for a shorter time than the duration specified for the interval +/// if some time has passed between calls to [`.tick().await`]. +/// +/// If the tick in the example below was replaced with [`sleep`], the task +/// would only be executed once every three seconds, and not every two +/// seconds. +/// +/// ``` +/// use tokio::time; +/// +/// async fn task_that_takes_a_second() { +/// println!("hello"); +/// time::sleep(time::Duration::from_secs(1)).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let mut interval = time::interval(time::Duration::from_secs(2)); +/// for _i in 0..5 { +/// interval.tick().await; +/// task_that_takes_a_second().await; +/// } +/// } +/// ``` +/// +/// [`sleep`]: crate::time::sleep() +/// [`.tick().await`]: Interval::tick +#[track_caller] +pub fn interval(period: Duration) -> Interval { + assert!(period > Duration::new(0, 0), "`period` must be non-zero."); + internal_interval_at(Instant::now(), period, trace::caller_location()) +} + +/// Creates new [`Interval`] that yields with interval of `period` with the +/// first tick completing at `start`. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). +/// +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. +/// +/// # Panics +/// +/// This function panics if `period` is zero. +/// +/// # Examples +/// +/// ``` +/// use tokio::time::{interval_at, Duration, Instant}; +/// +/// #[tokio::main] +/// async fn main() { +/// let start = Instant::now() + Duration::from_millis(50); +/// let mut interval = interval_at(start, Duration::from_millis(10)); +/// +/// interval.tick().await; // ticks after 50ms +/// interval.tick().await; // ticks after 10ms +/// interval.tick().await; // ticks after 10ms +/// +/// // approximately 70ms have elapsed. +/// } +/// ``` +#[track_caller] +pub fn interval_at(start: Instant, period: Duration) -> Interval { + assert!(period > Duration::new(0, 0), "`period` must be non-zero."); + internal_interval_at(start, period, trace::caller_location()) +} + +#[cfg_attr(not(all(tokio_unstable, feature = "tracing")), allow(unused_variables))] +fn internal_interval_at( + start: Instant, + period: Duration, + location: Option<&'static Location<'static>>, +) -> Interval { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = location.expect("should have location if tracing"); + + tracing::trace_span!( + "runtime.resource", + concrete_type = "Interval", + kind = "timer", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ) + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let delay = resource_span.in_scope(|| Box::pin(sleep_until(start))); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let delay = Box::pin(sleep_until(start)); + + Interval { + delay, + period, + missed_tick_behavior: Default::default(), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, + } +} + +/// Defines the behavior of an [`Interval`] when it misses a tick. +/// +/// Sometimes, an [`Interval`]'s tick is missed. For example, consider the +/// following: +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// # async fn task_that_takes_one_to_three_millis() {} +/// +/// #[tokio::main] +/// async fn main() { +/// // ticks every 2 milliseconds +/// let mut interval = time::interval(Duration::from_millis(2)); +/// for _ in 0..5 { +/// interval.tick().await; +/// // if this takes more than 2 milliseconds, a tick will be delayed +/// task_that_takes_one_to_three_millis().await; +/// } +/// } +/// ``` +/// +/// Generally, a tick is missed if too much time is spent without calling +/// [`Interval::tick()`]. +/// +/// By default, when a tick is missed, [`Interval`] fires ticks as quickly as it +/// can until it is "caught up" in time to where it should be. +/// `MissedTickBehavior` can be used to specify a different behavior for +/// [`Interval`] to exhibit. Each variant represents a different strategy. +/// +/// Note that because the executor cannot guarantee exact precision with timers, +/// these strategies will only apply when the delay is greater than 5 +/// milliseconds. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MissedTickBehavior { + /// Ticks as fast as possible until caught up. + /// + /// When this strategy is used, [`Interval`] schedules ticks "normally" (the + /// same as it would have if the ticks hadn't been delayed), which results + /// in it firing ticks as fast as possible until it is caught up in time to + /// where it should be. Unlike [`Delay`] and [`Skip`], the ticks yielded + /// when `Burst` is used (the [`Instant`]s that [`tick`](Interval::tick) + /// yields) aren't different than they would have been if a tick had not + /// been missed. Like [`Skip`], and unlike [`Delay`], the ticks may be + /// shortened. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work | work | work -| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration}; + /// # async fn task_that_takes_200_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// + /// task_that_takes_200_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // Since we are more than 100ms after the start of `interval`, this will + /// // also resolve immediately. + /// interval.tick().await; + /// + /// // Also resolves immediately, because it was supposed to resolve at + /// // 150ms after the start of `interval` + /// interval.tick().await; + /// + /// // Resolves immediately + /// interval.tick().await; + /// + /// // Since we have gotten to 200ms after the start of `interval`, this + /// // will resolve after 50ms + /// interval.tick().await; + /// # } + /// ``` + /// + /// This is the default behavior when [`Interval`] is created with + /// [`interval`] and [`interval_at`]. + /// + /// [`Delay`]: MissedTickBehavior::Delay + /// [`Skip`]: MissedTickBehavior::Skip + Burst, + + /// Tick at multiples of `period` from when [`tick`] was called, rather than + /// from `start`. + /// + /// When this strategy is used and [`Interval`] has missed a tick, instead + /// of scheduling ticks to fire at multiples of `period` from `start` (the + /// time when the first tick was fired), it schedules all future ticks to + /// happen at a regular `period` from the point when [`tick`] was called. + /// Unlike [`Burst`] and [`Skip`], ticks are not shortened, and they aren't + /// guaranteed to happen at a multiple of `period` from `start` any longer. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work -----| work -----| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration, MissedTickBehavior}; + /// # async fn task_that_takes_more_than_50_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + /// + /// task_that_takes_more_than_50_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // But this one, rather than also resolving immediately, as might happen + /// // with the `Burst` or `Skip` behaviors, will not resolve until + /// // 50ms after the call to `tick` up above. That is, in `tick`, when we + /// // recognize that we missed a tick, we schedule the next tick to happen + /// // 50ms (or whatever the `period` is) from right then, not from when + /// // were were *supposed* to tick + /// interval.tick().await; + /// # } + /// ``` + /// + /// [`Burst`]: MissedTickBehavior::Burst + /// [`Skip`]: MissedTickBehavior::Skip + /// [`tick`]: Interval::tick + Delay, + + /// Skips missed ticks and tick on the next multiple of `period` from + /// `start`. + /// + /// When this strategy is used, [`Interval`] schedules the next tick to fire + /// at the next-closest tick that is a multiple of `period` away from + /// `start` (the point where [`Interval`] first ticked). Like [`Burst`], all + /// ticks remain multiples of `period` away from `start`, but unlike + /// [`Burst`], the ticks may not be *one* multiple of `period` away from the + /// last tick. Like [`Delay`], the ticks are no longer the same as they + /// would have been if ticks had not been missed, but unlike [`Delay`], and + /// like [`Burst`], the ticks may be shortened to be less than one `period` + /// away from each other. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work ---| work -----| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration, MissedTickBehavior}; + /// # async fn task_that_takes_75_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + /// + /// task_that_takes_75_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // This one will resolve after 25ms, 100ms after the start of + /// // `interval`, which is the closest multiple of `period` from the start + /// // of `interval` after the call to `tick` up above. + /// interval.tick().await; + /// # } + /// ``` + /// + /// [`Burst`]: MissedTickBehavior::Burst + /// [`Delay`]: MissedTickBehavior::Delay + Skip, +} + +impl MissedTickBehavior { + /// If a tick is missed, this method is called to determine when the next tick should happen. + fn next_timeout(&self, timeout: Instant, now: Instant, period: Duration) -> Instant { + match self { + Self::Burst => timeout + period, + Self::Delay => now + period, + Self::Skip => { + now + period + - Duration::from_nanos( + ((now - timeout).as_nanos() % period.as_nanos()) + .try_into() + // This operation is practically guaranteed not to + // fail, as in order for it to fail, `period` would + // have to be longer than `now - timeout`, and both + // would have to be longer than 584 years. + // + // If it did fail, there's not a good way to pass + // the error along to the user, so we just panic. + .expect( + "too much time has elapsed since the interval was supposed to tick", + ), + ) + } + } + } +} + +impl Default for MissedTickBehavior { + /// Returns [`MissedTickBehavior::Burst`]. + /// + /// For most usecases, the [`Burst`] strategy is what is desired. + /// Additionally, to preserve backwards compatibility, the [`Burst`] + /// strategy must be the default. For these reasons, + /// [`MissedTickBehavior::Burst`] is the default for [`MissedTickBehavior`]. + /// See [`Burst`] for more details. + /// + /// [`Burst`]: MissedTickBehavior::Burst + fn default() -> Self { + Self::Burst + } +} + +/// Interval returned by [`interval`] and [`interval_at`]. +/// +/// This type allows you to wait on a sequence of instants with a certain +/// duration between each instant. Unlike calling [`sleep`] in a loop, this lets +/// you count the time spent between the calls to [`sleep`] as well. +/// +/// An `Interval` can be turned into a `Stream` with [`IntervalStream`]. +/// +/// [`IntervalStream`]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.IntervalStream.html +/// [`sleep`]: crate::time::sleep +#[derive(Debug)] +pub struct Interval { + /// Future that completes the next time the `Interval` yields a value. + delay: Pin<Box<Sleep>>, + + /// The duration between values yielded by `Interval`. + period: Duration, + + /// The strategy `Interval` should use when a tick is missed. + missed_tick_behavior: MissedTickBehavior, + + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, +} + +impl Interval { + /// Completes when the next instant in the interval has been reached. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. If `tick` is used as the branch in a `tokio::select!` and + /// another branch completes first, then no tick has been consumed. + /// + /// # Examples + /// + /// ``` + /// use tokio::time; + /// + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut interval = time::interval(Duration::from_millis(10)); + /// + /// interval.tick().await; + /// interval.tick().await; + /// interval.tick().await; + /// + /// // approximately 20ms have elapsed. + /// } + /// ``` + pub async fn tick(&mut self) -> Instant { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let instant = trace::async_op( + || poll_fn(|cx| self.poll_tick(cx)), + resource_span, + "Interval::tick", + "poll_tick", + false, + ); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let instant = poll_fn(|cx| self.poll_tick(cx)); + + instant.await + } + + /// Polls for the next instant in the interval to be reached. + /// + /// This method can return the following values: + /// + /// * `Poll::Pending` if the next instant has not yet been reached. + /// * `Poll::Ready(instant)` if the next instant has been reached. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when the instant has elapsed. Note that on multiple + /// calls to `poll_tick`, only the [`Waker`](std::task::Waker) from the + /// [`Context`] passed to the most recent call is scheduled to receive a + /// wakeup. + pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { + // Wait for the delay to be done + ready!(Pin::new(&mut self.delay).poll(cx)); + + // Get the time when we were scheduled to tick + let timeout = self.delay.deadline(); + + let now = Instant::now(); + + // If a tick was not missed, and thus we are being called before the + // next tick is due, just schedule the next tick normally, one `period` + // after `timeout` + // + // However, if a tick took excessively long and we are now behind, + // schedule the next tick according to how the user specified with + // `MissedTickBehavior` + let next = if now > timeout + Duration::from_millis(5) { + self.missed_tick_behavior + .next_timeout(timeout, now, self.period) + } else { + timeout + self.period + }; + + self.delay.as_mut().reset(next); + + // Return the time when we were scheduled to tick + Poll::Ready(timeout) + } + + /// Resets the interval to complete one period after the current time. + /// + /// This method ignores [`MissedTickBehavior`] strategy. + /// + /// # Examples + /// + /// ``` + /// use tokio::time; + /// + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut interval = time::interval(Duration::from_millis(100)); + /// + /// interval.tick().await; + /// + /// time::sleep(Duration::from_millis(50)).await; + /// interval.reset(); + /// + /// interval.tick().await; + /// interval.tick().await; + /// + /// // approximately 250ms have elapsed. + /// } + /// ``` + pub fn reset(&mut self) { + self.delay.as_mut().reset(Instant::now() + self.period); + } + + /// Returns the [`MissedTickBehavior`] strategy currently being used. + pub fn missed_tick_behavior(&self) -> MissedTickBehavior { + self.missed_tick_behavior + } + + /// Sets the [`MissedTickBehavior`] strategy that should be used. + pub fn set_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) { + self.missed_tick_behavior = behavior; + } + + /// Returns the period of the interval. + pub fn period(&self) -> Duration { + self.period + } +} diff --git a/third_party/rust/tokio/src/time/mod.rs b/third_party/rust/tokio/src/time/mod.rs new file mode 100644 index 0000000000..281990ef9a --- /dev/null +++ b/third_party/rust/tokio/src/time/mod.rs @@ -0,0 +1,114 @@ +//! Utilities for tracking time. +//! +//! This module provides a number of types for executing code after a set period +//! of time. +//! +//! * [`Sleep`] is a future that does no work and completes at a specific [`Instant`] +//! in time. +//! +//! * [`Interval`] is a stream yielding a value at a fixed period. It is +//! initialized with a [`Duration`] and repeatedly yields each time the duration +//! elapses. +//! +//! * [`Timeout`]: Wraps a future or stream, setting an upper bound to the amount +//! of time it is allowed to execute. If the future or stream does not +//! complete in time, then it is canceled and an error is returned. +//! +//! These types are sufficient for handling a large number of scenarios +//! involving time. +//! +//! These types must be used from within the context of the [`Runtime`](crate::runtime::Runtime). +//! +//! # Examples +//! +//! Wait 100ms and print "100 ms have elapsed" +//! +//! ``` +//! use std::time::Duration; +//! use tokio::time::sleep; +//! +//! #[tokio::main] +//! async fn main() { +//! sleep(Duration::from_millis(100)).await; +//! println!("100 ms have elapsed"); +//! } +//! ``` +//! +//! Require that an operation takes no more than 1s. +//! +//! ``` +//! use tokio::time::{timeout, Duration}; +//! +//! async fn long_future() { +//! // do work here +//! } +//! +//! # async fn dox() { +//! let res = timeout(Duration::from_secs(1), long_future()).await; +//! +//! if res.is_err() { +//! println!("operation timed out"); +//! } +//! # } +//! ``` +//! +//! A simple example using [`interval`] to execute a task every two seconds. +//! +//! The difference between [`interval`] and [`sleep`] is that an [`interval`] +//! measures the time since the last tick, which means that `.tick().await` may +//! wait for a shorter time than the duration specified for the interval +//! if some time has passed between calls to `.tick().await`. +//! +//! If the tick in the example below was replaced with [`sleep`], the task +//! would only be executed once every three seconds, and not every two +//! seconds. +//! +//! ``` +//! use tokio::time; +//! +//! async fn task_that_takes_a_second() { +//! println!("hello"); +//! time::sleep(time::Duration::from_secs(1)).await +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let mut interval = time::interval(time::Duration::from_secs(2)); +//! for _i in 0..5 { +//! interval.tick().await; +//! task_that_takes_a_second().await; +//! } +//! } +//! ``` +//! +//! [`interval`]: crate::time::interval() + +mod clock; +pub(crate) use self::clock::Clock; +#[cfg(feature = "test-util")] +pub use clock::{advance, pause, resume}; + +pub(crate) mod driver; + +#[doc(inline)] +pub use driver::sleep::{sleep, sleep_until, Sleep}; + +pub mod error; + +mod instant; +pub use self::instant::Instant; + +mod interval; +pub use interval::{interval, interval_at, Interval, MissedTickBehavior}; + +mod timeout; +#[doc(inline)] +pub use timeout::{timeout, timeout_at, Timeout}; + +#[cfg(test)] +#[cfg(not(loom))] +mod tests; + +// Re-export for convenience +#[doc(no_inline)] +pub use std::time::Duration; diff --git a/third_party/rust/tokio/src/time/tests/mod.rs b/third_party/rust/tokio/src/time/tests/mod.rs new file mode 100644 index 0000000000..35e1060aca --- /dev/null +++ b/third_party/rust/tokio/src/time/tests/mod.rs @@ -0,0 +1,22 @@ +mod test_sleep; + +use crate::time::{self, Instant}; +use std::time::Duration; + +fn assert_send<T: Send>() {} +fn assert_sync<T: Sync>() {} + +#[test] +fn registration_is_send_and_sync() { + use crate::time::Sleep; + + assert_send::<Sleep>(); + assert_sync::<Sleep>(); +} + +#[test] +#[should_panic] +fn sleep_is_eager() { + let when = Instant::now() + Duration::from_millis(100); + let _ = time::sleep_until(when); +} diff --git a/third_party/rust/tokio/src/time/tests/test_sleep.rs b/third_party/rust/tokio/src/time/tests/test_sleep.rs new file mode 100644 index 0000000000..77ca07e319 --- /dev/null +++ b/third_party/rust/tokio/src/time/tests/test_sleep.rs @@ -0,0 +1,443 @@ +//use crate::time::driver::{Driver, Entry, Handle}; + +/* +macro_rules! poll { + ($e:expr) => { + $e.enter(|cx, e| e.poll_elapsed(cx)) + }; +} + +#[test] +fn frozen_utility_returns_correct_advanced_duration() { + let clock = Clock::new(); + clock.pause(); + let start = clock.now(); + + clock.advance(ms(10)); + assert_eq!(clock.now() - start, ms(10)); +} + +#[test] +fn immediate_sleep() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + let when = clock.now(); + let mut e = task::spawn(sleep_until(&handle, when)); + + assert_ready_ok!(poll!(e)); + + assert_ok!(driver.park_timeout(Duration::from_millis(1000))); + + // The time has not advanced. The `turn` completed immediately. + assert_eq!(clock.now() - start, ms(1000)); +} + +#[test] +fn delayed_sleep_level_0() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + for &i in &[1, 10, 60] { + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, start + ms(i))); + + // The sleep instance has not elapsed. + assert_pending!(poll!(e)); + + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(i)); + + assert_ready_ok!(poll!(e)); + } +} + +#[test] +fn sub_ms_delayed_sleep() { + let (mut driver, clock, handle) = setup(); + + for _ in 0..5 { + let deadline = clock.now() + ms(1) + Duration::new(0, 1); + + let mut e = task::spawn(sleep_until(&handle, deadline)); + + assert_pending!(poll!(e)); + + assert_ok!(driver.park()); + assert_ready_ok!(poll!(e)); + + assert!(clock.now() >= deadline); + + clock.advance(Duration::new(0, 1)); + } +} + +#[test] +fn delayed_sleep_wrapping_level_0() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + assert_ok!(driver.park_timeout(ms(5))); + assert_eq!(clock.now() - start, ms(5)); + + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(60))); + + assert_pending!(poll!(e)); + + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(64)); + assert_pending!(poll!(e)); + + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(65)); + + assert_ready_ok!(poll!(e)); +} + +#[test] +fn timer_wrapping_with_higher_levels() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Set sleep to hit level 1 + let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(64))); + assert_pending!(poll!(e1)); + + // Turn a bit + assert_ok!(driver.park_timeout(ms(5))); + + // Set timeout such that it will hit level 0, but wrap + let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(60))); + assert_pending!(poll!(e2)); + + // This should result in s1 firing + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(64)); + + assert_ready_ok!(poll!(e1)); + assert_pending!(poll!(e2)); + + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(65)); + + assert_ready_ok!(poll!(e1)); +} + +#[test] +fn sleep_with_deadline_in_past() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Create `Sleep` that elapsed immediately. + let mut e = task::spawn(sleep_until(&handle, clock.now() - ms(100))); + + // Even though the `Sleep` expires in the past, it is not ready yet + // because the timer must observe it. + assert_ready_ok!(poll!(e)); + + // Turn the timer, it runs for the elapsed time + assert_ok!(driver.park_timeout(ms(1000))); + + // The time has not advanced. The `turn` completed immediately. + assert_eq!(clock.now() - start, ms(1000)); +} + +#[test] +fn delayed_sleep_level_1() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(234))); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + // Turn the timer, this will wake up to cascade the timer down. + assert_ok!(driver.park_timeout(ms(1000))); + assert_eq!(clock.now() - start, ms(192)); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + // Turn the timer again + assert_ok!(driver.park_timeout(ms(1000))); + assert_eq!(clock.now() - start, ms(234)); + + // The sleep has elapsed. + assert_ready_ok!(poll!(e)); + + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(234))); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + // Turn the timer with a smaller timeout than the cascade. + assert_ok!(driver.park_timeout(ms(100))); + assert_eq!(clock.now() - start, ms(100)); + + assert_pending!(poll!(e)); + + // Turn the timer, this will wake up to cascade the timer down. + assert_ok!(driver.park_timeout(ms(1000))); + assert_eq!(clock.now() - start, ms(192)); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + // Turn the timer again + assert_ok!(driver.park_timeout(ms(1000))); + assert_eq!(clock.now() - start, ms(234)); + + // The sleep has elapsed. + assert_ready_ok!(poll!(e)); +} + +#[test] +fn concurrently_set_two_timers_second_one_shorter() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(500))); + let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(200))); + + // The sleep has not elapsed + assert_pending!(poll!(e1)); + assert_pending!(poll!(e2)); + + // Sleep until a cascade + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(192)); + + // Sleep until the second timer. + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(200)); + + // The shorter sleep fires + assert_ready_ok!(poll!(e2)); + assert_pending!(poll!(e1)); + + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(448)); + + assert_pending!(poll!(e1)); + + // Turn again, this time the time will advance to the second sleep + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(500)); + + assert_ready_ok!(poll!(e1)); +} + +#[test] +fn short_sleep() { + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(1))); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + // Turn the timer, but not enough time will go by. + assert_ok!(driver.park()); + + // The sleep has elapsed. + assert_ready_ok!(poll!(e)); + + // The time has advanced to the point of the sleep elapsing. + assert_eq!(clock.now() - start, ms(1)); +} + +#[test] +fn sorta_long_sleep_until() { + const MIN_5: u64 = 5 * 60 * 1000; + + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(MIN_5))); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + let cascades = &[262_144, 262_144 + 9 * 4096, 262_144 + 9 * 4096 + 15 * 64]; + + for &elapsed in cascades { + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(elapsed)); + + assert_pending!(poll!(e)); + } + + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(MIN_5)); + + // The sleep has elapsed. + assert_ready_ok!(poll!(e)); +} + +#[test] +fn very_long_sleep() { + const MO_5: u64 = 5 * 30 * 24 * 60 * 60 * 1000; + + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(MO_5))); + + // The sleep has not elapsed. + assert_pending!(poll!(e)); + + let cascades = &[ + 12_884_901_888, + 12_952_010_752, + 12_959_875_072, + 12_959_997_952, + ]; + + for &elapsed in cascades { + assert_ok!(driver.park()); + assert_eq!(clock.now() - start, ms(elapsed)); + + assert_pending!(poll!(e)); + } + + // Turn the timer, but not enough time will go by. + assert_ok!(driver.park()); + + // The time has advanced to the point of the sleep elapsing. + assert_eq!(clock.now() - start, ms(MO_5)); + + // The sleep has elapsed. + assert_ready_ok!(poll!(e)); +} + +#[test] +fn unpark_is_delayed() { + // A special park that will take much longer than the requested duration + struct MockPark(Clock); + + struct MockUnpark; + + impl Park for MockPark { + type Unpark = MockUnpark; + type Error = (); + + fn unpark(&self) -> Self::Unpark { + MockUnpark + } + + fn park(&mut self) -> Result<(), Self::Error> { + panic!("parking forever"); + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + assert_eq!(duration, ms(0)); + self.0.advance(ms(436)); + Ok(()) + } + + fn shutdown(&mut self) {} + } + + impl Unpark for MockUnpark { + fn unpark(&self) {} + } + + let clock = Clock::new(); + clock.pause(); + let start = clock.now(); + let mut driver = Driver::new(MockPark(clock.clone()), clock.clone()); + let handle = driver.handle(); + + let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(100))); + let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(101))); + let mut e3 = task::spawn(sleep_until(&handle, clock.now() + ms(200))); + + assert_pending!(poll!(e1)); + assert_pending!(poll!(e2)); + assert_pending!(poll!(e3)); + + assert_ok!(driver.park()); + + assert_eq!(clock.now() - start, ms(500)); + + assert_ready_ok!(poll!(e1)); + assert_ready_ok!(poll!(e2)); + assert_ready_ok!(poll!(e3)); +} + +#[test] +fn set_timeout_at_deadline_greater_than_max_timer() { + const YR_1: u64 = 365 * 24 * 60 * 60 * 1000; + const YR_5: u64 = 5 * YR_1; + + let (mut driver, clock, handle) = setup(); + let start = clock.now(); + + for _ in 0..5 { + assert_ok!(driver.park_timeout(ms(YR_1))); + } + + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(1))); + assert_pending!(poll!(e)); + + assert_ok!(driver.park_timeout(ms(1000))); + assert_eq!(clock.now() - start, ms(YR_5) + ms(1)); + + assert_ready_ok!(poll!(e)); +} + +fn setup() -> (Driver<MockPark>, Clock, Handle) { + let clock = Clock::new(); + clock.pause(); + let driver = Driver::new(MockPark(clock.clone()), clock.clone()); + let handle = driver.handle(); + + (driver, clock, handle) +} + +fn sleep_until(handle: &Handle, when: Instant) -> Arc<Entry> { + Entry::new(&handle, when, ms(0)) +} + +struct MockPark(Clock); + +struct MockUnpark; + +impl Park for MockPark { + type Unpark = MockUnpark; + type Error = (); + + fn unpark(&self) -> Self::Unpark { + MockUnpark + } + + fn park(&mut self) -> Result<(), Self::Error> { + panic!("parking forever"); + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.0.advance(duration); + Ok(()) + } + + fn shutdown(&mut self) {} +} + +impl Unpark for MockUnpark { + fn unpark(&self) {} +} + +fn ms(n: u64) -> Duration { + Duration::from_millis(n) +} +*/ diff --git a/third_party/rust/tokio/src/time/timeout.rs b/third_party/rust/tokio/src/time/timeout.rs new file mode 100644 index 0000000000..4a93089e8e --- /dev/null +++ b/third_party/rust/tokio/src/time/timeout.rs @@ -0,0 +1,202 @@ +//! Allows a future to execute for a maximum amount of time. +//! +//! See [`Timeout`] documentation for more details. +//! +//! [`Timeout`]: struct@Timeout + +use crate::{ + coop, + time::{error::Elapsed, sleep_until, Duration, Instant, Sleep}, + util::trace, +}; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll}; + +/// Requires a `Future` to complete before the specified duration has elapsed. +/// +/// If the future completes before the duration has elapsed, then the completed +/// value is returned. Otherwise, an error is returned and the future is +/// canceled. +/// +/// # Cancelation +/// +/// Cancelling a timeout is done by dropping the future. No additional cleanup +/// or other work is required. +/// +/// The original future may be obtained by calling [`Timeout::into_inner`]. This +/// consumes the `Timeout`. +/// +/// # Examples +/// +/// Create a new `Timeout` set to expire in 10 milliseconds. +/// +/// ```rust +/// use tokio::time::timeout; +/// use tokio::sync::oneshot; +/// +/// use std::time::Duration; +/// +/// # async fn dox() { +/// let (tx, rx) = oneshot::channel(); +/// # tx.send(()).unwrap(); +/// +/// // Wrap the future with a `Timeout` set to expire in 10 milliseconds. +/// if let Err(_) = timeout(Duration::from_millis(10), rx).await { +/// println!("did not receive value within 10 ms"); +/// } +/// # } +/// ``` +/// +/// # Panics +/// +/// This function panics if there is no current timer set. +/// +/// It can be triggered when [`Builder::enable_time`] or +/// [`Builder::enable_all`] are not included in the builder. +/// +/// It can also panic whenever a timer is created outside of a +/// Tokio runtime. That is why `rt.block_on(sleep(...))` will panic, +/// since the function is executed outside of the runtime. +/// Whereas `rt.block_on(async {sleep(...).await})` doesn't panic. +/// And this is because wrapping the function on an async makes it lazy, +/// and so gets executed inside the runtime successfully without +/// panicking. +/// +/// [`Builder::enable_time`]: crate::runtime::Builder::enable_time +/// [`Builder::enable_all`]: crate::runtime::Builder::enable_all +#[track_caller] +pub fn timeout<T>(duration: Duration, future: T) -> Timeout<T> +where + T: Future, +{ + let location = trace::caller_location(); + + let deadline = Instant::now().checked_add(duration); + let delay = match deadline { + Some(deadline) => Sleep::new_timeout(deadline, location), + None => Sleep::far_future(location), + }; + Timeout::new_with_delay(future, delay) +} + +/// Requires a `Future` to complete before the specified instant in time. +/// +/// If the future completes before the instant is reached, then the completed +/// value is returned. Otherwise, an error is returned. +/// +/// # Cancelation +/// +/// Cancelling a timeout is done by dropping the future. No additional cleanup +/// or other work is required. +/// +/// The original future may be obtained by calling [`Timeout::into_inner`]. This +/// consumes the `Timeout`. +/// +/// # Examples +/// +/// Create a new `Timeout` set to expire in 10 milliseconds. +/// +/// ```rust +/// use tokio::time::{Instant, timeout_at}; +/// use tokio::sync::oneshot; +/// +/// use std::time::Duration; +/// +/// # async fn dox() { +/// let (tx, rx) = oneshot::channel(); +/// # tx.send(()).unwrap(); +/// +/// // Wrap the future with a `Timeout` set to expire 10 milliseconds into the +/// // future. +/// if let Err(_) = timeout_at(Instant::now() + Duration::from_millis(10), rx).await { +/// println!("did not receive value within 10 ms"); +/// } +/// # } +/// ``` +pub fn timeout_at<T>(deadline: Instant, future: T) -> Timeout<T> +where + T: Future, +{ + let delay = sleep_until(deadline); + + Timeout { + value: future, + delay, + } +} + +pin_project! { + /// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at). + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[derive(Debug)] + pub struct Timeout<T> { + #[pin] + value: T, + #[pin] + delay: Sleep, + } +} + +impl<T> Timeout<T> { + pub(crate) fn new_with_delay(value: T, delay: Sleep) -> Timeout<T> { + Timeout { value, delay } + } + + /// Gets a reference to the underlying value in this timeout. + pub fn get_ref(&self) -> &T { + &self.value + } + + /// Gets a mutable reference to the underlying value in this timeout. + pub fn get_mut(&mut self) -> &mut T { + &mut self.value + } + + /// Consumes this timeout, returning the underlying value. + pub fn into_inner(self) -> T { + self.value + } +} + +impl<T> Future for Timeout<T> +where + T: Future, +{ + type Output = Result<T::Output, Elapsed>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + let had_budget_before = coop::has_budget_remaining(); + + // First, try polling the future + if let Poll::Ready(v) = me.value.poll(cx) { + return Poll::Ready(Ok(v)); + } + + let has_budget_now = coop::has_budget_remaining(); + + let delay = me.delay; + + let poll_delay = || -> Poll<Self::Output> { + match delay.poll(cx) { + Poll::Ready(()) => Poll::Ready(Err(Elapsed::new())), + Poll::Pending => Poll::Pending, + } + }; + + if let (true, false) = (had_budget_before, has_budget_now) { + // if it is the underlying future that exhausted the budget, we poll + // the `delay` with an unconstrained one. This prevents pathological + // cases where the underlying future always exhausts the budget and + // we never get a chance to evaluate whether the timeout was hit or + // not. + coop::with_unconstrained(poll_delay) + } else { + poll_delay() + } + } +} diff --git a/third_party/rust/tokio/src/util/atomic_cell.rs b/third_party/rust/tokio/src/util/atomic_cell.rs new file mode 100644 index 0000000000..07e37303a7 --- /dev/null +++ b/third_party/rust/tokio/src/util/atomic_cell.rs @@ -0,0 +1,51 @@ +use crate::loom::sync::atomic::AtomicPtr; + +use std::ptr; +use std::sync::atomic::Ordering::AcqRel; + +pub(crate) struct AtomicCell<T> { + data: AtomicPtr<T>, +} + +unsafe impl<T: Send> Send for AtomicCell<T> {} +unsafe impl<T: Send> Sync for AtomicCell<T> {} + +impl<T> AtomicCell<T> { + pub(crate) fn new(data: Option<Box<T>>) -> AtomicCell<T> { + AtomicCell { + data: AtomicPtr::new(to_raw(data)), + } + } + + pub(crate) fn swap(&self, val: Option<Box<T>>) -> Option<Box<T>> { + let old = self.data.swap(to_raw(val), AcqRel); + from_raw(old) + } + + pub(crate) fn set(&self, val: Box<T>) { + let _ = self.swap(Some(val)); + } + + pub(crate) fn take(&self) -> Option<Box<T>> { + self.swap(None) + } +} + +fn to_raw<T>(data: Option<Box<T>>) -> *mut T { + data.map(Box::into_raw).unwrap_or(ptr::null_mut()) +} + +fn from_raw<T>(val: *mut T) -> Option<Box<T>> { + if val.is_null() { + None + } else { + Some(unsafe { Box::from_raw(val) }) + } +} + +impl<T> Drop for AtomicCell<T> { + fn drop(&mut self) { + // Free any data still held by the cell + let _ = self.take(); + } +} diff --git a/third_party/rust/tokio/src/util/bit.rs b/third_party/rust/tokio/src/util/bit.rs new file mode 100644 index 0000000000..a43c2c2d36 --- /dev/null +++ b/third_party/rust/tokio/src/util/bit.rs @@ -0,0 +1,77 @@ +use std::fmt; + +#[derive(Clone, Copy, PartialEq)] +pub(crate) struct Pack { + mask: usize, + shift: u32, +} + +impl Pack { + /// Value is packed in the `width` least-significant bits. + pub(crate) const fn least_significant(width: u32) -> Pack { + let mask = mask_for(width); + + Pack { mask, shift: 0 } + } + + /// Value is packed in the `width` more-significant bits. + pub(crate) const fn then(&self, width: u32) -> Pack { + let shift = pointer_width() - self.mask.leading_zeros(); + let mask = mask_for(width) << shift; + + Pack { mask, shift } + } + + /// Width, in bits, dedicated to storing the value. + pub(crate) const fn width(&self) -> u32 { + pointer_width() - (self.mask >> self.shift).leading_zeros() + } + + /// Max representable value. + pub(crate) const fn max_value(&self) -> usize { + (1 << self.width()) - 1 + } + + pub(crate) fn pack(&self, value: usize, base: usize) -> usize { + assert!(value <= self.max_value()); + (base & !self.mask) | (value << self.shift) + } + + /// Packs the value with `base`, losing any bits of `value` that fit. + /// + /// If `value` is larger than the max value that can be represented by the + /// allotted width, the most significant bits are truncated. + pub(crate) fn pack_lossy(&self, value: usize, base: usize) -> usize { + self.pack(value & self.max_value(), base) + } + + pub(crate) fn unpack(&self, src: usize) -> usize { + unpack(src, self.mask, self.shift) + } +} + +impl fmt::Debug for Pack { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "Pack {{ mask: {:b}, shift: {} }}", + self.mask, self.shift + ) + } +} + +/// Returns the width of a pointer in bits. +pub(crate) const fn pointer_width() -> u32 { + std::mem::size_of::<usize>() as u32 * 8 +} + +/// Returns a `usize` with the right-most `n` bits set. +pub(crate) const fn mask_for(n: u32) -> usize { + let shift = 1usize.wrapping_shl(n - 1); + shift | (shift - 1) +} + +/// Unpacks a value using a mask & shift. +pub(crate) const fn unpack(src: usize, mask: usize, shift: u32) -> usize { + (src & mask) >> shift +} diff --git a/third_party/rust/tokio/src/util/error.rs b/third_party/rust/tokio/src/util/error.rs new file mode 100644 index 0000000000..8f252c0c91 --- /dev/null +++ b/third_party/rust/tokio/src/util/error.rs @@ -0,0 +1,17 @@ +/// Error string explaining that the Tokio context hasn't been instantiated. +pub(crate) const CONTEXT_MISSING_ERROR: &str = + "there is no reactor running, must be called from the context of a Tokio 1.x runtime"; + +// some combinations of features might not use this +#[allow(dead_code)] +/// Error string explaining that the Tokio context is shutting down and cannot drive timers. +pub(crate) const RUNTIME_SHUTTING_DOWN_ERROR: &str = + "A Tokio 1.x context was found, but it is being shutdown."; + +// some combinations of features might not use this +#[allow(dead_code)] +/// Error string explaining that the Tokio context is not available because the +/// thread-local storing it has been destroyed. This usually only happens during +/// destructors of other thread-locals. +pub(crate) const THREAD_LOCAL_DESTROYED_ERROR: &str = + "The Tokio context thread-local variable has been destroyed."; diff --git a/third_party/rust/tokio/src/util/idle_notified_set.rs b/third_party/rust/tokio/src/util/idle_notified_set.rs new file mode 100644 index 0000000000..71f3a32a85 --- /dev/null +++ b/third_party/rust/tokio/src/util/idle_notified_set.rs @@ -0,0 +1,463 @@ +//! This module defines an `IdleNotifiedSet`, which is a collection of elements. +//! Each element is intended to correspond to a task, and the collection will +//! keep track of which tasks have had their waker notified, and which have not. +//! +//! Each entry in the set holds some user-specified value. The value's type is +//! specified using the `T` parameter. It will usually be a `JoinHandle` or +//! similar. + +use std::marker::PhantomPinned; +use std::mem::ManuallyDrop; +use std::ptr::NonNull; +use std::task::{Context, Waker}; + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::{Arc, Mutex}; +use crate::util::linked_list::{self, Link}; +use crate::util::{waker_ref, Wake}; + +type LinkedList<T> = + linked_list::LinkedList<ListEntry<T>, <ListEntry<T> as linked_list::Link>::Target>; + +/// This is the main handle to the collection. +pub(crate) struct IdleNotifiedSet<T> { + lists: Arc<Lists<T>>, + length: usize, +} + +/// A handle to an entry that is guaranteed to be stored in the idle or notified +/// list of its `IdleNotifiedSet`. This value borrows the `IdleNotifiedSet` +/// mutably to prevent the entry from being moved to the `Neither` list, which +/// only the `IdleNotifiedSet` may do. +/// +/// The main consequence of being stored in one of the lists is that the `value` +/// field has not yet been consumed. +/// +/// Note: This entry can be moved from the idle to the notified list while this +/// object exists by waking its waker. +pub(crate) struct EntryInOneOfTheLists<'a, T> { + entry: Arc<ListEntry<T>>, + set: &'a mut IdleNotifiedSet<T>, +} + +type Lists<T> = Mutex<ListsInner<T>>; + +/// The linked lists hold strong references to the ListEntry items, and the +/// ListEntry items also hold a strong reference back to the Lists object, but +/// the destructor of the `IdleNotifiedSet` will clear the two lists, so once +/// that object is destroyed, no ref-cycles will remain. +struct ListsInner<T> { + notified: LinkedList<T>, + idle: LinkedList<T>, + /// Whenever an element in the `notified` list is woken, this waker will be + /// notified and consumed, if it exists. + waker: Option<Waker>, +} + +/// Which of the two lists in the shared Lists object is this entry stored in? +/// +/// If the value is `Idle`, then an entry's waker may move it to the notified +/// list. Otherwise, only the `IdleNotifiedSet` may move it. +/// +/// If the value is `Neither`, then it is still possible that the entry is in +/// some third external list (this happens in `drain`). +#[derive(Copy, Clone, Eq, PartialEq)] +enum List { + Notified, + Idle, + Neither, +} + +/// An entry in the list. +/// +/// # Safety +/// +/// The `my_list` field must only be accessed while holding the mutex in +/// `parent`. It is an invariant that the value of `my_list` corresponds to +/// which linked list in the `parent` holds this entry. Once this field takes +/// the value `Neither`, then it may never be modified again. +/// +/// If the value of `my_list` is `Notified` or `Idle`, then the `pointers` field +/// must only be accessed while holding the mutex. If the value of `my_list` is +/// `Neither`, then the `pointers` field may be accessed by the +/// `IdleNotifiedSet` (this happens inside `drain`). +/// +/// The `value` field is owned by the `IdleNotifiedSet` and may only be accessed +/// by the `IdleNotifiedSet`. The operation that sets the value of `my_list` to +/// `Neither` assumes ownership of the `value`, and it must either drop it or +/// move it out from this entry to prevent it from getting leaked. (Since the +/// two linked lists are emptied in the destructor of `IdleNotifiedSet`, the +/// value should not be leaked.) +/// +/// This type is `#[repr(C)]` because its `linked_list::Link` implementation +/// requires that `pointers` is the first field. +#[repr(C)] +struct ListEntry<T> { + /// The linked list pointers of the list this entry is in. + pointers: linked_list::Pointers<ListEntry<T>>, + /// Pointer to the shared `Lists` struct. + parent: Arc<Lists<T>>, + /// The value stored in this entry. + value: UnsafeCell<ManuallyDrop<T>>, + /// Used to remember which list this entry is in. + my_list: UnsafeCell<List>, + /// Required by the `linked_list::Pointers` field. + _pin: PhantomPinned, +} + +// With mutable access to the `IdleNotifiedSet`, you can get mutable access to +// the values. +unsafe impl<T: Send> Send for IdleNotifiedSet<T> {} +// With the current API we strictly speaking don't even need `T: Sync`, but we +// require it anyway to support adding &self APIs that access the values in the +// future. +unsafe impl<T: Sync> Sync for IdleNotifiedSet<T> {} + +// These impls control when it is safe to create a Waker. Since the waker does +// not allow access to the value in any way (including its destructor), it is +// not necessary for `T` to be Send or Sync. +unsafe impl<T> Send for ListEntry<T> {} +unsafe impl<T> Sync for ListEntry<T> {} + +impl<T> IdleNotifiedSet<T> { + /// Create a new IdleNotifiedSet. + pub(crate) fn new() -> Self { + let lists = Mutex::new(ListsInner { + notified: LinkedList::new(), + idle: LinkedList::new(), + waker: None, + }); + + IdleNotifiedSet { + lists: Arc::new(lists), + length: 0, + } + } + + pub(crate) fn len(&self) -> usize { + self.length + } + + pub(crate) fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Insert the given value into the `idle` list. + pub(crate) fn insert_idle(&mut self, value: T) -> EntryInOneOfTheLists<'_, T> { + self.length += 1; + + let entry = Arc::new(ListEntry { + parent: self.lists.clone(), + value: UnsafeCell::new(ManuallyDrop::new(value)), + my_list: UnsafeCell::new(List::Idle), + pointers: linked_list::Pointers::new(), + _pin: PhantomPinned, + }); + + { + let mut lock = self.lists.lock(); + lock.idle.push_front(entry.clone()); + } + + // Safety: We just put the entry in the idle list, so it is in one of the lists. + EntryInOneOfTheLists { entry, set: self } + } + + /// Pop an entry from the notified list to poll it. The entry is moved to + /// the idle list atomically. + pub(crate) fn pop_notified(&mut self, waker: &Waker) -> Option<EntryInOneOfTheLists<'_, T>> { + // We don't decrement the length because this call moves the entry to + // the idle list rather than removing it. + if self.length == 0 { + // Fast path. + return None; + } + + let mut lock = self.lists.lock(); + + let should_update_waker = match lock.waker.as_mut() { + Some(cur_waker) => !waker.will_wake(cur_waker), + None => true, + }; + if should_update_waker { + lock.waker = Some(waker.clone()); + } + + // Pop the entry, returning None if empty. + let entry = lock.notified.pop_back()?; + + lock.idle.push_front(entry.clone()); + + // Safety: We are holding the lock. + entry.my_list.with_mut(|ptr| unsafe { + *ptr = List::Idle; + }); + + drop(lock); + + // Safety: We just put the entry in the idle list, so it is in one of the lists. + Some(EntryInOneOfTheLists { entry, set: self }) + } + + /// Call a function on every element in this list. + pub(crate) fn for_each<F: FnMut(&mut T)>(&mut self, mut func: F) { + fn get_ptrs<T>(list: &mut LinkedList<T>, ptrs: &mut Vec<*mut T>) { + let mut node = list.last(); + + while let Some(entry) = node { + ptrs.push(entry.value.with_mut(|ptr| { + let ptr: *mut ManuallyDrop<T> = ptr; + let ptr: *mut T = ptr.cast(); + ptr + })); + + let prev = entry.pointers.get_prev(); + node = prev.map(|prev| unsafe { &*prev.as_ptr() }); + } + } + + // Atomically get a raw pointer to the value of every entry. + // + // Since this only locks the mutex once, it is not possible for a value + // to get moved from the idle list to the notified list during the + // operation, which would otherwise result in some value being listed + // twice. + let mut ptrs = Vec::with_capacity(self.len()); + { + let mut lock = self.lists.lock(); + + get_ptrs(&mut lock.idle, &mut ptrs); + get_ptrs(&mut lock.notified, &mut ptrs); + } + debug_assert_eq!(ptrs.len(), ptrs.capacity()); + + for ptr in ptrs { + // Safety: When we grabbed the pointers, the entries were in one of + // the two lists. This means that their value was valid at the time, + // and it must still be valid because we are the IdleNotifiedSet, + // and only we can remove an entry from the two lists. (It's + // possible that an entry is moved from one list to the other during + // this loop, but that is ok.) + func(unsafe { &mut *ptr }); + } + } + + /// Remove all entries in both lists, applying some function to each element. + /// + /// The closure is called on all elements even if it panics. Having it panic + /// twice is a double-panic, and will abort the application. + pub(crate) fn drain<F: FnMut(T)>(&mut self, func: F) { + if self.length == 0 { + // Fast path. + return; + } + self.length = 0; + + // The LinkedList is not cleared on panic, so we use a bomb to clear it. + // + // This value has the invariant that any entry in its `all_entries` list + // has `my_list` set to `Neither` and that the value has not yet been + // dropped. + struct AllEntries<T, F: FnMut(T)> { + all_entries: LinkedList<T>, + func: F, + } + + impl<T, F: FnMut(T)> AllEntries<T, F> { + fn pop_next(&mut self) -> bool { + if let Some(entry) = self.all_entries.pop_back() { + // Safety: We just took this value from the list, so we can + // destroy the value in the entry. + entry + .value + .with_mut(|ptr| unsafe { (self.func)(ManuallyDrop::take(&mut *ptr)) }); + true + } else { + false + } + } + } + + impl<T, F: FnMut(T)> Drop for AllEntries<T, F> { + fn drop(&mut self) { + while self.pop_next() {} + } + } + + let mut all_entries = AllEntries { + all_entries: LinkedList::new(), + func, + }; + + // Atomically move all entries to the new linked list in the AllEntries + // object. + { + let mut lock = self.lists.lock(); + unsafe { + // Safety: We are holding the lock and `all_entries` is a new + // LinkedList. + move_to_new_list(&mut lock.idle, &mut all_entries.all_entries); + move_to_new_list(&mut lock.notified, &mut all_entries.all_entries); + } + } + + // Keep destroying entries in the list until it is empty. + // + // If the closure panics, then the destructor of the `AllEntries` bomb + // ensures that we keep running the destructor on the remaining values. + // A second panic will abort the program. + while all_entries.pop_next() {} + } +} + +/// # Safety +/// +/// The mutex for the entries must be held, and the target list must be such +/// that setting `my_list` to `Neither` is ok. +unsafe fn move_to_new_list<T>(from: &mut LinkedList<T>, to: &mut LinkedList<T>) { + while let Some(entry) = from.pop_back() { + entry.my_list.with_mut(|ptr| { + *ptr = List::Neither; + }); + to.push_front(entry); + } +} + +impl<'a, T> EntryInOneOfTheLists<'a, T> { + /// Remove this entry from the list it is in, returning the value associated + /// with the entry. + /// + /// This consumes the value, since it is no longer guaranteed to be in a + /// list. + pub(crate) fn remove(self) -> T { + self.set.length -= 1; + + { + let mut lock = self.set.lists.lock(); + + // Safety: We are holding the lock so there is no race, and we will + // remove the entry afterwards to uphold invariants. + let old_my_list = self.entry.my_list.with_mut(|ptr| unsafe { + let old_my_list = *ptr; + *ptr = List::Neither; + old_my_list + }); + + let list = match old_my_list { + List::Idle => &mut lock.idle, + List::Notified => &mut lock.notified, + // An entry in one of the lists is in one of the lists. + List::Neither => unreachable!(), + }; + + unsafe { + // Safety: We just checked that the entry is in this particular + // list. + list.remove(ListEntry::as_raw(&self.entry)).unwrap(); + } + } + + // By setting `my_list` to `Neither`, we have taken ownership of the + // value. We return it to the caller. + // + // Safety: We have a mutable reference to the `IdleNotifiedSet` that + // owns this entry, so we can use its permission to access the value. + self.entry + .value + .with_mut(|ptr| unsafe { ManuallyDrop::take(&mut *ptr) }) + } + + /// Access the value in this entry together with a context for its waker. + pub(crate) fn with_value_and_context<F, U>(&mut self, func: F) -> U + where + F: FnOnce(&mut T, &mut Context<'_>) -> U, + T: 'static, + { + let waker = waker_ref(&self.entry); + + let mut context = Context::from_waker(&waker); + + // Safety: We have a mutable reference to the `IdleNotifiedSet` that + // owns this entry, so we can use its permission to access the value. + self.entry + .value + .with_mut(|ptr| unsafe { func(&mut *ptr, &mut context) }) + } +} + +impl<T> Drop for IdleNotifiedSet<T> { + fn drop(&mut self) { + // Clear both lists. + self.drain(drop); + + #[cfg(debug_assertions)] + if !std::thread::panicking() { + let lock = self.lists.lock(); + assert!(lock.idle.is_empty()); + assert!(lock.notified.is_empty()); + } + } +} + +impl<T: 'static> Wake for ListEntry<T> { + fn wake_by_ref(me: &Arc<Self>) { + let mut lock = me.parent.lock(); + + // Safety: We are holding the lock and we will update the lists to + // maintain invariants. + let old_my_list = me.my_list.with_mut(|ptr| unsafe { + let old_my_list = *ptr; + if old_my_list == List::Idle { + *ptr = List::Notified; + } + old_my_list + }); + + if old_my_list == List::Idle { + // We move ourself to the notified list. + let me = unsafe { + // Safety: We just checked that we are in this particular list. + lock.idle.remove(NonNull::from(&**me)).unwrap() + }; + lock.notified.push_front(me); + + if let Some(waker) = lock.waker.take() { + drop(lock); + waker.wake(); + } + } + } + + fn wake(me: Arc<Self>) { + Self::wake_by_ref(&me) + } +} + +/// # Safety +/// +/// `ListEntry` is forced to be !Unpin. +unsafe impl<T> linked_list::Link for ListEntry<T> { + type Handle = Arc<ListEntry<T>>; + type Target = ListEntry<T>; + + fn as_raw(handle: &Self::Handle) -> NonNull<ListEntry<T>> { + let ptr: *const ListEntry<T> = Arc::as_ptr(handle); + // Safety: We can't get a null pointer from `Arc::as_ptr`. + unsafe { NonNull::new_unchecked(ptr as *mut ListEntry<T>) } + } + + unsafe fn from_raw(ptr: NonNull<ListEntry<T>>) -> Arc<ListEntry<T>> { + Arc::from_raw(ptr.as_ptr()) + } + + unsafe fn pointers( + target: NonNull<ListEntry<T>>, + ) -> NonNull<linked_list::Pointers<ListEntry<T>>> { + // Safety: The pointers struct is the first field and ListEntry is + // `#[repr(C)]` so this cast is safe. + // + // We do this rather than doing a field access since `std::ptr::addr_of` + // is too new for our MSRV. + target.cast() + } +} diff --git a/third_party/rust/tokio/src/util/linked_list.rs b/third_party/rust/tokio/src/util/linked_list.rs new file mode 100644 index 0000000000..e6bdde68c7 --- /dev/null +++ b/third_party/rust/tokio/src/util/linked_list.rs @@ -0,0 +1,693 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + +//! An intrusive double linked list of data. +//! +//! The data structure supports tracking pinned nodes. Most of the data +//! structure's APIs are `unsafe` as they require the caller to ensure the +//! specified node is actually contained by the list. + +use core::cell::UnsafeCell; +use core::fmt; +use core::marker::{PhantomData, PhantomPinned}; +use core::mem::ManuallyDrop; +use core::ptr::{self, NonNull}; + +/// An intrusive linked list. +/// +/// Currently, the list is not emptied on drop. It is the caller's +/// responsibility to ensure the list is empty before dropping it. +pub(crate) struct LinkedList<L, T> { + /// Linked list head + head: Option<NonNull<T>>, + + /// Linked list tail + tail: Option<NonNull<T>>, + + /// Node type marker. + _marker: PhantomData<*const L>, +} + +unsafe impl<L: Link> Send for LinkedList<L, L::Target> where L::Target: Send {} +unsafe impl<L: Link> Sync for LinkedList<L, L::Target> where L::Target: Sync {} + +/// Defines how a type is tracked within a linked list. +/// +/// In order to support storing a single type within multiple lists, accessing +/// the list pointers is decoupled from the entry type. +/// +/// # Safety +/// +/// Implementations must guarantee that `Target` types are pinned in memory. In +/// other words, when a node is inserted, the value will not be moved as long as +/// it is stored in the list. +pub(crate) unsafe trait Link { + /// Handle to the list entry. + /// + /// This is usually a pointer-ish type. + type Handle; + + /// Node type. + type Target; + + /// Convert the handle to a raw pointer without consuming the handle. + #[allow(clippy::wrong_self_convention)] + fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target>; + + /// Convert the raw pointer to a handle + unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle; + + /// Return the pointers for a node + /// + /// # Safety + /// + /// The resulting pointer should have the same tag in the stacked-borrows + /// stack as the argument. In particular, the method may not create an + /// intermediate reference in the process of creating the resulting raw + /// pointer. + unsafe fn pointers(target: NonNull<Self::Target>) -> NonNull<Pointers<Self::Target>>; +} + +/// Previous / next pointers. +pub(crate) struct Pointers<T> { + inner: UnsafeCell<PointersInner<T>>, +} +/// We do not want the compiler to put the `noalias` attribute on mutable +/// references to this type, so the type has been made `!Unpin` with a +/// `PhantomPinned` field. +/// +/// Additionally, we never access the `prev` or `next` fields directly, as any +/// such access would implicitly involve the creation of a reference to the +/// field, which we want to avoid since the fields are not `!Unpin`, and would +/// hence be given the `noalias` attribute if we were to do such an access. +/// As an alternative to accessing the fields directly, the `Pointers` type +/// provides getters and setters for the two fields, and those are implemented +/// using raw pointer casts and offsets, which is valid since the struct is +/// #[repr(C)]. +/// +/// See this link for more information: +/// <https://github.com/rust-lang/rust/pull/82834> +#[repr(C)] +struct PointersInner<T> { + /// The previous node in the list. null if there is no previous node. + /// + /// This field is accessed through pointer manipulation, so it is not dead code. + #[allow(dead_code)] + prev: Option<NonNull<T>>, + + /// The next node in the list. null if there is no previous node. + /// + /// This field is accessed through pointer manipulation, so it is not dead code. + #[allow(dead_code)] + next: Option<NonNull<T>>, + + /// This type is !Unpin due to the heuristic from: + /// <https://github.com/rust-lang/rust/pull/82834> + _pin: PhantomPinned, +} + +unsafe impl<T: Send> Send for Pointers<T> {} +unsafe impl<T: Sync> Sync for Pointers<T> {} + +// ===== impl LinkedList ===== + +impl<L, T> LinkedList<L, T> { + /// Creates an empty linked list. + pub(crate) const fn new() -> LinkedList<L, T> { + LinkedList { + head: None, + tail: None, + _marker: PhantomData, + } + } +} + +impl<L: Link> LinkedList<L, L::Target> { + /// Adds an element first in the list. + pub(crate) fn push_front(&mut self, val: L::Handle) { + // The value should not be dropped, it is being inserted into the list + let val = ManuallyDrop::new(val); + let ptr = L::as_raw(&*val); + assert_ne!(self.head, Some(ptr)); + unsafe { + L::pointers(ptr).as_mut().set_next(self.head); + L::pointers(ptr).as_mut().set_prev(None); + + if let Some(head) = self.head { + L::pointers(head).as_mut().set_prev(Some(ptr)); + } + + self.head = Some(ptr); + + if self.tail.is_none() { + self.tail = Some(ptr); + } + } + } + + /// Removes the last element from a list and returns it, or None if it is + /// empty. + pub(crate) fn pop_back(&mut self) -> Option<L::Handle> { + unsafe { + let last = self.tail?; + self.tail = L::pointers(last).as_ref().get_prev(); + + if let Some(prev) = L::pointers(last).as_ref().get_prev() { + L::pointers(prev).as_mut().set_next(None); + } else { + self.head = None + } + + L::pointers(last).as_mut().set_prev(None); + L::pointers(last).as_mut().set_next(None); + + Some(L::from_raw(last)) + } + } + + /// Returns whether the linked list does not contain any node + pub(crate) fn is_empty(&self) -> bool { + if self.head.is_some() { + return false; + } + + assert!(self.tail.is_none()); + true + } + + /// Removes the specified node from the list + /// + /// # Safety + /// + /// The caller **must** ensure that `node` is currently contained by + /// `self` or not contained by any other list. + pub(crate) unsafe fn remove(&mut self, node: NonNull<L::Target>) -> Option<L::Handle> { + if let Some(prev) = L::pointers(node).as_ref().get_prev() { + debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); + L::pointers(prev) + .as_mut() + .set_next(L::pointers(node).as_ref().get_next()); + } else { + if self.head != Some(node) { + return None; + } + + self.head = L::pointers(node).as_ref().get_next(); + } + + if let Some(next) = L::pointers(node).as_ref().get_next() { + debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); + L::pointers(next) + .as_mut() + .set_prev(L::pointers(node).as_ref().get_prev()); + } else { + // This might be the last item in the list + if self.tail != Some(node) { + return None; + } + + self.tail = L::pointers(node).as_ref().get_prev(); + } + + L::pointers(node).as_mut().set_next(None); + L::pointers(node).as_mut().set_prev(None); + + Some(L::from_raw(node)) + } +} + +impl<L: Link> fmt::Debug for LinkedList<L, L::Target> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LinkedList") + .field("head", &self.head) + .field("tail", &self.tail) + .finish() + } +} + +#[cfg(any( + feature = "fs", + feature = "rt", + all(unix, feature = "process"), + feature = "signal", + feature = "sync", +))] +impl<L: Link> LinkedList<L, L::Target> { + pub(crate) fn last(&self) -> Option<&L::Target> { + let tail = self.tail.as_ref()?; + unsafe { Some(&*tail.as_ptr()) } + } +} + +impl<L: Link> Default for LinkedList<L, L::Target> { + fn default() -> Self { + Self::new() + } +} + +// ===== impl DrainFilter ===== + +cfg_io_readiness! { + pub(crate) struct DrainFilter<'a, T: Link, F> { + list: &'a mut LinkedList<T, T::Target>, + filter: F, + curr: Option<NonNull<T::Target>>, + } + + impl<T: Link> LinkedList<T, T::Target> { + pub(crate) fn drain_filter<F>(&mut self, filter: F) -> DrainFilter<'_, T, F> + where + F: FnMut(&mut T::Target) -> bool, + { + let curr = self.head; + DrainFilter { + curr, + filter, + list: self, + } + } + } + + impl<'a, T, F> Iterator for DrainFilter<'a, T, F> + where + T: Link, + F: FnMut(&mut T::Target) -> bool, + { + type Item = T::Handle; + + fn next(&mut self) -> Option<Self::Item> { + while let Some(curr) = self.curr { + // safety: the pointer references data contained by the list + self.curr = unsafe { T::pointers(curr).as_ref() }.get_next(); + + // safety: the value is still owned by the linked list. + if (self.filter)(unsafe { &mut *curr.as_ptr() }) { + return unsafe { self.list.remove(curr) }; + } + } + + None + } + } +} + +// ===== impl Pointers ===== + +impl<T> Pointers<T> { + /// Create a new set of empty pointers + pub(crate) fn new() -> Pointers<T> { + Pointers { + inner: UnsafeCell::new(PointersInner { + prev: None, + next: None, + _pin: PhantomPinned, + }), + } + } + + pub(crate) fn get_prev(&self) -> Option<NonNull<T>> { + // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *const Option<NonNull<T>>; + ptr::read(prev) + } + } + pub(crate) fn get_next(&self) -> Option<NonNull<T>> { + // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *const Option<NonNull<T>>; + let next = prev.add(1); + ptr::read(next) + } + } + + fn set_prev(&mut self, value: Option<NonNull<T>>) { + // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *mut Option<NonNull<T>>; + ptr::write(prev, value); + } + } + fn set_next(&mut self, value: Option<NonNull<T>>) { + // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *mut Option<NonNull<T>>; + let next = prev.add(1); + ptr::write(next, value); + } + } +} + +impl<T> fmt::Debug for Pointers<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let prev = self.get_prev(); + let next = self.get_next(); + f.debug_struct("Pointers") + .field("prev", &prev) + .field("next", &next) + .finish() + } +} + +#[cfg(test)] +#[cfg(not(loom))] +mod tests { + use super::*; + + use std::pin::Pin; + + #[derive(Debug)] + #[repr(C)] + struct Entry { + pointers: Pointers<Entry>, + val: i32, + } + + unsafe impl<'a> Link for &'a Entry { + type Handle = Pin<&'a Entry>; + type Target = Entry; + + fn as_raw(handle: &Pin<&'_ Entry>) -> NonNull<Entry> { + NonNull::from(handle.get_ref()) + } + + unsafe fn from_raw(ptr: NonNull<Entry>) -> Pin<&'a Entry> { + Pin::new_unchecked(&*ptr.as_ptr()) + } + + unsafe fn pointers(target: NonNull<Entry>) -> NonNull<Pointers<Entry>> { + target.cast() + } + } + + fn entry(val: i32) -> Pin<Box<Entry>> { + Box::pin(Entry { + pointers: Pointers::new(), + val, + }) + } + + fn ptr(r: &Pin<Box<Entry>>) -> NonNull<Entry> { + r.as_ref().get_ref().into() + } + + fn collect_list(list: &mut LinkedList<&'_ Entry, <&'_ Entry as Link>::Target>) -> Vec<i32> { + let mut ret = vec![]; + + while let Some(entry) = list.pop_back() { + ret.push(entry.val); + } + + ret + } + + fn push_all<'a>( + list: &mut LinkedList<&'a Entry, <&'_ Entry as Link>::Target>, + entries: &[Pin<&'a Entry>], + ) { + for entry in entries.iter() { + list.push_front(*entry); + } + } + + macro_rules! assert_clean { + ($e:ident) => {{ + assert!($e.pointers.get_next().is_none()); + assert!($e.pointers.get_prev().is_none()); + }}; + } + + macro_rules! assert_ptr_eq { + ($a:expr, $b:expr) => {{ + // Deal with mapping a Pin<&mut T> -> Option<NonNull<T>> + assert_eq!(Some($a.as_ref().get_ref().into()), $b) + }}; + } + + #[test] + fn const_new() { + const _: LinkedList<&Entry, <&Entry as Link>::Target> = LinkedList::new(); + } + + #[test] + fn push_and_drain() { + let a = entry(5); + let b = entry(7); + let c = entry(31); + + let mut list = LinkedList::new(); + assert!(list.is_empty()); + + list.push_front(a.as_ref()); + assert!(!list.is_empty()); + list.push_front(b.as_ref()); + list.push_front(c.as_ref()); + + let items: Vec<i32> = collect_list(&mut list); + assert_eq!([5, 7, 31].to_vec(), items); + + assert!(list.is_empty()); + } + + #[test] + fn push_pop_push_pop() { + let a = entry(5); + let b = entry(7); + + let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); + + list.push_front(a.as_ref()); + + let entry = list.pop_back().unwrap(); + assert_eq!(5, entry.val); + assert!(list.is_empty()); + + list.push_front(b.as_ref()); + + let entry = list.pop_back().unwrap(); + assert_eq!(7, entry.val); + + assert!(list.is_empty()); + assert!(list.pop_back().is_none()); + } + + #[test] + fn remove_by_address() { + let a = entry(5); + let b = entry(7); + let c = entry(31); + + unsafe { + // Remove first + let mut list = LinkedList::new(); + + push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); + assert!(list.remove(ptr(&a)).is_some()); + assert_clean!(a); + // `a` should be no longer there and can't be removed twice + assert!(list.remove(ptr(&a)).is_none()); + assert!(!list.is_empty()); + + assert!(list.remove(ptr(&b)).is_some()); + assert_clean!(b); + // `b` should be no longer there and can't be removed twice + assert!(list.remove(ptr(&b)).is_none()); + assert!(!list.is_empty()); + + assert!(list.remove(ptr(&c)).is_some()); + assert_clean!(c); + // `b` should be no longer there and can't be removed twice + assert!(list.remove(ptr(&c)).is_none()); + assert!(list.is_empty()); + } + + unsafe { + // Remove middle + let mut list = LinkedList::new(); + + push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); + + assert!(list.remove(ptr(&a)).is_some()); + assert_clean!(a); + + assert_ptr_eq!(b, list.head); + assert_ptr_eq!(c, b.pointers.get_next()); + assert_ptr_eq!(b, c.pointers.get_prev()); + + let items = collect_list(&mut list); + assert_eq!([31, 7].to_vec(), items); + } + + unsafe { + // Remove middle + let mut list = LinkedList::new(); + + push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); + + assert!(list.remove(ptr(&b)).is_some()); + assert_clean!(b); + + assert_ptr_eq!(c, a.pointers.get_next()); + assert_ptr_eq!(a, c.pointers.get_prev()); + + let items = collect_list(&mut list); + assert_eq!([31, 5].to_vec(), items); + } + + unsafe { + // Remove last + // Remove middle + let mut list = LinkedList::new(); + + push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]); + + assert!(list.remove(ptr(&c)).is_some()); + assert_clean!(c); + + assert!(b.pointers.get_next().is_none()); + assert_ptr_eq!(b, list.tail); + + let items = collect_list(&mut list); + assert_eq!([7, 5].to_vec(), items); + } + + unsafe { + // Remove first of two + let mut list = LinkedList::new(); + + push_all(&mut list, &[b.as_ref(), a.as_ref()]); + + assert!(list.remove(ptr(&a)).is_some()); + + assert_clean!(a); + + // a should be no longer there and can't be removed twice + assert!(list.remove(ptr(&a)).is_none()); + + assert_ptr_eq!(b, list.head); + assert_ptr_eq!(b, list.tail); + + assert!(b.pointers.get_next().is_none()); + assert!(b.pointers.get_prev().is_none()); + + let items = collect_list(&mut list); + assert_eq!([7].to_vec(), items); + } + + unsafe { + // Remove last of two + let mut list = LinkedList::new(); + + push_all(&mut list, &[b.as_ref(), a.as_ref()]); + + assert!(list.remove(ptr(&b)).is_some()); + + assert_clean!(b); + + assert_ptr_eq!(a, list.head); + assert_ptr_eq!(a, list.tail); + + assert!(a.pointers.get_next().is_none()); + assert!(a.pointers.get_prev().is_none()); + + let items = collect_list(&mut list); + assert_eq!([5].to_vec(), items); + } + + unsafe { + // Remove last item + let mut list = LinkedList::new(); + + push_all(&mut list, &[a.as_ref()]); + + assert!(list.remove(ptr(&a)).is_some()); + assert_clean!(a); + + assert!(list.head.is_none()); + assert!(list.tail.is_none()); + let items = collect_list(&mut list); + assert!(items.is_empty()); + } + + unsafe { + // Remove missing + let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); + + list.push_front(b.as_ref()); + list.push_front(a.as_ref()); + + assert!(list.remove(ptr(&c)).is_none()); + } + } + + #[cfg(not(target_arch = "wasm32"))] + proptest::proptest! { + #[test] + fn fuzz_linked_list(ops: Vec<usize>) { + run_fuzz(ops); + } + } + + fn run_fuzz(ops: Vec<usize>) { + use std::collections::VecDeque; + + #[derive(Debug)] + enum Op { + Push, + Pop, + Remove(usize), + } + + let ops = ops + .iter() + .map(|i| match i % 3 { + 0 => Op::Push, + 1 => Op::Pop, + 2 => Op::Remove(i / 3), + _ => unreachable!(), + }) + .collect::<Vec<_>>(); + + let mut ll = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); + let mut reference = VecDeque::new(); + + let entries: Vec<_> = (0..ops.len()).map(|i| entry(i as i32)).collect(); + + for (i, op) in ops.iter().enumerate() { + match op { + Op::Push => { + reference.push_front(i as i32); + assert_eq!(entries[i].val, i as i32); + + ll.push_front(entries[i].as_ref()); + } + Op::Pop => { + if reference.is_empty() { + assert!(ll.is_empty()); + continue; + } + + let v = reference.pop_back(); + assert_eq!(v, ll.pop_back().map(|v| v.val)); + } + Op::Remove(n) => { + if reference.is_empty() { + assert!(ll.is_empty()); + continue; + } + + let idx = n % reference.len(); + let expect = reference.remove(idx).unwrap(); + + unsafe { + let entry = ll.remove(ptr(&entries[expect as usize])).unwrap(); + assert_eq!(expect, entry.val); + } + } + } + } + } +} diff --git a/third_party/rust/tokio/src/util/mod.rs b/third_party/rust/tokio/src/util/mod.rs new file mode 100644 index 0000000000..618f554380 --- /dev/null +++ b/third_party/rust/tokio/src/util/mod.rs @@ -0,0 +1,83 @@ +cfg_io_driver! { + pub(crate) mod bit; + pub(crate) mod slab; +} + +#[cfg(feature = "rt")] +pub(crate) mod atomic_cell; + +#[cfg(any( + // io driver uses `WakeList` directly + feature = "net", + feature = "process", + // `sync` enables `Notify` and `batch_semaphore`, which require `WakeList`. + feature = "sync", + // `fs` uses `batch_semaphore`, which requires `WakeList`. + feature = "fs", + // rt and signal use `Notify`, which requires `WakeList`. + feature = "rt", + feature = "signal", +))] +mod wake_list; +#[cfg(any( + feature = "net", + feature = "process", + feature = "sync", + feature = "fs", + feature = "rt", + feature = "signal", +))] +pub(crate) use wake_list::WakeList; + +#[cfg(any( + feature = "fs", + feature = "net", + feature = "process", + feature = "rt", + feature = "sync", + feature = "signal", + feature = "time", +))] +pub(crate) mod linked_list; + +#[cfg(any(feature = "rt-multi-thread", feature = "macros"))] +mod rand; + +cfg_rt! { + cfg_unstable! { + mod idle_notified_set; + pub(crate) use idle_notified_set::IdleNotifiedSet; + } + + mod wake; + pub(crate) use wake::WakerRef; + pub(crate) use wake::{waker_ref, Wake}; + + mod sync_wrapper; + pub(crate) use sync_wrapper::SyncWrapper; + + mod vec_deque_cell; + pub(crate) use vec_deque_cell::VecDequeCell; +} + +cfg_rt_multi_thread! { + pub(crate) use self::rand::FastRand; + + mod try_lock; + pub(crate) use try_lock::TryLock; +} + +pub(crate) mod trace; + +#[cfg(any(feature = "macros"))] +#[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] +pub use self::rand::thread_rng_n; + +#[cfg(any( + feature = "rt", + feature = "time", + feature = "net", + feature = "process", + all(unix, feature = "signal") +))] +pub(crate) mod error; diff --git a/third_party/rust/tokio/src/util/pad.rs b/third_party/rust/tokio/src/util/pad.rs new file mode 100644 index 0000000000..bf0913ca85 --- /dev/null +++ b/third_party/rust/tokio/src/util/pad.rs @@ -0,0 +1,52 @@ +use core::fmt; +use core::ops::{Deref, DerefMut}; + +#[derive(Clone, Copy, Default, Hash, PartialEq, Eq)] +// Starting from Intel's Sandy Bridge, spatial prefetcher is now pulling pairs of 64-byte cache +// lines at a time, so we have to align to 128 bytes rather than 64. +// +// Sources: +// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf +// - https://github.com/facebook/folly/blob/1b5288e6eea6df074758f877c849b6e73bbb9fbb/folly/lang/Align.h#L107 +#[cfg_attr(target_arch = "x86_64", repr(align(128)))] +#[cfg_attr(not(target_arch = "x86_64"), repr(align(64)))] +pub(crate) struct CachePadded<T> { + value: T, +} + +unsafe impl<T: Send> Send for CachePadded<T> {} +unsafe impl<T: Sync> Sync for CachePadded<T> {} + +impl<T> CachePadded<T> { + pub(crate) fn new(t: T) -> CachePadded<T> { + CachePadded::<T> { value: t } + } +} + +impl<T> Deref for CachePadded<T> { + type Target = T; + + fn deref(&self) -> &T { + &self.value + } +} + +impl<T> DerefMut for CachePadded<T> { + fn deref_mut(&mut self) -> &mut T { + &mut self.value + } +} + +impl<T: fmt::Debug> fmt::Debug for CachePadded<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CachePadded") + .field("value", &self.value) + .finish() + } +} + +impl<T> From<T> for CachePadded<T> { + fn from(t: T) -> Self { + CachePadded::new(t) + } +} diff --git a/third_party/rust/tokio/src/util/rand.rs b/third_party/rust/tokio/src/util/rand.rs new file mode 100644 index 0000000000..6b19c8be95 --- /dev/null +++ b/third_party/rust/tokio/src/util/rand.rs @@ -0,0 +1,64 @@ +use std::cell::Cell; + +/// Fast random number generate. +/// +/// Implement xorshift64+: 2 32-bit xorshift sequences added together. +/// Shift triplet `[17,7,16]` was calculated as indicated in Marsaglia's +/// Xorshift paper: <https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf> +/// This generator passes the SmallCrush suite, part of TestU01 framework: +/// <http://simul.iro.umontreal.ca/testu01/tu01.html> +#[derive(Debug)] +pub(crate) struct FastRand { + one: Cell<u32>, + two: Cell<u32>, +} + +impl FastRand { + /// Initializes a new, thread-local, fast random number generator. + pub(crate) fn new(seed: u64) -> FastRand { + let one = (seed >> 32) as u32; + let mut two = seed as u32; + + if two == 0 { + // This value cannot be zero + two = 1; + } + + FastRand { + one: Cell::new(one), + two: Cell::new(two), + } + } + + pub(crate) fn fastrand_n(&self, n: u32) -> u32 { + // This is similar to fastrand() % n, but faster. + // See https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + let mul = (self.fastrand() as u64).wrapping_mul(n as u64); + (mul >> 32) as u32 + } + + fn fastrand(&self) -> u32 { + let mut s1 = self.one.get(); + let s0 = self.two.get(); + + s1 ^= s1 << 17; + s1 = s1 ^ s0 ^ s1 >> 7 ^ s0 >> 16; + + self.one.set(s0); + self.two.set(s1); + + s0.wrapping_add(s1) + } +} + +// Used by the select macro and `StreamMap` +#[cfg(any(feature = "macros"))] +#[doc(hidden)] +#[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] +pub fn thread_rng_n(n: u32) -> u32 { + thread_local! { + static THREAD_RNG: FastRand = FastRand::new(crate::loom::rand::seed()); + } + + THREAD_RNG.with(|rng| rng.fastrand_n(n)) +} diff --git a/third_party/rust/tokio/src/util/slab.rs b/third_party/rust/tokio/src/util/slab.rs new file mode 100644 index 0000000000..214fa08dc8 --- /dev/null +++ b/third_party/rust/tokio/src/util/slab.rs @@ -0,0 +1,855 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicBool, AtomicUsize}; +use crate::loom::sync::{Arc, Mutex}; +use crate::util::bit; +use std::fmt; +use std::mem; +use std::ops; +use std::ptr; +use std::sync::atomic::Ordering::Relaxed; + +/// Amortized allocation for homogeneous data types. +/// +/// The slab pre-allocates chunks of memory to store values. It uses a similar +/// growing strategy as `Vec`. When new capacity is needed, the slab grows by +/// 2x. +/// +/// # Pages +/// +/// Unlike `Vec`, growing does not require moving existing elements. Instead of +/// being a continuous chunk of memory for all elements, `Slab` is an array of +/// arrays. The top-level array is an array of pages. Each page is 2x bigger +/// than the previous one. When the slab grows, a new page is allocated. +/// +/// Pages are lazily initialized. +/// +/// # Allocating +/// +/// When allocating an object, first previously used slots are reused. If no +/// previously used slot is available, a new slot is initialized in an existing +/// page. If all pages are full, then a new page is allocated. +/// +/// When an allocated object is released, it is pushed into it's page's free +/// list. Allocating scans all pages for a free slot. +/// +/// # Indexing +/// +/// The slab is able to index values using an address. Even when the indexed +/// object has been released, it is still safe to index. This is a key ability +/// for using the slab with the I/O driver. Addresses are registered with the +/// OS's selector and I/O resources can be released without synchronizing with +/// the OS. +/// +/// # Compaction +/// +/// `Slab::compact` will release pages that have been allocated but are no +/// longer used. This is done by scanning the pages and finding pages with no +/// allocated objects. These pages are then freed. +/// +/// # Synchronization +/// +/// The `Slab` structure is able to provide (mostly) unsynchronized reads to +/// values stored in the slab. Insertions and removals are synchronized. Reading +/// objects via `Ref` is fully unsynchronized. Indexing objects uses amortized +/// synchronization. +/// +pub(crate) struct Slab<T> { + /// Array of pages. Each page is synchronized. + pages: [Arc<Page<T>>; NUM_PAGES], + + /// Caches the array pointer & number of initialized slots. + cached: [CachedPage<T>; NUM_PAGES], +} + +/// Allocate values in the associated slab. +pub(crate) struct Allocator<T> { + /// Pages in the slab. The first page has a capacity of 16 elements. Each + /// following page has double the capacity of the previous page. + /// + /// Each returned `Ref` holds a reference count to this `Arc`. + pages: [Arc<Page<T>>; NUM_PAGES], +} + +/// References a slot in the slab. Indexing a slot using an `Address` is memory +/// safe even if the slot has been released or the page has been deallocated. +/// However, it is not guaranteed that the slot has not been reused and is now +/// represents a different value. +/// +/// The I/O driver uses a counter to track the slot's generation. Once accessing +/// the slot, the generations are compared. If they match, the value matches the +/// address. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct Address(usize); + +/// An entry in the slab. +pub(crate) trait Entry: Default { + /// Resets the entry's value and track the generation. + fn reset(&self); +} + +/// A reference to a value stored in the slab. +pub(crate) struct Ref<T> { + value: *const Value<T>, +} + +/// Maximum number of pages a slab can contain. +const NUM_PAGES: usize = 19; + +/// Minimum number of slots a page can contain. +const PAGE_INITIAL_SIZE: usize = 32; +const PAGE_INDEX_SHIFT: u32 = PAGE_INITIAL_SIZE.trailing_zeros() + 1; + +/// A page in the slab. +struct Page<T> { + /// Slots. + slots: Mutex<Slots<T>>, + + // Number of slots currently being used. This is not guaranteed to be up to + // date and should only be used as a hint. + used: AtomicUsize, + + // Set to `true` when the page has been allocated. + allocated: AtomicBool, + + // The number of slots the page can hold. + len: usize, + + // Length of all previous pages combined. + prev_len: usize, +} + +struct CachedPage<T> { + /// Pointer to the page's slots. + slots: *const Slot<T>, + + /// Number of initialized slots. + init: usize, +} + +/// Page state. +struct Slots<T> { + /// Slots. + slots: Vec<Slot<T>>, + + head: usize, + + /// Number of slots currently in use. + used: usize, +} + +unsafe impl<T: Sync> Sync for Page<T> {} +unsafe impl<T: Sync> Send for Page<T> {} +unsafe impl<T: Sync> Sync for CachedPage<T> {} +unsafe impl<T: Sync> Send for CachedPage<T> {} +unsafe impl<T: Sync> Sync for Ref<T> {} +unsafe impl<T: Sync> Send for Ref<T> {} + +/// A slot in the slab. Contains slot-specific metadata. +/// +/// `#[repr(C)]` guarantees that the struct starts w/ `value`. We use pointer +/// math to map a value pointer to an index in the page. +#[repr(C)] +struct Slot<T> { + /// Pointed to by `Ref`. + value: UnsafeCell<Value<T>>, + + /// Next entry in the free list. + next: u32, + + /// Makes miri happy by making mutable references not take exclusive access. + /// + /// Could probably also be fixed by replacing `slots` with a raw-pointer + /// based equivalent. + _pin: std::marker::PhantomPinned, +} + +/// Value paired with a reference to the page. +struct Value<T> { + /// Value stored in the value. + value: T, + + /// Pointer to the page containing the slot. + /// + /// A raw pointer is used as this creates a ref cycle. + page: *const Page<T>, +} + +impl<T> Slab<T> { + /// Create a new, empty, slab. + pub(crate) fn new() -> Slab<T> { + // Initializing arrays is a bit annoying. Instead of manually writing + // out an array and every single entry, `Default::default()` is used to + // initialize the array, then the array is iterated and each value is + // initialized. + let mut slab = Slab { + pages: Default::default(), + cached: Default::default(), + }; + + let mut len = PAGE_INITIAL_SIZE; + let mut prev_len: usize = 0; + + for page in &mut slab.pages { + let page = Arc::get_mut(page).unwrap(); + page.len = len; + page.prev_len = prev_len; + len *= 2; + prev_len += page.len; + + // Ensure we don't exceed the max address space. + debug_assert!( + page.len - 1 + page.prev_len < (1 << 24), + "max = {:b}", + page.len - 1 + page.prev_len + ); + } + + slab + } + + /// Returns a new `Allocator`. + /// + /// The `Allocator` supports concurrent allocation of objects. + pub(crate) fn allocator(&self) -> Allocator<T> { + Allocator { + pages: self.pages.clone(), + } + } + + /// Returns a reference to the value stored at the given address. + /// + /// `&mut self` is used as the call may update internal cached state. + pub(crate) fn get(&mut self, addr: Address) -> Option<&T> { + let page_idx = addr.page(); + let slot_idx = self.pages[page_idx].slot(addr); + + // If the address references a slot that was last seen as uninitialized, + // the `CachedPage` is updated. This requires acquiring the page lock + // and updating the slot pointer and initialized offset. + if self.cached[page_idx].init <= slot_idx { + self.cached[page_idx].refresh(&self.pages[page_idx]); + } + + // If the address **still** references an uninitialized slot, then the + // address is invalid and `None` is returned. + if self.cached[page_idx].init <= slot_idx { + return None; + } + + // Get a reference to the value. The lifetime of the returned reference + // is bound to `&self`. The only way to invalidate the underlying memory + // is to call `compact()`. The lifetimes prevent calling `compact()` + // while references to values are outstanding. + // + // The referenced data is never mutated. Only `&self` references are + // used and the data is `Sync`. + Some(self.cached[page_idx].get(slot_idx)) + } + + /// Calls the given function with a reference to each slot in the slab. The + /// slot may not be in-use. + /// + /// This is used by the I/O driver during the shutdown process to notify + /// each pending task. + pub(crate) fn for_each(&mut self, mut f: impl FnMut(&T)) { + for page_idx in 0..self.pages.len() { + // It is required to avoid holding the lock when calling the + // provided function. The function may attempt to acquire the lock + // itself. If we hold the lock here while calling `f`, a deadlock + // situation is possible. + // + // Instead of iterating the slots directly in `page`, which would + // require holding the lock, the cache is updated and the slots are + // iterated from the cache. + self.cached[page_idx].refresh(&self.pages[page_idx]); + + for slot_idx in 0..self.cached[page_idx].init { + f(self.cached[page_idx].get(slot_idx)); + } + } + } + + // Release memory back to the allocator. + // + // If pages are empty, the underlying memory is released back to the + // allocator. + pub(crate) fn compact(&mut self) { + // Iterate each page except the very first one. The very first page is + // never freed. + for (idx, page) in self.pages.iter().enumerate().skip(1) { + if page.used.load(Relaxed) != 0 || !page.allocated.load(Relaxed) { + // If the page has slots in use or the memory has not been + // allocated then it cannot be compacted. + continue; + } + + let mut slots = match page.slots.try_lock() { + Some(slots) => slots, + // If the lock cannot be acquired due to being held by another + // thread, don't try to compact the page. + _ => continue, + }; + + if slots.used > 0 || slots.slots.capacity() == 0 { + // The page is in use or it has not yet been allocated. Either + // way, there is no more work to do. + continue; + } + + page.allocated.store(false, Relaxed); + + // Remove the slots vector from the page. This is done so that the + // freeing process is done outside of the lock's critical section. + let vec = mem::take(&mut slots.slots); + slots.head = 0; + + // Drop the lock so we can drop the vector outside the lock below. + drop(slots); + + debug_assert!( + self.cached[idx].slots.is_null() || self.cached[idx].slots == vec.as_ptr(), + "cached = {:?}; actual = {:?}", + self.cached[idx].slots, + vec.as_ptr(), + ); + + // Clear cache + self.cached[idx].slots = ptr::null(); + self.cached[idx].init = 0; + + drop(vec); + } + } +} + +impl<T> fmt::Debug for Slab<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + debug(fmt, "Slab", &self.pages[..]) + } +} + +impl<T: Entry> Allocator<T> { + /// Allocate a new entry and return a handle to the entry. + /// + /// Scans pages from smallest to biggest, stopping when a slot is found. + /// Pages are allocated if necessary. + /// + /// Returns `None` if the slab is full. + pub(crate) fn allocate(&self) -> Option<(Address, Ref<T>)> { + // Find the first available slot. + for page in &self.pages[..] { + if let Some((addr, val)) = Page::allocate(page) { + return Some((addr, val)); + } + } + + None + } +} + +impl<T> fmt::Debug for Allocator<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + debug(fmt, "slab::Allocator", &self.pages[..]) + } +} + +impl<T> ops::Deref for Ref<T> { + type Target = T; + + fn deref(&self) -> &T { + // Safety: `&mut` is never handed out to the underlying value. The page + // is not freed until all `Ref` values are dropped. + unsafe { &(*self.value).value } + } +} + +impl<T> Drop for Ref<T> { + fn drop(&mut self) { + // Safety: `&mut` is never handed out to the underlying value. The page + // is not freed until all `Ref` values are dropped. + let _ = unsafe { (*self.value).release() }; + } +} + +impl<T: fmt::Debug> fmt::Debug for Ref<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(fmt) + } +} + +impl<T: Entry> Page<T> { + // Allocates an object, returns the ref and address. + // + // `self: &Arc<Page<T>>` is avoided here as this would not work with the + // loom `Arc`. + fn allocate(me: &Arc<Page<T>>) -> Option<(Address, Ref<T>)> { + // Before acquiring the lock, use the `used` hint. + if me.used.load(Relaxed) == me.len { + return None; + } + + // Allocating objects requires synchronization + let mut locked = me.slots.lock(); + + if locked.head < locked.slots.len() { + // Re-use an already initialized slot. + // + // Help out the borrow checker + let locked = &mut *locked; + + // Get the index of the slot at the head of the free stack. This is + // the slot that will be reused. + let idx = locked.head; + let slot = &locked.slots[idx]; + + // Update the free stack head to point to the next slot. + locked.head = slot.next as usize; + + // Increment the number of used slots + locked.used += 1; + me.used.store(locked.used, Relaxed); + + // Reset the slot + slot.value.with(|ptr| unsafe { (*ptr).value.reset() }); + + // Return a reference to the slot + Some((me.addr(idx), locked.gen_ref(idx, me))) + } else if me.len == locked.slots.len() { + // The page is full + None + } else { + // No initialized slots are available, but the page has more + // capacity. Initialize a new slot. + let idx = locked.slots.len(); + + if idx == 0 { + // The page has not yet been allocated. Allocate the storage for + // all page slots. + locked.slots.reserve_exact(me.len); + } + + // Initialize a new slot + locked.slots.push(Slot { + value: UnsafeCell::new(Value { + value: Default::default(), + page: Arc::as_ptr(me), + }), + next: 0, + _pin: std::marker::PhantomPinned, + }); + + // Increment the head to indicate the free stack is empty + locked.head += 1; + + // Increment the number of used slots + locked.used += 1; + me.used.store(locked.used, Relaxed); + me.allocated.store(true, Relaxed); + + debug_assert_eq!(locked.slots.len(), locked.head); + + Some((me.addr(idx), locked.gen_ref(idx, me))) + } + } +} + +impl<T> Page<T> { + /// Returns the slot index within the current page referenced by the given + /// address. + fn slot(&self, addr: Address) -> usize { + addr.0 - self.prev_len + } + + /// Returns the address for the given slot. + fn addr(&self, slot: usize) -> Address { + Address(slot + self.prev_len) + } +} + +impl<T> Default for Page<T> { + fn default() -> Page<T> { + Page { + used: AtomicUsize::new(0), + allocated: AtomicBool::new(false), + slots: Mutex::new(Slots { + slots: Vec::new(), + head: 0, + used: 0, + }), + len: 0, + prev_len: 0, + } + } +} + +impl<T> Page<T> { + /// Release a slot into the page's free list. + fn release(&self, value: *const Value<T>) { + let mut locked = self.slots.lock(); + + let idx = locked.index_for(value); + locked.slots[idx].next = locked.head as u32; + locked.head = idx; + locked.used -= 1; + + self.used.store(locked.used, Relaxed); + } +} + +impl<T> CachedPage<T> { + /// Refreshes the cache. + fn refresh(&mut self, page: &Page<T>) { + let slots = page.slots.lock(); + + if !slots.slots.is_empty() { + self.slots = slots.slots.as_ptr(); + self.init = slots.slots.len(); + } + } + + /// Gets a value by index. + fn get(&self, idx: usize) -> &T { + assert!(idx < self.init); + + // Safety: Pages are allocated concurrently, but are only ever + // **deallocated** by `Slab`. `Slab` will always have a more + // conservative view on the state of the slot array. Once `CachedPage` + // sees a slot pointer and initialized offset, it will remain valid + // until `compact()` is called. The `compact()` function also updates + // `CachedPage`. + unsafe { + let slot = self.slots.add(idx); + let value = slot as *const Value<T>; + + &(*value).value + } + } +} + +impl<T> Default for CachedPage<T> { + fn default() -> CachedPage<T> { + CachedPage { + slots: ptr::null(), + init: 0, + } + } +} + +impl<T> Slots<T> { + /// Maps a slot pointer to an offset within the current page. + /// + /// The pointer math removes the `usize` index from the `Ref` struct, + /// shrinking the struct to a single pointer size. The contents of the + /// function is safe, the resulting `usize` is bounds checked before being + /// used. + /// + /// # Panics + /// + /// panics if the provided slot pointer is not contained by the page. + fn index_for(&self, slot: *const Value<T>) -> usize { + use std::mem; + + let base = &self.slots[0] as *const _ as usize; + + assert!(base != 0, "page is unallocated"); + + let slot = slot as usize; + let width = mem::size_of::<Slot<T>>(); + + assert!(slot >= base, "unexpected pointer"); + + let idx = (slot - base) / width; + assert!(idx < self.slots.len() as usize); + + idx + } + + /// Generates a `Ref` for the slot at the given index. This involves bumping the page's ref count. + fn gen_ref(&self, idx: usize, page: &Arc<Page<T>>) -> Ref<T> { + assert!(idx < self.slots.len()); + mem::forget(page.clone()); + + let vec_ptr = self.slots.as_ptr(); + let slot: *const Slot<T> = unsafe { vec_ptr.add(idx) }; + let value: *const Value<T> = slot as *const Value<T>; + + Ref { value } + } +} + +impl<T> Value<T> { + /// Releases the slot, returning the `Arc<Page<T>>` logically owned by the ref. + fn release(&self) -> Arc<Page<T>> { + // Safety: called by `Ref`, which owns an `Arc<Page<T>>` instance. + let page = unsafe { Arc::from_raw(self.page) }; + page.release(self as *const _); + page + } +} + +impl Address { + fn page(self) -> usize { + // Since every page is twice as large as the previous page, and all page + // sizes are powers of two, we can determine the page index that + // contains a given address by shifting the address down by the smallest + // page size and looking at how many twos places necessary to represent + // that number, telling us what power of two page size it fits inside + // of. We can determine the number of twos places by counting the number + // of leading zeros (unused twos places) in the number's binary + // representation, and subtracting that count from the total number of + // bits in a word. + let slot_shifted = (self.0 + PAGE_INITIAL_SIZE) >> PAGE_INDEX_SHIFT; + (bit::pointer_width() - slot_shifted.leading_zeros()) as usize + } + + pub(crate) const fn as_usize(self) -> usize { + self.0 + } + + pub(crate) fn from_usize(src: usize) -> Address { + Address(src) + } +} + +fn debug<T>(fmt: &mut fmt::Formatter<'_>, name: &str, pages: &[Arc<Page<T>>]) -> fmt::Result { + let mut capacity = 0; + let mut len = 0; + + for page in pages { + if page.allocated.load(Relaxed) { + capacity += page.len; + len += page.used.load(Relaxed); + } + } + + fmt.debug_struct(name) + .field("len", &len) + .field("capacity", &capacity) + .finish() +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering::SeqCst; + + struct Foo { + cnt: AtomicUsize, + id: AtomicUsize, + } + + impl Default for Foo { + fn default() -> Foo { + Foo { + cnt: AtomicUsize::new(0), + id: AtomicUsize::new(0), + } + } + } + + impl Entry for Foo { + fn reset(&self) { + self.cnt.fetch_add(1, SeqCst); + } + } + + #[test] + fn insert_remove() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + + let (addr1, foo1) = alloc.allocate().unwrap(); + foo1.id.store(1, SeqCst); + assert_eq!(0, foo1.cnt.load(SeqCst)); + + let (addr2, foo2) = alloc.allocate().unwrap(); + foo2.id.store(2, SeqCst); + assert_eq!(0, foo2.cnt.load(SeqCst)); + + assert_eq!(1, slab.get(addr1).unwrap().id.load(SeqCst)); + assert_eq!(2, slab.get(addr2).unwrap().id.load(SeqCst)); + + drop(foo1); + + assert_eq!(1, slab.get(addr1).unwrap().id.load(SeqCst)); + + let (addr3, foo3) = alloc.allocate().unwrap(); + assert_eq!(addr3, addr1); + assert_eq!(1, foo3.cnt.load(SeqCst)); + foo3.id.store(3, SeqCst); + assert_eq!(3, slab.get(addr3).unwrap().id.load(SeqCst)); + + drop(foo2); + drop(foo3); + + slab.compact(); + + // The first page is never released + assert!(slab.get(addr1).is_some()); + assert!(slab.get(addr2).is_some()); + assert!(slab.get(addr3).is_some()); + } + + #[test] + fn insert_many() { + const MANY: usize = normal_or_miri(10_000, 50); + + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for i in 0..MANY { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + entries.push((addr, val)); + } + + for (i, (addr, v)) in entries.iter().enumerate() { + assert_eq!(i, v.id.load(SeqCst)); + assert_eq!(i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + + entries.clear(); + + for i in 0..MANY { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(MANY - i, SeqCst); + entries.push((addr, val)); + } + + for (i, (addr, v)) in entries.iter().enumerate() { + assert_eq!(MANY - i, v.id.load(SeqCst)); + assert_eq!(MANY - i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + + #[test] + fn insert_drop_reverse() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for i in 0..normal_or_miri(10_000, 100) { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + entries.push((addr, val)); + } + + for _ in 0..10 { + // Drop 1000 in reverse + for _ in 0..normal_or_miri(1_000, 10) { + entries.pop(); + } + + // Check remaining + for (i, (addr, v)) in entries.iter().enumerate() { + assert_eq!(i, v.id.load(SeqCst)); + assert_eq!(i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + } + + #[test] + fn no_compaction_if_page_still_in_use() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries1 = vec![]; + let mut entries2 = vec![]; + + for i in 0..normal_or_miri(10_000, 100) { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + + if i % 2 == 0 { + entries1.push((addr, val, i)); + } else { + entries2.push(val); + } + } + + drop(entries2); + + for (addr, _, i) in &entries1 { + assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + + const fn normal_or_miri(normal: usize, miri: usize) -> usize { + if cfg!(miri) { + miri + } else { + normal + } + } + + #[test] + fn compact_all() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for _ in 0..2 { + entries.clear(); + + for i in 0..normal_or_miri(10_000, 100) { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + + entries.push((addr, val)); + } + + let mut addrs = vec![]; + + for (addr, _) in entries.drain(..) { + addrs.push(addr); + } + + slab.compact(); + + // The first page is never freed + for addr in &addrs[PAGE_INITIAL_SIZE..] { + assert!(slab.get(*addr).is_none()); + } + } + } + + #[test] + fn issue_3014() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for _ in 0..normal_or_miri(5, 2) { + entries.clear(); + + // Allocate a few pages + 1 + for i in 0..(32 + 64 + 128 + 1) { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + + entries.push((addr, val, i)); + } + + for (addr, val, i) in &entries { + assert_eq!(*i, val.id.load(SeqCst)); + assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + + // Release the last entry + entries.pop(); + + // Compact + slab.compact(); + + // Check all the addresses + + for (addr, val, i) in &entries { + assert_eq!(*i, val.id.load(SeqCst)); + assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + } +} diff --git a/third_party/rust/tokio/src/util/sync_wrapper.rs b/third_party/rust/tokio/src/util/sync_wrapper.rs new file mode 100644 index 0000000000..5ffc8f96b1 --- /dev/null +++ b/third_party/rust/tokio/src/util/sync_wrapper.rs @@ -0,0 +1,26 @@ +//! This module contains a type that can make `Send + !Sync` types `Sync` by +//! disallowing all immutable access to the value. +//! +//! A similar primitive is provided in the `sync_wrapper` crate. + +pub(crate) struct SyncWrapper<T> { + value: T, +} + +// safety: The SyncWrapper being send allows you to send the inner value across +// thread boundaries. +unsafe impl<T: Send> Send for SyncWrapper<T> {} + +// safety: An immutable reference to a SyncWrapper is useless, so moving such an +// immutable reference across threads is safe. +unsafe impl<T> Sync for SyncWrapper<T> {} + +impl<T> SyncWrapper<T> { + pub(crate) fn new(value: T) -> Self { + Self { value } + } + + pub(crate) fn into_inner(self) -> T { + self.value + } +} diff --git a/third_party/rust/tokio/src/util/trace.rs b/third_party/rust/tokio/src/util/trace.rs new file mode 100644 index 0000000000..6080e2358a --- /dev/null +++ b/third_party/rust/tokio/src/util/trace.rs @@ -0,0 +1,99 @@ +cfg_trace! { + cfg_rt! { + use core::{ + pin::Pin, + task::{Context, Poll}, + }; + use pin_project_lite::pin_project; + use std::future::Future; + pub(crate) use tracing::instrument::Instrumented; + + #[inline] + #[track_caller] + pub(crate) fn task<F>(task: F, kind: &'static str, name: Option<&str>) -> Instrumented<F> { + use tracing::instrument::Instrument; + let location = std::panic::Location::caller(); + let span = tracing::trace_span!( + target: "tokio::task", + "runtime.spawn", + %kind, + task.name = %name.unwrap_or_default(), + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + task.instrument(span) + } + + pub(crate) fn async_op<P,F>(inner: P, resource_span: tracing::Span, source: &str, poll_op_name: &'static str, inherits_child_attrs: bool) -> InstrumentedAsyncOp<F> + where P: FnOnce() -> F { + resource_span.in_scope(|| { + let async_op_span = tracing::trace_span!("runtime.resource.async_op", source = source, inherits_child_attrs = inherits_child_attrs); + let enter = async_op_span.enter(); + let async_op_poll_span = tracing::trace_span!("runtime.resource.async_op.poll"); + let inner = inner(); + drop(enter); + let tracing_ctx = AsyncOpTracingCtx { + async_op_span, + async_op_poll_span, + resource_span: resource_span.clone(), + }; + InstrumentedAsyncOp { + inner, + tracing_ctx, + poll_op_name, + } + }) + } + + #[derive(Debug, Clone)] + pub(crate) struct AsyncOpTracingCtx { + pub(crate) async_op_span: tracing::Span, + pub(crate) async_op_poll_span: tracing::Span, + pub(crate) resource_span: tracing::Span, + } + + + pin_project! { + #[derive(Debug, Clone)] + pub(crate) struct InstrumentedAsyncOp<F> { + #[pin] + pub(crate) inner: F, + pub(crate) tracing_ctx: AsyncOpTracingCtx, + pub(crate) poll_op_name: &'static str + } + } + + impl<F: Future> Future for InstrumentedAsyncOp<F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let poll_op_name = &*this.poll_op_name; + let _res_enter = this.tracing_ctx.resource_span.enter(); + let _async_op_enter = this.tracing_ctx.async_op_span.enter(); + let _async_op_poll_enter = this.tracing_ctx.async_op_poll_span.enter(); + trace_poll_op!(poll_op_name, this.inner.poll(cx)) + } + } + } +} +cfg_time! { + #[track_caller] + pub(crate) fn caller_location() -> Option<&'static std::panic::Location<'static>> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + return Some(std::panic::Location::caller()); + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + None + } +} + +cfg_not_trace! { + cfg_rt! { + #[inline] + pub(crate) fn task<F>(task: F, _: &'static str, _name: Option<&str>) -> F { + // nop + task + } + } +} diff --git a/third_party/rust/tokio/src/util/try_lock.rs b/third_party/rust/tokio/src/util/try_lock.rs new file mode 100644 index 0000000000..8b0edb4a87 --- /dev/null +++ b/third_party/rust/tokio/src/util/try_lock.rs @@ -0,0 +1,80 @@ +use crate::loom::sync::atomic::AtomicBool; + +use std::cell::UnsafeCell; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::Ordering::SeqCst; + +pub(crate) struct TryLock<T> { + locked: AtomicBool, + data: UnsafeCell<T>, +} + +pub(crate) struct LockGuard<'a, T> { + lock: &'a TryLock<T>, + _p: PhantomData<std::rc::Rc<()>>, +} + +unsafe impl<T: Send> Send for TryLock<T> {} +unsafe impl<T: Send> Sync for TryLock<T> {} + +unsafe impl<T: Sync> Sync for LockGuard<'_, T> {} + +macro_rules! new { + ($data:ident) => { + TryLock { + locked: AtomicBool::new(false), + data: UnsafeCell::new($data), + } + }; +} + +impl<T> TryLock<T> { + #[cfg(not(loom))] + /// Create a new `TryLock` + pub(crate) const fn new(data: T) -> TryLock<T> { + new!(data) + } + + #[cfg(loom)] + /// Create a new `TryLock` + pub(crate) fn new(data: T) -> TryLock<T> { + new!(data) + } + + /// Attempt to acquire lock + pub(crate) fn try_lock(&self) -> Option<LockGuard<'_, T>> { + if self + .locked + .compare_exchange(false, true, SeqCst, SeqCst) + .is_err() + { + return None; + } + + Some(LockGuard { + lock: self, + _p: PhantomData, + }) + } +} + +impl<T> Deref for LockGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.lock.data.get() } + } +} + +impl<T> DerefMut for LockGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.lock.data.get() } + } +} + +impl<T> Drop for LockGuard<'_, T> { + fn drop(&mut self) { + self.lock.locked.store(false, SeqCst); + } +} diff --git a/third_party/rust/tokio/src/util/vec_deque_cell.rs b/third_party/rust/tokio/src/util/vec_deque_cell.rs new file mode 100644 index 0000000000..b4e124c151 --- /dev/null +++ b/third_party/rust/tokio/src/util/vec_deque_cell.rs @@ -0,0 +1,53 @@ +use crate::loom::cell::UnsafeCell; + +use std::collections::VecDeque; +use std::marker::PhantomData; + +/// This type is like VecDeque, except that it is not Sync and can be modified +/// through immutable references. +pub(crate) struct VecDequeCell<T> { + inner: UnsafeCell<VecDeque<T>>, + _not_sync: PhantomData<*const ()>, +} + +// This is Send for the same reasons that RefCell<VecDeque<T>> is Send. +unsafe impl<T: Send> Send for VecDequeCell<T> {} + +impl<T> VecDequeCell<T> { + pub(crate) fn with_capacity(cap: usize) -> Self { + Self { + inner: UnsafeCell::new(VecDeque::with_capacity(cap)), + _not_sync: PhantomData, + } + } + + /// Safety: This method may not be called recursively. + #[inline] + unsafe fn with_inner<F, R>(&self, f: F) -> R + where + F: FnOnce(&mut VecDeque<T>) -> R, + { + // safety: This type is not Sync, so concurrent calls of this method + // cannot happen. Furthermore, the caller guarantees that the method is + // not called recursively. Finally, this is the only place that can + // create mutable references to the inner VecDeque. This ensures that + // any mutable references created here are exclusive. + self.inner.with_mut(|ptr| f(&mut *ptr)) + } + + pub(crate) fn pop_front(&self) -> Option<T> { + unsafe { self.with_inner(VecDeque::pop_front) } + } + + pub(crate) fn push_back(&self, item: T) { + unsafe { + self.with_inner(|inner| inner.push_back(item)); + } + } + + /// Replaces the inner VecDeque with an empty VecDeque and return the current + /// contents. + pub(crate) fn take(&self) -> VecDeque<T> { + unsafe { self.with_inner(|inner| std::mem::take(inner)) } + } +} diff --git a/third_party/rust/tokio/src/util/wake.rs b/third_party/rust/tokio/src/util/wake.rs new file mode 100644 index 0000000000..5526cbc63a --- /dev/null +++ b/third_party/rust/tokio/src/util/wake.rs @@ -0,0 +1,80 @@ +use crate::loom::sync::Arc; + +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops::Deref; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +/// Simplified waking interface based on Arcs. +pub(crate) trait Wake: Send + Sync + Sized + 'static { + /// Wake by value. + fn wake(arc_self: Arc<Self>); + + /// Wake by reference. + fn wake_by_ref(arc_self: &Arc<Self>); +} + +/// A `Waker` that is only valid for a given lifetime. +#[derive(Debug)] +pub(crate) struct WakerRef<'a> { + waker: ManuallyDrop<Waker>, + _p: PhantomData<&'a ()>, +} + +impl Deref for WakerRef<'_> { + type Target = Waker; + + fn deref(&self) -> &Waker { + &self.waker + } +} + +/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`. +pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> { + let ptr = Arc::as_ptr(wake) as *const (); + + let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) }; + + WakerRef { + waker: ManuallyDrop::new(waker), + _p: PhantomData, + } +} + +fn waker_vtable<W: Wake>() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_arc_raw::<W>, + wake_arc_raw::<W>, + wake_by_ref_arc_raw::<W>, + drop_arc_raw::<W>, + ) +} + +unsafe fn inc_ref_count<T: Wake>(data: *const ()) { + // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop + let arc = ManuallyDrop::new(Arc::<T>::from_raw(data as *const T)); + + // Now increase refcount, but don't drop new refcount either + let _arc_clone: ManuallyDrop<_> = arc.clone(); +} + +unsafe fn clone_arc_raw<T: Wake>(data: *const ()) -> RawWaker { + inc_ref_count::<T>(data); + RawWaker::new(data, waker_vtable::<T>()) +} + +unsafe fn wake_arc_raw<T: Wake>(data: *const ()) { + let arc: Arc<T> = Arc::from_raw(data as *const T); + Wake::wake(arc); +} + +// used by `waker_ref` +unsafe fn wake_by_ref_arc_raw<T: Wake>(data: *const ()) { + // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop + let arc = ManuallyDrop::new(Arc::<T>::from_raw(data as *const T)); + Wake::wake_by_ref(&arc); +} + +unsafe fn drop_arc_raw<T: Wake>(data: *const ()) { + drop(Arc::<T>::from_raw(data as *const T)) +} diff --git a/third_party/rust/tokio/src/util/wake_list.rs b/third_party/rust/tokio/src/util/wake_list.rs new file mode 100644 index 0000000000..aa569dd17b --- /dev/null +++ b/third_party/rust/tokio/src/util/wake_list.rs @@ -0,0 +1,53 @@ +use core::mem::MaybeUninit; +use core::ptr; +use std::task::Waker; + +const NUM_WAKERS: usize = 32; + +pub(crate) struct WakeList { + inner: [MaybeUninit<Waker>; NUM_WAKERS], + curr: usize, +} + +impl WakeList { + pub(crate) fn new() -> Self { + Self { + inner: unsafe { + // safety: Create an uninitialized array of `MaybeUninit`. The + // `assume_init` is safe because the type we are claiming to + // have initialized here is a bunch of `MaybeUninit`s, which do + // not require initialization. + MaybeUninit::uninit().assume_init() + }, + curr: 0, + } + } + + #[inline] + pub(crate) fn can_push(&self) -> bool { + self.curr < NUM_WAKERS + } + + pub(crate) fn push(&mut self, val: Waker) { + debug_assert!(self.can_push()); + + self.inner[self.curr] = MaybeUninit::new(val); + self.curr += 1; + } + + pub(crate) fn wake_all(&mut self) { + assert!(self.curr <= NUM_WAKERS); + while self.curr > 0 { + self.curr -= 1; + let waker = unsafe { ptr::read(self.inner[self.curr].as_mut_ptr()) }; + waker.wake(); + } + } +} + +impl Drop for WakeList { + fn drop(&mut self) { + let slice = ptr::slice_from_raw_parts_mut(self.inner.as_mut_ptr() as *mut Waker, self.curr); + unsafe { ptr::drop_in_place(slice) }; + } +} |