diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-17 12:02:58 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-17 12:02:58 +0000 |
commit | 698f8c2f01ea549d77d7dc3338a12e04c11057b9 (patch) | |
tree | 173a775858bd501c378080a10dca74132f05bc50 /library/std/src/sys_common | |
parent | Initial commit. (diff) | |
download | rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.tar.xz rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.zip |
Adding upstream version 1.64.0+dfsg1.upstream/1.64.0+dfsg1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'library/std/src/sys_common')
28 files changed, 4159 insertions, 0 deletions
diff --git a/library/std/src/sys_common/backtrace.rs b/library/std/src/sys_common/backtrace.rs new file mode 100644 index 000000000..31164afdc --- /dev/null +++ b/library/std/src/sys_common/backtrace.rs @@ -0,0 +1,183 @@ +use crate::backtrace_rs::{self, BacktraceFmt, BytesOrWideString, PrintFmt}; +use crate::borrow::Cow; +/// Common code for printing the backtrace in the same way across the different +/// supported platforms. +use crate::env; +use crate::fmt; +use crate::io; +use crate::io::prelude::*; +use crate::path::{self, Path, PathBuf}; +use crate::sys_common::mutex::StaticMutex; + +/// Max number of frames to print. +const MAX_NB_FRAMES: usize = 100; + +// SAFETY: Don't attempt to lock this reentrantly. +pub unsafe fn lock() -> impl Drop { + static LOCK: StaticMutex = StaticMutex::new(); + LOCK.lock() +} + +/// Prints the current backtrace. +pub fn print(w: &mut dyn Write, format: PrintFmt) -> io::Result<()> { + // There are issues currently linking libbacktrace into tests, and in + // general during libstd's own unit tests we're not testing this path. In + // test mode immediately return here to optimize away any references to the + // libbacktrace symbols + if cfg!(test) { + return Ok(()); + } + + // Use a lock to prevent mixed output in multithreading context. + // Some platforms also requires it, like `SymFromAddr` on Windows. + unsafe { + let _lock = lock(); + _print(w, format) + } +} + +unsafe fn _print(w: &mut dyn Write, format: PrintFmt) -> io::Result<()> { + struct DisplayBacktrace { + format: PrintFmt, + } + impl fmt::Display for DisplayBacktrace { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + unsafe { _print_fmt(fmt, self.format) } + } + } + write!(w, "{}", DisplayBacktrace { format }) +} + +unsafe fn _print_fmt(fmt: &mut fmt::Formatter<'_>, print_fmt: PrintFmt) -> fmt::Result { + // Always 'fail' to get the cwd when running under Miri - + // this allows Miri to display backtraces in isolation mode + let cwd = if !cfg!(miri) { env::current_dir().ok() } else { None }; + + let mut print_path = move |fmt: &mut fmt::Formatter<'_>, bows: BytesOrWideString<'_>| { + output_filename(fmt, bows, print_fmt, cwd.as_ref()) + }; + writeln!(fmt, "stack backtrace:")?; + let mut bt_fmt = BacktraceFmt::new(fmt, print_fmt, &mut print_path); + bt_fmt.add_context()?; + let mut idx = 0; + let mut res = Ok(()); + // Start immediately if we're not using a short backtrace. + let mut start = print_fmt != PrintFmt::Short; + backtrace_rs::trace_unsynchronized(|frame| { + if print_fmt == PrintFmt::Short && idx > MAX_NB_FRAMES { + return false; + } + + let mut hit = false; + let mut stop = false; + backtrace_rs::resolve_frame_unsynchronized(frame, |symbol| { + hit = true; + if print_fmt == PrintFmt::Short { + if let Some(sym) = symbol.name().and_then(|s| s.as_str()) { + if start && sym.contains("__rust_begin_short_backtrace") { + stop = true; + return; + } + if sym.contains("__rust_end_short_backtrace") { + start = true; + return; + } + } + } + + if start { + res = bt_fmt.frame().symbol(frame, symbol); + } + }); + if stop { + return false; + } + if !hit && start { + res = bt_fmt.frame().print_raw(frame.ip(), None, None, None); + } + + idx += 1; + res.is_ok() + }); + res?; + bt_fmt.finish()?; + if print_fmt == PrintFmt::Short { + writeln!( + fmt, + "note: Some details are omitted, \ + run with `RUST_BACKTRACE=full` for a verbose backtrace." + )?; + } + Ok(()) +} + +/// Fixed frame used to clean the backtrace with `RUST_BACKTRACE=1`. Note that +/// this is only inline(never) when backtraces in libstd are enabled, otherwise +/// it's fine to optimize away. +#[cfg_attr(feature = "backtrace", inline(never))] +pub fn __rust_begin_short_backtrace<F, T>(f: F) -> T +where + F: FnOnce() -> T, +{ + let result = f(); + + // prevent this frame from being tail-call optimised away + crate::hint::black_box(()); + + result +} + +/// Fixed frame used to clean the backtrace with `RUST_BACKTRACE=1`. Note that +/// this is only inline(never) when backtraces in libstd are enabled, otherwise +/// it's fine to optimize away. +#[cfg_attr(feature = "backtrace", inline(never))] +pub fn __rust_end_short_backtrace<F, T>(f: F) -> T +where + F: FnOnce() -> T, +{ + let result = f(); + + // prevent this frame from being tail-call optimised away + crate::hint::black_box(()); + + result +} + +/// Prints the filename of the backtrace frame. +/// +/// See also `output`. +pub fn output_filename( + fmt: &mut fmt::Formatter<'_>, + bows: BytesOrWideString<'_>, + print_fmt: PrintFmt, + cwd: Option<&PathBuf>, +) -> fmt::Result { + let file: Cow<'_, Path> = match bows { + #[cfg(unix)] + BytesOrWideString::Bytes(bytes) => { + use crate::os::unix::prelude::*; + Path::new(crate::ffi::OsStr::from_bytes(bytes)).into() + } + #[cfg(not(unix))] + BytesOrWideString::Bytes(bytes) => { + Path::new(crate::str::from_utf8(bytes).unwrap_or("<unknown>")).into() + } + #[cfg(windows)] + BytesOrWideString::Wide(wide) => { + use crate::os::windows::prelude::*; + Cow::Owned(crate::ffi::OsString::from_wide(wide).into()) + } + #[cfg(not(windows))] + BytesOrWideString::Wide(_wide) => Path::new("<unknown>").into(), + }; + if print_fmt == PrintFmt::Short && file.is_absolute() { + if let Some(cwd) = cwd { + if let Ok(stripped) = file.strip_prefix(&cwd) { + if let Some(s) = stripped.to_str() { + return write!(fmt, ".{}{s}", path::MAIN_SEPARATOR); + } + } + } + } + fmt::Display::fmt(&file.display(), fmt) +} diff --git a/library/std/src/sys_common/condvar.rs b/library/std/src/sys_common/condvar.rs new file mode 100644 index 000000000..f3ac1061b --- /dev/null +++ b/library/std/src/sys_common/condvar.rs @@ -0,0 +1,56 @@ +use crate::sys::locks as imp; +use crate::sys_common::mutex::MovableMutex; +use crate::time::Duration; + +mod check; + +type CondvarCheck = <imp::MovableMutex as check::CondvarCheck>::Check; + +/// An OS-based condition variable. +pub struct Condvar { + inner: imp::MovableCondvar, + check: CondvarCheck, +} + +impl Condvar { + /// Creates a new condition variable for use. + #[inline] + pub const fn new() -> Self { + Self { inner: imp::MovableCondvar::new(), check: CondvarCheck::new() } + } + + /// Signals one waiter on this condition variable to wake up. + #[inline] + pub fn notify_one(&self) { + unsafe { self.inner.notify_one() }; + } + + /// Awakens all current waiters on this condition variable. + #[inline] + pub fn notify_all(&self) { + unsafe { self.inner.notify_all() }; + } + + /// Waits for a signal on the specified mutex. + /// + /// Behavior is undefined if the mutex is not locked by the current thread. + /// + /// May panic if used with more than one mutex. + #[inline] + pub unsafe fn wait(&self, mutex: &MovableMutex) { + self.check.verify(mutex); + self.inner.wait(mutex.raw()) + } + + /// Waits for a signal on the specified mutex with a timeout duration + /// specified by `dur` (a relative time into the future). + /// + /// Behavior is undefined if the mutex is not locked by the current thread. + /// + /// May panic if used with more than one mutex. + #[inline] + pub unsafe fn wait_timeout(&self, mutex: &MovableMutex, dur: Duration) -> bool { + self.check.verify(mutex); + self.inner.wait_timeout(mutex.raw(), dur) + } +} diff --git a/library/std/src/sys_common/condvar/check.rs b/library/std/src/sys_common/condvar/check.rs new file mode 100644 index 000000000..ce8f36704 --- /dev/null +++ b/library/std/src/sys_common/condvar/check.rs @@ -0,0 +1,57 @@ +use crate::ptr; +use crate::sync::atomic::{AtomicPtr, Ordering}; +use crate::sys::locks as imp; +use crate::sys_common::lazy_box::{LazyBox, LazyInit}; +use crate::sys_common::mutex::MovableMutex; + +pub trait CondvarCheck { + type Check; +} + +/// For boxed mutexes, a `Condvar` will check it's only ever used with the same +/// mutex, based on its (stable) address. +impl<T: LazyInit> CondvarCheck for LazyBox<T> { + type Check = SameMutexCheck; +} + +pub struct SameMutexCheck { + addr: AtomicPtr<()>, +} + +#[allow(dead_code)] +impl SameMutexCheck { + pub const fn new() -> Self { + Self { addr: AtomicPtr::new(ptr::null_mut()) } + } + pub fn verify(&self, mutex: &MovableMutex) { + let addr = mutex.raw() as *const imp::Mutex as *const () as *mut _; + // Relaxed is okay here because we never read through `self.addr`, and only use it to + // compare addresses. + match self.addr.compare_exchange( + ptr::null_mut(), + addr, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => {} // Stored the address + Err(n) if n == addr => {} // Lost a race to store the same address + _ => panic!("attempted to use a condition variable with two mutexes"), + } + } +} + +/// Unboxed mutexes may move, so `Condvar` can not require its address to stay +/// constant. +impl CondvarCheck for imp::Mutex { + type Check = NoCheck; +} + +pub struct NoCheck; + +#[allow(dead_code)] +impl NoCheck { + pub const fn new() -> Self { + Self + } + pub fn verify(&self, _: &MovableMutex) {} +} diff --git a/library/std/src/sys_common/fs.rs b/library/std/src/sys_common/fs.rs new file mode 100644 index 000000000..617ac52e5 --- /dev/null +++ b/library/std/src/sys_common/fs.rs @@ -0,0 +1,51 @@ +#![allow(dead_code)] // not used on all platforms + +use crate::fs; +use crate::io::{self, Error, ErrorKind}; +use crate::path::Path; + +pub(crate) const NOT_FILE_ERROR: Error = io::const_io_error!( + ErrorKind::InvalidInput, + "the source path is neither a regular file nor a symlink to a regular file", +); + +pub fn copy(from: &Path, to: &Path) -> io::Result<u64> { + let mut reader = fs::File::open(from)?; + let metadata = reader.metadata()?; + + if !metadata.is_file() { + return Err(NOT_FILE_ERROR); + } + + let mut writer = fs::File::create(to)?; + let perm = metadata.permissions(); + + let ret = io::copy(&mut reader, &mut writer)?; + writer.set_permissions(perm)?; + Ok(ret) +} + +pub fn remove_dir_all(path: &Path) -> io::Result<()> { + let filetype = fs::symlink_metadata(path)?.file_type(); + if filetype.is_symlink() { fs::remove_file(path) } else { remove_dir_all_recursive(path) } +} + +fn remove_dir_all_recursive(path: &Path) -> io::Result<()> { + for child in fs::read_dir(path)? { + let child = child?; + if child.file_type()?.is_dir() { + remove_dir_all_recursive(&child.path())?; + } else { + fs::remove_file(&child.path())?; + } + } + fs::remove_dir(path) +} + +pub fn try_exists(path: &Path) -> io::Result<bool> { + match fs::metadata(path) { + Ok(_) => Ok(true), + Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(false), + Err(error) => Err(error), + } +} diff --git a/library/std/src/sys_common/io.rs b/library/std/src/sys_common/io.rs new file mode 100644 index 000000000..d1e9fed41 --- /dev/null +++ b/library/std/src/sys_common/io.rs @@ -0,0 +1,49 @@ +// Bare metal platforms usually have very small amounts of RAM +// (in the order of hundreds of KB) +pub const DEFAULT_BUF_SIZE: usize = if cfg!(target_os = "espidf") { 512 } else { 8 * 1024 }; + +#[cfg(test)] +#[allow(dead_code)] // not used on emscripten +pub mod test { + use crate::env; + use crate::fs; + use crate::path::{Path, PathBuf}; + use crate::thread; + use rand::RngCore; + + pub struct TempDir(PathBuf); + + impl TempDir { + pub fn join(&self, path: &str) -> PathBuf { + let TempDir(ref p) = *self; + p.join(path) + } + + pub fn path(&self) -> &Path { + let TempDir(ref p) = *self; + p + } + } + + impl Drop for TempDir { + fn drop(&mut self) { + // Gee, seeing how we're testing the fs module I sure hope that we + // at least implement this correctly! + let TempDir(ref p) = *self; + let result = fs::remove_dir_all(p); + // Avoid panicking while panicking as this causes the process to + // immediately abort, without displaying test results. + if !thread::panicking() { + result.unwrap(); + } + } + } + + pub fn tmpdir() -> TempDir { + let p = env::temp_dir(); + let mut r = rand::thread_rng(); + let ret = p.join(&format!("rust-{}", r.next_u32())); + fs::create_dir(&ret).unwrap(); + TempDir(ret) + } +} diff --git a/library/std/src/sys_common/lazy_box.rs b/library/std/src/sys_common/lazy_box.rs new file mode 100644 index 000000000..63c3316bd --- /dev/null +++ b/library/std/src/sys_common/lazy_box.rs @@ -0,0 +1,90 @@ +#![allow(dead_code)] // Only used on some platforms. + +// This is used to wrap pthread {Mutex, Condvar, RwLock} in. + +use crate::marker::PhantomData; +use crate::ops::{Deref, DerefMut}; +use crate::ptr::null_mut; +use crate::sync::atomic::{ + AtomicPtr, + Ordering::{AcqRel, Acquire}, +}; + +pub(crate) struct LazyBox<T: LazyInit> { + ptr: AtomicPtr<T>, + _phantom: PhantomData<T>, +} + +pub(crate) trait LazyInit { + /// This is called before the box is allocated, to provide the value to + /// move into the new box. + /// + /// It might be called more than once per LazyBox, as multiple threads + /// might race to initialize it concurrently, each constructing and initializing + /// their own box. All but one of them will be passed to `cancel_init` right after. + fn init() -> Box<Self>; + + /// Any surplus boxes from `init()` that lost the initialization race + /// are passed to this function for disposal. + /// + /// The default implementation calls destroy(). + fn cancel_init(x: Box<Self>) { + Self::destroy(x); + } + + /// This is called to destroy a used box. + /// + /// The default implementation just drops it. + fn destroy(_: Box<Self>) {} +} + +impl<T: LazyInit> LazyBox<T> { + #[inline] + pub const fn new() -> Self { + Self { ptr: AtomicPtr::new(null_mut()), _phantom: PhantomData } + } + + #[inline] + fn get_pointer(&self) -> *mut T { + let ptr = self.ptr.load(Acquire); + if ptr.is_null() { self.initialize() } else { ptr } + } + + #[cold] + fn initialize(&self) -> *mut T { + let new_ptr = Box::into_raw(T::init()); + match self.ptr.compare_exchange(null_mut(), new_ptr, AcqRel, Acquire) { + Ok(_) => new_ptr, + Err(ptr) => { + // Lost the race to another thread. + // Drop the box we created, and use the one from the other thread instead. + T::cancel_init(unsafe { Box::from_raw(new_ptr) }); + ptr + } + } + } +} + +impl<T: LazyInit> Deref for LazyBox<T> { + type Target = T; + #[inline] + fn deref(&self) -> &T { + unsafe { &*self.get_pointer() } + } +} + +impl<T: LazyInit> DerefMut for LazyBox<T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.get_pointer() } + } +} + +impl<T: LazyInit> Drop for LazyBox<T> { + fn drop(&mut self) { + let ptr = *self.ptr.get_mut(); + if !ptr.is_null() { + T::destroy(unsafe { Box::from_raw(ptr) }); + } + } +} diff --git a/library/std/src/sys_common/memchr.rs b/library/std/src/sys_common/memchr.rs new file mode 100644 index 000000000..b219e8789 --- /dev/null +++ b/library/std/src/sys_common/memchr.rs @@ -0,0 +1,51 @@ +// Original implementation taken from rust-memchr. +// Copyright 2015 Andrew Gallant, bluss and Nicolas Koch + +use crate::sys::memchr as sys; + +#[cfg(test)] +mod tests; + +/// A safe interface to `memchr`. +/// +/// Returns the index corresponding to the first occurrence of `needle` in +/// `haystack`, or `None` if one is not found. +/// +/// memchr reduces to super-optimized machine code at around an order of +/// magnitude faster than `haystack.iter().position(|&b| b == needle)`. +/// (See benchmarks.) +/// +/// # Examples +/// +/// This shows how to find the first position of a byte in a byte string. +/// +/// ```ignore (cannot-doctest-private-modules) +/// use memchr::memchr; +/// +/// let haystack = b"the quick brown fox"; +/// assert_eq!(memchr(b'k', haystack), Some(8)); +/// ``` +#[inline] +pub fn memchr(needle: u8, haystack: &[u8]) -> Option<usize> { + sys::memchr(needle, haystack) +} + +/// A safe interface to `memrchr`. +/// +/// Returns the index corresponding to the last occurrence of `needle` in +/// `haystack`, or `None` if one is not found. +/// +/// # Examples +/// +/// This shows how to find the last position of a byte in a byte string. +/// +/// ```ignore (cannot-doctest-private-modules) +/// use memchr::memrchr; +/// +/// let haystack = b"the quick brown fox"; +/// assert_eq!(memrchr(b'o', haystack), Some(17)); +/// ``` +#[inline] +pub fn memrchr(needle: u8, haystack: &[u8]) -> Option<usize> { + sys::memrchr(needle, haystack) +} diff --git a/library/std/src/sys_common/memchr/tests.rs b/library/std/src/sys_common/memchr/tests.rs new file mode 100644 index 000000000..557d749c7 --- /dev/null +++ b/library/std/src/sys_common/memchr/tests.rs @@ -0,0 +1,86 @@ +// Original implementation taken from rust-memchr. +// Copyright 2015 Andrew Gallant, bluss and Nicolas Koch + +// test the implementations for the current platform +use super::{memchr, memrchr}; + +#[test] +fn matches_one() { + assert_eq!(Some(0), memchr(b'a', b"a")); +} + +#[test] +fn matches_begin() { + assert_eq!(Some(0), memchr(b'a', b"aaaa")); +} + +#[test] +fn matches_end() { + assert_eq!(Some(4), memchr(b'z', b"aaaaz")); +} + +#[test] +fn matches_nul() { + assert_eq!(Some(4), memchr(b'\x00', b"aaaa\x00")); +} + +#[test] +fn matches_past_nul() { + assert_eq!(Some(5), memchr(b'z', b"aaaa\x00z")); +} + +#[test] +fn no_match_empty() { + assert_eq!(None, memchr(b'a', b"")); +} + +#[test] +fn no_match() { + assert_eq!(None, memchr(b'a', b"xyz")); +} + +#[test] +fn matches_one_reversed() { + assert_eq!(Some(0), memrchr(b'a', b"a")); +} + +#[test] +fn matches_begin_reversed() { + assert_eq!(Some(3), memrchr(b'a', b"aaaa")); +} + +#[test] +fn matches_end_reversed() { + assert_eq!(Some(0), memrchr(b'z', b"zaaaa")); +} + +#[test] +fn matches_nul_reversed() { + assert_eq!(Some(4), memrchr(b'\x00', b"aaaa\x00")); +} + +#[test] +fn matches_past_nul_reversed() { + assert_eq!(Some(0), memrchr(b'z', b"z\x00aaaa")); +} + +#[test] +fn no_match_empty_reversed() { + assert_eq!(None, memrchr(b'a', b"")); +} + +#[test] +fn no_match_reversed() { + assert_eq!(None, memrchr(b'a', b"xyz")); +} + +#[test] +fn each_alignment() { + let mut data = [1u8; 64]; + let needle = 2; + let pos = 40; + data[pos] = needle; + for start in 0..16 { + assert_eq!(Some(pos - start), memchr(needle, &data[start..])); + } +} diff --git a/library/std/src/sys_common/mod.rs b/library/std/src/sys_common/mod.rs new file mode 100644 index 000000000..80f56bf75 --- /dev/null +++ b/library/std/src/sys_common/mod.rs @@ -0,0 +1,89 @@ +//! Platform-independent platform abstraction +//! +//! This is the platform-independent portion of the standard library's +//! platform abstraction layer, whereas `std::sys` is the +//! platform-specific portion. +//! +//! The relationship between `std::sys_common`, `std::sys` and the +//! rest of `std` is complex, with dependencies going in all +//! directions: `std` depending on `sys_common`, `sys_common` +//! depending on `sys`, and `sys` depending on `sys_common` and `std`. +//! This is because `sys_common` not only contains platform-independent code, +//! but also code that is shared between the different platforms in `sys`. +//! Ideally all that shared code should be moved to `sys::common`, +//! and the dependencies between `std`, `sys_common` and `sys` all would form a dag. +//! Progress on this is tracked in #84187. + +#![allow(missing_docs)] +#![allow(missing_debug_implementations)] + +#[cfg(test)] +mod tests; + +pub mod backtrace; +pub mod condvar; +pub mod fs; +pub mod io; +pub mod lazy_box; +pub mod memchr; +pub mod mutex; +pub mod process; +pub mod remutex; +pub mod rwlock; +pub mod thread; +pub mod thread_info; +pub mod thread_local_dtor; +pub mod thread_local_key; +pub mod thread_parker; +pub mod wtf8; + +cfg_if::cfg_if! { + if #[cfg(any(target_os = "l4re", + target_os = "hermit", + feature = "restricted-std", + all(target_family = "wasm", not(target_os = "emscripten")), + all(target_vendor = "fortanix", target_env = "sgx")))] { + pub use crate::sys::net; + } else { + pub mod net; + } +} + +// common error constructors + +/// A trait for viewing representations from std types +#[doc(hidden)] +pub trait AsInner<Inner: ?Sized> { + fn as_inner(&self) -> &Inner; +} + +/// A trait for viewing representations from std types +#[doc(hidden)] +pub trait AsInnerMut<Inner: ?Sized> { + fn as_inner_mut(&mut self) -> &mut Inner; +} + +/// A trait for extracting representations from std types +#[doc(hidden)] +pub trait IntoInner<Inner> { + fn into_inner(self) -> Inner; +} + +/// A trait for creating std types from internal representations +#[doc(hidden)] +pub trait FromInner<Inner> { + fn from_inner(inner: Inner) -> Self; +} + +// Computes (value*numer)/denom without overflow, as long as both +// (numer*denom) and the overall result fit into i64 (which is the case +// for our time conversions). +#[allow(dead_code)] // not used on all platforms +pub fn mul_div_u64(value: u64, numer: u64, denom: u64) -> u64 { + let q = value / denom; + let r = value % denom; + // Decompose value as (value/denom*denom + value%denom), + // substitute into (value*numer)/denom and simplify. + // r < denom, so (denom*numer) is the upper bound of (r*numer) + q * numer + r * numer / denom +} diff --git a/library/std/src/sys_common/mutex.rs b/library/std/src/sys_common/mutex.rs new file mode 100644 index 000000000..48479f5bd --- /dev/null +++ b/library/std/src/sys_common/mutex.rs @@ -0,0 +1,93 @@ +use crate::sys::locks as imp; + +/// An OS-based mutual exclusion lock, meant for use in static variables. +/// +/// This mutex has a const constructor ([`StaticMutex::new`]), does not +/// implement `Drop` to cleanup resources, and causes UB when used reentrantly. +/// +/// This mutex does not implement poisoning. +/// +/// This is a wrapper around `imp::Mutex` that does *not* call `init()` and +/// `destroy()`. +pub struct StaticMutex(imp::Mutex); + +unsafe impl Sync for StaticMutex {} + +impl StaticMutex { + /// Creates a new mutex for use. + #[inline] + pub const fn new() -> Self { + Self(imp::Mutex::new()) + } + + /// Calls raw_lock() and then returns an RAII guard to guarantee the mutex + /// will be unlocked. + /// + /// It is undefined behaviour to call this function while locked by the + /// same thread. + #[inline] + pub unsafe fn lock(&'static self) -> StaticMutexGuard { + self.0.lock(); + StaticMutexGuard(&self.0) + } +} + +#[must_use] +pub struct StaticMutexGuard(&'static imp::Mutex); + +impl Drop for StaticMutexGuard { + #[inline] + fn drop(&mut self) { + unsafe { + self.0.unlock(); + } + } +} + +/// An OS-based mutual exclusion lock. +/// +/// This mutex cleans up its resources in its `Drop` implementation, may safely +/// be moved (when not borrowed), and does not cause UB when used reentrantly. +/// +/// This mutex does not implement poisoning. +/// +/// This is either a wrapper around `LazyBox<imp::Mutex>` or `imp::Mutex`, +/// depending on the platform. It is boxed on platforms where `imp::Mutex` may +/// not be moved. +pub struct MovableMutex(imp::MovableMutex); + +unsafe impl Sync for MovableMutex {} + +impl MovableMutex { + /// Creates a new mutex. + #[inline] + pub const fn new() -> Self { + Self(imp::MovableMutex::new()) + } + + pub(super) fn raw(&self) -> &imp::Mutex { + &self.0 + } + + /// Locks the mutex blocking the current thread until it is available. + #[inline] + pub fn raw_lock(&self) { + unsafe { self.0.lock() } + } + + /// Attempts to lock the mutex without blocking, returning whether it was + /// successfully acquired or not. + #[inline] + pub fn try_lock(&self) -> bool { + unsafe { self.0.try_lock() } + } + + /// Unlocks the mutex. + /// + /// Behavior is undefined if the current thread does not actually hold the + /// mutex. + #[inline] + pub unsafe fn raw_unlock(&self) { + self.0.unlock() + } +} diff --git a/library/std/src/sys_common/net.rs b/library/std/src/sys_common/net.rs new file mode 100644 index 000000000..33d336c43 --- /dev/null +++ b/library/std/src/sys_common/net.rs @@ -0,0 +1,737 @@ +#[cfg(test)] +mod tests; + +use crate::cmp; +use crate::ffi::CString; +use crate::fmt; +use crate::io::{self, ErrorKind, IoSlice, IoSliceMut}; +use crate::mem; +use crate::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}; +use crate::ptr; +use crate::sys::net::netc as c; +use crate::sys::net::{cvt, cvt_gai, cvt_r, init, wrlen_t, Socket}; +use crate::sys_common::{FromInner, IntoInner}; +use crate::time::Duration; + +use libc::{c_int, c_void}; + +cfg_if::cfg_if! { + if #[cfg(any( + target_os = "dragonfly", target_os = "freebsd", + target_os = "ios", target_os = "macos", target_os = "watchos", + target_os = "openbsd", target_os = "netbsd", target_os = "illumos", + target_os = "solaris", target_os = "haiku", target_os = "l4re"))] { + use crate::sys::net::netc::IPV6_JOIN_GROUP as IPV6_ADD_MEMBERSHIP; + use crate::sys::net::netc::IPV6_LEAVE_GROUP as IPV6_DROP_MEMBERSHIP; + } else { + use crate::sys::net::netc::IPV6_ADD_MEMBERSHIP; + use crate::sys::net::netc::IPV6_DROP_MEMBERSHIP; + } +} + +cfg_if::cfg_if! { + if #[cfg(any( + target_os = "linux", target_os = "android", + target_os = "dragonfly", target_os = "freebsd", + target_os = "openbsd", target_os = "netbsd", + target_os = "haiku"))] { + use libc::MSG_NOSIGNAL; + } else { + const MSG_NOSIGNAL: c_int = 0x0; + } +} + +cfg_if::cfg_if! { + if #[cfg(any( + target_os = "dragonfly", target_os = "freebsd", + target_os = "openbsd", target_os = "netbsd", + target_os = "solaris", target_os = "illumos"))] { + use libc::c_uchar; + type IpV4MultiCastType = c_uchar; + } else { + type IpV4MultiCastType = c_int; + } +} + +//////////////////////////////////////////////////////////////////////////////// +// sockaddr and misc bindings +//////////////////////////////////////////////////////////////////////////////// + +pub fn setsockopt<T>( + sock: &Socket, + level: c_int, + option_name: c_int, + option_value: T, +) -> io::Result<()> { + unsafe { + cvt(c::setsockopt( + sock.as_raw(), + level, + option_name, + &option_value as *const T as *const _, + mem::size_of::<T>() as c::socklen_t, + ))?; + Ok(()) + } +} + +pub fn getsockopt<T: Copy>(sock: &Socket, level: c_int, option_name: c_int) -> io::Result<T> { + unsafe { + let mut option_value: T = mem::zeroed(); + let mut option_len = mem::size_of::<T>() as c::socklen_t; + cvt(c::getsockopt( + sock.as_raw(), + level, + option_name, + &mut option_value as *mut T as *mut _, + &mut option_len, + ))?; + Ok(option_value) + } +} + +fn sockname<F>(f: F) -> io::Result<SocketAddr> +where + F: FnOnce(*mut c::sockaddr, *mut c::socklen_t) -> c_int, +{ + unsafe { + let mut storage: c::sockaddr_storage = mem::zeroed(); + let mut len = mem::size_of_val(&storage) as c::socklen_t; + cvt(f(&mut storage as *mut _ as *mut _, &mut len))?; + sockaddr_to_addr(&storage, len as usize) + } +} + +pub fn sockaddr_to_addr(storage: &c::sockaddr_storage, len: usize) -> io::Result<SocketAddr> { + match storage.ss_family as c_int { + c::AF_INET => { + assert!(len as usize >= mem::size_of::<c::sockaddr_in>()); + Ok(SocketAddr::V4(FromInner::from_inner(unsafe { + *(storage as *const _ as *const c::sockaddr_in) + }))) + } + c::AF_INET6 => { + assert!(len as usize >= mem::size_of::<c::sockaddr_in6>()); + Ok(SocketAddr::V6(FromInner::from_inner(unsafe { + *(storage as *const _ as *const c::sockaddr_in6) + }))) + } + _ => Err(io::const_io_error!(ErrorKind::InvalidInput, "invalid argument")), + } +} + +#[cfg(target_os = "android")] +fn to_ipv6mr_interface(value: u32) -> c_int { + value as c_int +} + +#[cfg(not(target_os = "android"))] +fn to_ipv6mr_interface(value: u32) -> libc::c_uint { + value as libc::c_uint +} + +//////////////////////////////////////////////////////////////////////////////// +// get_host_addresses +//////////////////////////////////////////////////////////////////////////////// + +pub struct LookupHost { + original: *mut c::addrinfo, + cur: *mut c::addrinfo, + port: u16, +} + +impl LookupHost { + pub fn port(&self) -> u16 { + self.port + } +} + +impl Iterator for LookupHost { + type Item = SocketAddr; + fn next(&mut self) -> Option<SocketAddr> { + loop { + unsafe { + let cur = self.cur.as_ref()?; + self.cur = cur.ai_next; + match sockaddr_to_addr(mem::transmute(cur.ai_addr), cur.ai_addrlen as usize) { + Ok(addr) => return Some(addr), + Err(_) => continue, + } + } + } + } +} + +unsafe impl Sync for LookupHost {} +unsafe impl Send for LookupHost {} + +impl Drop for LookupHost { + fn drop(&mut self) { + unsafe { c::freeaddrinfo(self.original) } + } +} + +impl TryFrom<&str> for LookupHost { + type Error = io::Error; + + fn try_from(s: &str) -> io::Result<LookupHost> { + macro_rules! try_opt { + ($e:expr, $msg:expr) => { + match $e { + Some(r) => r, + None => return Err(io::const_io_error!(io::ErrorKind::InvalidInput, $msg)), + } + }; + } + + // split the string by ':' and convert the second part to u16 + let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address"); + let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value"); + (host, port).try_into() + } +} + +impl<'a> TryFrom<(&'a str, u16)> for LookupHost { + type Error = io::Error; + + fn try_from((host, port): (&'a str, u16)) -> io::Result<LookupHost> { + init(); + + let c_host = CString::new(host)?; + let mut hints: c::addrinfo = unsafe { mem::zeroed() }; + hints.ai_socktype = c::SOCK_STREAM; + let mut res = ptr::null_mut(); + unsafe { + cvt_gai(c::getaddrinfo(c_host.as_ptr(), ptr::null(), &hints, &mut res)) + .map(|_| LookupHost { original: res, cur: res, port }) + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// TCP streams +//////////////////////////////////////////////////////////////////////////////// + +pub struct TcpStream { + inner: Socket, +} + +impl TcpStream { + pub fn connect(addr: io::Result<&SocketAddr>) -> io::Result<TcpStream> { + let addr = addr?; + + init(); + + let sock = Socket::new(addr, c::SOCK_STREAM)?; + + let (addr, len) = addr.into_inner(); + cvt_r(|| unsafe { c::connect(sock.as_raw(), addr.as_ptr(), len) })?; + Ok(TcpStream { inner: sock }) + } + + pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> { + init(); + + let sock = Socket::new(addr, c::SOCK_STREAM)?; + sock.connect_timeout(addr, timeout)?; + Ok(TcpStream { inner: sock }) + } + + pub fn socket(&self) -> &Socket { + &self.inner + } + + pub fn into_socket(self) -> Socket { + self.inner + } + + pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { + self.inner.set_timeout(dur, c::SO_RCVTIMEO) + } + + pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { + self.inner.set_timeout(dur, c::SO_SNDTIMEO) + } + + pub fn read_timeout(&self) -> io::Result<Option<Duration>> { + self.inner.timeout(c::SO_RCVTIMEO) + } + + pub fn write_timeout(&self) -> io::Result<Option<Duration>> { + self.inner.timeout(c::SO_SNDTIMEO) + } + + pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.peek(buf) + } + + pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.read(buf) + } + + pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> { + self.inner.read_vectored(bufs) + } + + #[inline] + pub fn is_read_vectored(&self) -> bool { + self.inner.is_read_vectored() + } + + pub fn write(&self, buf: &[u8]) -> io::Result<usize> { + let len = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t; + let ret = cvt(unsafe { + c::send(self.inner.as_raw(), buf.as_ptr() as *const c_void, len, MSG_NOSIGNAL) + })?; + Ok(ret as usize) + } + + pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { + self.inner.write_vectored(bufs) + } + + #[inline] + pub fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + sockname(|buf, len| unsafe { c::getpeername(self.inner.as_raw(), buf, len) }) + } + + pub fn socket_addr(&self) -> io::Result<SocketAddr> { + sockname(|buf, len| unsafe { c::getsockname(self.inner.as_raw(), buf, len) }) + } + + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.inner.shutdown(how) + } + + pub fn duplicate(&self) -> io::Result<TcpStream> { + self.inner.duplicate().map(|s| TcpStream { inner: s }) + } + + pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> { + self.inner.set_linger(linger) + } + + pub fn linger(&self) -> io::Result<Option<Duration>> { + self.inner.linger() + } + + pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { + self.inner.set_nodelay(nodelay) + } + + pub fn nodelay(&self) -> io::Result<bool> { + self.inner.nodelay() + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) + } + + pub fn ttl(&self) -> io::Result<u32> { + let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)?; + Ok(raw as u32) + } + + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.inner.take_error() + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } +} + +impl FromInner<Socket> for TcpStream { + fn from_inner(socket: Socket) -> TcpStream { + TcpStream { inner: socket } + } +} + +impl fmt::Debug for TcpStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut res = f.debug_struct("TcpStream"); + + if let Ok(addr) = self.socket_addr() { + res.field("addr", &addr); + } + + if let Ok(peer) = self.peer_addr() { + res.field("peer", &peer); + } + + let name = if cfg!(windows) { "socket" } else { "fd" }; + res.field(name, &self.inner.as_raw()).finish() + } +} + +//////////////////////////////////////////////////////////////////////////////// +// TCP listeners +//////////////////////////////////////////////////////////////////////////////// + +pub struct TcpListener { + inner: Socket, +} + +impl TcpListener { + pub fn bind(addr: io::Result<&SocketAddr>) -> io::Result<TcpListener> { + let addr = addr?; + + init(); + + let sock = Socket::new(addr, c::SOCK_STREAM)?; + + // On platforms with Berkeley-derived sockets, this allows to quickly + // rebind a socket, without needing to wait for the OS to clean up the + // previous one. + // + // On Windows, this allows rebinding sockets which are actively in use, + // which allows “socket hijacking”, so we explicitly don't set it here. + // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + #[cfg(not(windows))] + setsockopt(&sock, c::SOL_SOCKET, c::SO_REUSEADDR, 1 as c_int)?; + + // Bind our new socket + let (addr, len) = addr.into_inner(); + cvt(unsafe { c::bind(sock.as_raw(), addr.as_ptr(), len as _) })?; + + cfg_if::cfg_if! { + if #[cfg(target_os = "horizon")] { + // The 3DS doesn't support a big connection backlog. Sometimes + // it allows up to about 37, but other times it doesn't even + // accept 32. There may be a global limitation causing this. + let backlog = 20; + } else { + // The default for all other platforms + let backlog = 128; + } + } + + // Start listening + cvt(unsafe { c::listen(sock.as_raw(), backlog) })?; + Ok(TcpListener { inner: sock }) + } + + pub fn socket(&self) -> &Socket { + &self.inner + } + + pub fn into_socket(self) -> Socket { + self.inner + } + + pub fn socket_addr(&self) -> io::Result<SocketAddr> { + sockname(|buf, len| unsafe { c::getsockname(self.inner.as_raw(), buf, len) }) + } + + pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() }; + let mut len = mem::size_of_val(&storage) as c::socklen_t; + let sock = self.inner.accept(&mut storage as *mut _ as *mut _, &mut len)?; + let addr = sockaddr_to_addr(&storage, len as usize)?; + Ok((TcpStream { inner: sock }, addr)) + } + + pub fn duplicate(&self) -> io::Result<TcpListener> { + self.inner.duplicate().map(|s| TcpListener { inner: s }) + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) + } + + pub fn ttl(&self) -> io::Result<u32> { + let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)?; + Ok(raw as u32) + } + + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int) + } + + pub fn only_v6(&self) -> io::Result<bool> { + let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY)?; + Ok(raw != 0) + } + + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.inner.take_error() + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } +} + +impl FromInner<Socket> for TcpListener { + fn from_inner(socket: Socket) -> TcpListener { + TcpListener { inner: socket } + } +} + +impl fmt::Debug for TcpListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut res = f.debug_struct("TcpListener"); + + if let Ok(addr) = self.socket_addr() { + res.field("addr", &addr); + } + + let name = if cfg!(windows) { "socket" } else { "fd" }; + res.field(name, &self.inner.as_raw()).finish() + } +} + +//////////////////////////////////////////////////////////////////////////////// +// UDP +//////////////////////////////////////////////////////////////////////////////// + +pub struct UdpSocket { + inner: Socket, +} + +impl UdpSocket { + pub fn bind(addr: io::Result<&SocketAddr>) -> io::Result<UdpSocket> { + let addr = addr?; + + init(); + + let sock = Socket::new(addr, c::SOCK_DGRAM)?; + let (addr, len) = addr.into_inner(); + cvt(unsafe { c::bind(sock.as_raw(), addr.as_ptr(), len as _) })?; + Ok(UdpSocket { inner: sock }) + } + + pub fn socket(&self) -> &Socket { + &self.inner + } + + pub fn into_socket(self) -> Socket { + self.inner + } + + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + sockname(|buf, len| unsafe { c::getpeername(self.inner.as_raw(), buf, len) }) + } + + pub fn socket_addr(&self) -> io::Result<SocketAddr> { + sockname(|buf, len| unsafe { c::getsockname(self.inner.as_raw(), buf, len) }) + } + + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.inner.recv_from(buf) + } + + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.inner.peek_from(buf) + } + + pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> { + let len = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t; + let (dst, dstlen) = dst.into_inner(); + let ret = cvt(unsafe { + c::sendto( + self.inner.as_raw(), + buf.as_ptr() as *const c_void, + len, + MSG_NOSIGNAL, + dst.as_ptr(), + dstlen, + ) + })?; + Ok(ret as usize) + } + + pub fn duplicate(&self) -> io::Result<UdpSocket> { + self.inner.duplicate().map(|s| UdpSocket { inner: s }) + } + + pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { + self.inner.set_timeout(dur, c::SO_RCVTIMEO) + } + + pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { + self.inner.set_timeout(dur, c::SO_SNDTIMEO) + } + + pub fn read_timeout(&self) -> io::Result<Option<Duration>> { + self.inner.timeout(c::SO_RCVTIMEO) + } + + pub fn write_timeout(&self) -> io::Result<Option<Duration>> { + self.inner.timeout(c::SO_SNDTIMEO) + } + + pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> { + setsockopt(&self.inner, c::SOL_SOCKET, c::SO_BROADCAST, broadcast as c_int) + } + + pub fn broadcast(&self) -> io::Result<bool> { + let raw: c_int = getsockopt(&self.inner, c::SOL_SOCKET, c::SO_BROADCAST)?; + Ok(raw != 0) + } + + pub fn set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()> { + setsockopt( + &self.inner, + c::IPPROTO_IP, + c::IP_MULTICAST_LOOP, + multicast_loop_v4 as IpV4MultiCastType, + ) + } + + pub fn multicast_loop_v4(&self) -> io::Result<bool> { + let raw: IpV4MultiCastType = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_MULTICAST_LOOP)?; + Ok(raw != 0) + } + + pub fn set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()> { + setsockopt( + &self.inner, + c::IPPROTO_IP, + c::IP_MULTICAST_TTL, + multicast_ttl_v4 as IpV4MultiCastType, + ) + } + + pub fn multicast_ttl_v4(&self) -> io::Result<u32> { + let raw: IpV4MultiCastType = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_MULTICAST_TTL)?; + Ok(raw as u32) + } + + pub fn set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_MULTICAST_LOOP, multicast_loop_v6 as c_int) + } + + pub fn multicast_loop_v6(&self) -> io::Result<bool> { + let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_MULTICAST_LOOP)?; + Ok(raw != 0) + } + + pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> { + let mreq = c::ip_mreq { + imr_multiaddr: multiaddr.into_inner(), + imr_interface: interface.into_inner(), + }; + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_ADD_MEMBERSHIP, mreq) + } + + pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + let mreq = c::ipv6_mreq { + ipv6mr_multiaddr: multiaddr.into_inner(), + ipv6mr_interface: to_ipv6mr_interface(interface), + }; + setsockopt(&self.inner, c::IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, mreq) + } + + pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> { + let mreq = c::ip_mreq { + imr_multiaddr: multiaddr.into_inner(), + imr_interface: interface.into_inner(), + }; + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_DROP_MEMBERSHIP, mreq) + } + + pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + let mreq = c::ipv6_mreq { + ipv6mr_multiaddr: multiaddr.into_inner(), + ipv6mr_interface: to_ipv6mr_interface(interface), + }; + setsockopt(&self.inner, c::IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP, mreq) + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) + } + + pub fn ttl(&self) -> io::Result<u32> { + let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)?; + Ok(raw as u32) + } + + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.inner.take_error() + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } + + pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.read(buf) + } + + pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.peek(buf) + } + + pub fn send(&self, buf: &[u8]) -> io::Result<usize> { + let len = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t; + let ret = cvt(unsafe { + c::send(self.inner.as_raw(), buf.as_ptr() as *const c_void, len, MSG_NOSIGNAL) + })?; + Ok(ret as usize) + } + + pub fn connect(&self, addr: io::Result<&SocketAddr>) -> io::Result<()> { + let (addr, len) = addr?.into_inner(); + cvt_r(|| unsafe { c::connect(self.inner.as_raw(), addr.as_ptr(), len) }).map(drop) + } +} + +impl FromInner<Socket> for UdpSocket { + fn from_inner(socket: Socket) -> UdpSocket { + UdpSocket { inner: socket } + } +} + +impl fmt::Debug for UdpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut res = f.debug_struct("UdpSocket"); + + if let Ok(addr) = self.socket_addr() { + res.field("addr", &addr); + } + + let name = if cfg!(windows) { "socket" } else { "fd" }; + res.field(name, &self.inner.as_raw()).finish() + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Converting SocketAddr to libc representation +//////////////////////////////////////////////////////////////////////////////// + +/// A type with the same memory layout as `c::sockaddr`. Used in converting Rust level +/// SocketAddr* types into their system representation. The benefit of this specific +/// type over using `c::sockaddr_storage` is that this type is exactly as large as it +/// needs to be and not a lot larger. And it can be initialized more cleanly from Rust. +#[repr(C)] +pub(crate) union SocketAddrCRepr { + v4: c::sockaddr_in, + v6: c::sockaddr_in6, +} + +impl SocketAddrCRepr { + pub fn as_ptr(&self) -> *const c::sockaddr { + self as *const _ as *const c::sockaddr + } +} + +impl<'a> IntoInner<(SocketAddrCRepr, c::socklen_t)> for &'a SocketAddr { + fn into_inner(self) -> (SocketAddrCRepr, c::socklen_t) { + match *self { + SocketAddr::V4(ref a) => { + let sockaddr = SocketAddrCRepr { v4: a.into_inner() }; + (sockaddr, mem::size_of::<c::sockaddr_in>() as c::socklen_t) + } + SocketAddr::V6(ref a) => { + let sockaddr = SocketAddrCRepr { v6: a.into_inner() }; + (sockaddr, mem::size_of::<c::sockaddr_in6>() as c::socklen_t) + } + } + } +} diff --git a/library/std/src/sys_common/net/tests.rs b/library/std/src/sys_common/net/tests.rs new file mode 100644 index 000000000..ac75d9ebf --- /dev/null +++ b/library/std/src/sys_common/net/tests.rs @@ -0,0 +1,19 @@ +use super::*; +use crate::collections::HashMap; + +#[test] +fn no_lookup_host_duplicates() { + let mut addrs = HashMap::new(); + let lh = match LookupHost::try_from(("localhost", 0)) { + Ok(lh) => lh, + Err(e) => panic!("couldn't resolve `localhost': {e}"), + }; + for sa in lh { + *addrs.entry(sa).or_insert(0) += 1; + } + assert_eq!( + addrs.iter().filter(|&(_, &v)| v > 1).collect::<Vec<_>>(), + vec![], + "There should be no duplicate localhost entries" + ); +} diff --git a/library/std/src/sys_common/process.rs b/library/std/src/sys_common/process.rs new file mode 100644 index 000000000..9f978789a --- /dev/null +++ b/library/std/src/sys_common/process.rs @@ -0,0 +1,119 @@ +#![allow(dead_code)] +#![unstable(feature = "process_internals", issue = "none")] + +use crate::collections::BTreeMap; +use crate::env; +use crate::ffi::{OsStr, OsString}; +use crate::sys::process::EnvKey; + +// Stores a set of changes to an environment +#[derive(Clone, Debug)] +pub struct CommandEnv { + clear: bool, + saw_path: bool, + vars: BTreeMap<EnvKey, Option<OsString>>, +} + +impl Default for CommandEnv { + fn default() -> Self { + CommandEnv { clear: false, saw_path: false, vars: Default::default() } + } +} + +impl CommandEnv { + // Capture the current environment with these changes applied + pub fn capture(&self) -> BTreeMap<EnvKey, OsString> { + let mut result = BTreeMap::<EnvKey, OsString>::new(); + if !self.clear { + for (k, v) in env::vars_os() { + result.insert(k.into(), v); + } + } + for (k, maybe_v) in &self.vars { + if let &Some(ref v) = maybe_v { + result.insert(k.clone(), v.clone()); + } else { + result.remove(k); + } + } + result + } + + pub fn is_unchanged(&self) -> bool { + !self.clear && self.vars.is_empty() + } + + pub fn capture_if_changed(&self) -> Option<BTreeMap<EnvKey, OsString>> { + if self.is_unchanged() { None } else { Some(self.capture()) } + } + + // The following functions build up changes + pub fn set(&mut self, key: &OsStr, value: &OsStr) { + let key = EnvKey::from(key); + self.maybe_saw_path(&key); + self.vars.insert(key, Some(value.to_owned())); + } + + pub fn remove(&mut self, key: &OsStr) { + let key = EnvKey::from(key); + self.maybe_saw_path(&key); + if self.clear { + self.vars.remove(&key); + } else { + self.vars.insert(key, None); + } + } + + pub fn clear(&mut self) { + self.clear = true; + self.vars.clear(); + } + + pub fn have_changed_path(&self) -> bool { + self.saw_path || self.clear + } + + fn maybe_saw_path(&mut self, key: &EnvKey) { + if !self.saw_path && key == "PATH" { + self.saw_path = true; + } + } + + pub fn iter(&self) -> CommandEnvs<'_> { + let iter = self.vars.iter(); + CommandEnvs { iter } + } +} + +/// An iterator over the command environment variables. +/// +/// This struct is created by +/// [`Command::get_envs`][crate::process::Command::get_envs]. See its +/// documentation for more. +#[must_use = "iterators are lazy and do nothing unless consumed"] +#[stable(feature = "command_access", since = "1.57.0")] +#[derive(Debug)] +pub struct CommandEnvs<'a> { + iter: crate::collections::btree_map::Iter<'a, EnvKey, Option<OsString>>, +} + +#[stable(feature = "command_access", since = "1.57.0")] +impl<'a> Iterator for CommandEnvs<'a> { + type Item = (&'a OsStr, Option<&'a OsStr>); + fn next(&mut self) -> Option<Self::Item> { + self.iter.next().map(|(key, value)| (key.as_ref(), value.as_deref())) + } + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +#[stable(feature = "command_access", since = "1.57.0")] +impl<'a> ExactSizeIterator for CommandEnvs<'a> { + fn len(&self) -> usize { + self.iter.len() + } + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} diff --git a/library/std/src/sys_common/remutex.rs b/library/std/src/sys_common/remutex.rs new file mode 100644 index 000000000..8921af311 --- /dev/null +++ b/library/std/src/sys_common/remutex.rs @@ -0,0 +1,200 @@ +#[cfg(all(test, not(target_os = "emscripten")))] +mod tests; + +use crate::cell::UnsafeCell; +use crate::marker::PhantomPinned; +use crate::ops::Deref; +use crate::panic::{RefUnwindSafe, UnwindSafe}; +use crate::pin::Pin; +use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use crate::sys::locks as sys; + +/// A re-entrant mutual exclusion +/// +/// This mutex will block *other* threads waiting for the lock to become +/// available. The thread which has already locked the mutex can lock it +/// multiple times without blocking, preventing a common source of deadlocks. +/// +/// This is used by stdout().lock() and friends. +/// +/// ## Implementation details +/// +/// The 'owner' field tracks which thread has locked the mutex. +/// +/// We use current_thread_unique_ptr() as the thread identifier, +/// which is just the address of a thread local variable. +/// +/// If `owner` is set to the identifier of the current thread, +/// we assume the mutex is already locked and instead of locking it again, +/// we increment `lock_count`. +/// +/// When unlocking, we decrement `lock_count`, and only unlock the mutex when +/// it reaches zero. +/// +/// `lock_count` is protected by the mutex and only accessed by the thread that has +/// locked the mutex, so needs no synchronization. +/// +/// `owner` can be checked by other threads that want to see if they already +/// hold the lock, so needs to be atomic. If it compares equal, we're on the +/// same thread that holds the mutex and memory access can use relaxed ordering +/// since we're not dealing with multiple threads. If it compares unequal, +/// synchronization is left to the mutex, making relaxed memory ordering for +/// the `owner` field fine in all cases. +pub struct ReentrantMutex<T> { + mutex: sys::Mutex, + owner: AtomicUsize, + lock_count: UnsafeCell<u32>, + data: T, + _pinned: PhantomPinned, +} + +unsafe impl<T: Send> Send for ReentrantMutex<T> {} +unsafe impl<T: Send> Sync for ReentrantMutex<T> {} + +impl<T> UnwindSafe for ReentrantMutex<T> {} +impl<T> RefUnwindSafe for ReentrantMutex<T> {} + +/// An RAII implementation of a "scoped lock" of a mutex. When this structure is +/// dropped (falls out of scope), the lock will be unlocked. +/// +/// The data protected by the mutex can be accessed through this guard via its +/// Deref implementation. +/// +/// # Mutability +/// +/// Unlike `MutexGuard`, `ReentrantMutexGuard` does not implement `DerefMut`, +/// because implementation of the trait would violate Rust’s reference aliasing +/// rules. Use interior mutability (usually `RefCell`) in order to mutate the +/// guarded data. +#[must_use = "if unused the ReentrantMutex will immediately unlock"] +pub struct ReentrantMutexGuard<'a, T: 'a> { + lock: Pin<&'a ReentrantMutex<T>>, +} + +impl<T> !Send for ReentrantMutexGuard<'_, T> {} + +impl<T> ReentrantMutex<T> { + /// Creates a new reentrant mutex in an unlocked state. + /// + /// # Unsafety + /// + /// This function is unsafe because it is required that `init` is called + /// once this mutex is in its final resting place, and only then are the + /// lock/unlock methods safe. + pub const unsafe fn new(t: T) -> ReentrantMutex<T> { + ReentrantMutex { + mutex: sys::Mutex::new(), + owner: AtomicUsize::new(0), + lock_count: UnsafeCell::new(0), + data: t, + _pinned: PhantomPinned, + } + } + + /// Initializes this mutex so it's ready for use. + /// + /// # Unsafety + /// + /// Unsafe to call more than once, and must be called after this will no + /// longer move in memory. + pub unsafe fn init(self: Pin<&mut Self>) { + self.get_unchecked_mut().mutex.init() + } + + /// Acquires a mutex, blocking the current thread until it is able to do so. + /// + /// This function will block the caller until it is available to acquire the mutex. + /// Upon returning, the thread is the only thread with the mutex held. When the thread + /// calling this method already holds the lock, the call shall succeed without + /// blocking. + /// + /// # Errors + /// + /// If another user of this mutex panicked while holding the mutex, then + /// this call will return failure if the mutex would otherwise be + /// acquired. + pub fn lock(self: Pin<&Self>) -> ReentrantMutexGuard<'_, T> { + let this_thread = current_thread_unique_ptr(); + // Safety: We only touch lock_count when we own the lock, + // and since self is pinned we can safely call the lock() on the mutex. + unsafe { + if self.owner.load(Relaxed) == this_thread { + self.increment_lock_count(); + } else { + self.mutex.lock(); + self.owner.store(this_thread, Relaxed); + debug_assert_eq!(*self.lock_count.get(), 0); + *self.lock_count.get() = 1; + } + } + ReentrantMutexGuard { lock: self } + } + + /// Attempts to acquire this lock. + /// + /// If the lock could not be acquired at this time, then `Err` is returned. + /// Otherwise, an RAII guard is returned. + /// + /// This function does not block. + /// + /// # Errors + /// + /// If another user of this mutex panicked while holding the mutex, then + /// this call will return failure if the mutex would otherwise be + /// acquired. + pub fn try_lock(self: Pin<&Self>) -> Option<ReentrantMutexGuard<'_, T>> { + let this_thread = current_thread_unique_ptr(); + // Safety: We only touch lock_count when we own the lock, + // and since self is pinned we can safely call the try_lock on the mutex. + unsafe { + if self.owner.load(Relaxed) == this_thread { + self.increment_lock_count(); + Some(ReentrantMutexGuard { lock: self }) + } else if self.mutex.try_lock() { + self.owner.store(this_thread, Relaxed); + debug_assert_eq!(*self.lock_count.get(), 0); + *self.lock_count.get() = 1; + Some(ReentrantMutexGuard { lock: self }) + } else { + None + } + } + } + + unsafe fn increment_lock_count(&self) { + *self.lock_count.get() = (*self.lock_count.get()) + .checked_add(1) + .expect("lock count overflow in reentrant mutex"); + } +} + +impl<T> Deref for ReentrantMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + &self.lock.data + } +} + +impl<T> Drop for ReentrantMutexGuard<'_, T> { + #[inline] + fn drop(&mut self) { + // Safety: We own the lock, and the lock is pinned. + unsafe { + *self.lock.lock_count.get() -= 1; + if *self.lock.lock_count.get() == 0 { + self.lock.owner.store(0, Relaxed); + self.lock.mutex.unlock(); + } + } + } +} + +/// Get an address that is unique per running thread. +/// +/// This can be used as a non-null usize-sized ID. +pub fn current_thread_unique_ptr() -> usize { + // Use a non-drop type to make sure it's still available during thread destruction. + thread_local! { static X: u8 = const { 0 } } + X.with(|x| <*const _>::addr(x)) +} diff --git a/library/std/src/sys_common/remutex/tests.rs b/library/std/src/sys_common/remutex/tests.rs new file mode 100644 index 000000000..64873b850 --- /dev/null +++ b/library/std/src/sys_common/remutex/tests.rs @@ -0,0 +1,77 @@ +use crate::boxed::Box; +use crate::cell::RefCell; +use crate::pin::Pin; +use crate::sync::Arc; +use crate::sys_common::remutex::{ReentrantMutex, ReentrantMutexGuard}; +use crate::thread; + +#[test] +fn smoke() { + let m = unsafe { + let mut m = Box::pin(ReentrantMutex::new(())); + m.as_mut().init(); + m + }; + let m = m.as_ref(); + { + let a = m.lock(); + { + let b = m.lock(); + { + let c = m.lock(); + assert_eq!(*c, ()); + } + assert_eq!(*b, ()); + } + assert_eq!(*a, ()); + } +} + +#[test] +fn is_mutex() { + let m = unsafe { + // FIXME: Simplify this if Arc gets an Arc::get_pin_mut. + let mut m = Arc::new(ReentrantMutex::new(RefCell::new(0))); + Pin::new_unchecked(Arc::get_mut_unchecked(&mut m)).init(); + Pin::new_unchecked(m) + }; + let m2 = m.clone(); + let lock = m.as_ref().lock(); + let child = thread::spawn(move || { + let lock = m2.as_ref().lock(); + assert_eq!(*lock.borrow(), 4950); + }); + for i in 0..100 { + let lock = m.as_ref().lock(); + *lock.borrow_mut() += i; + } + drop(lock); + child.join().unwrap(); +} + +#[test] +fn trylock_works() { + let m = unsafe { + // FIXME: Simplify this if Arc gets an Arc::get_pin_mut. + let mut m = Arc::new(ReentrantMutex::new(())); + Pin::new_unchecked(Arc::get_mut_unchecked(&mut m)).init(); + Pin::new_unchecked(m) + }; + let m2 = m.clone(); + let _lock = m.as_ref().try_lock(); + let _lock2 = m.as_ref().try_lock(); + thread::spawn(move || { + let lock = m2.as_ref().try_lock(); + assert!(lock.is_none()); + }) + .join() + .unwrap(); + let _lock3 = m.as_ref().try_lock(); +} + +pub struct Answer<'a>(pub ReentrantMutexGuard<'a, RefCell<u32>>); +impl Drop for Answer<'_> { + fn drop(&mut self) { + *self.0.borrow_mut() = 42; + } +} diff --git a/library/std/src/sys_common/rwlock.rs b/library/std/src/sys_common/rwlock.rs new file mode 100644 index 000000000..ba56f3a8f --- /dev/null +++ b/library/std/src/sys_common/rwlock.rs @@ -0,0 +1,130 @@ +use crate::sys::locks as imp; + +/// An OS-based reader-writer lock, meant for use in static variables. +/// +/// This rwlock does not implement poisoning. +/// +/// This rwlock has a const constructor ([`StaticRwLock::new`]), does not +/// implement `Drop` to cleanup resources. +pub struct StaticRwLock(imp::RwLock); + +impl StaticRwLock { + /// Creates a new rwlock for use. + #[inline] + pub const fn new() -> Self { + Self(imp::RwLock::new()) + } + + /// Acquires shared access to the underlying lock, blocking the current + /// thread to do so. + /// + /// The lock is automatically unlocked when the returned guard is dropped. + #[inline] + pub fn read(&'static self) -> StaticRwLockReadGuard { + unsafe { self.0.read() }; + StaticRwLockReadGuard(&self.0) + } + + /// Acquires write access to the underlying lock, blocking the current thread + /// to do so. + /// + /// The lock is automatically unlocked when the returned guard is dropped. + #[inline] + pub fn write(&'static self) -> StaticRwLockWriteGuard { + unsafe { self.0.write() }; + StaticRwLockWriteGuard(&self.0) + } +} + +#[must_use] +pub struct StaticRwLockReadGuard(&'static imp::RwLock); + +impl Drop for StaticRwLockReadGuard { + #[inline] + fn drop(&mut self) { + unsafe { + self.0.read_unlock(); + } + } +} + +#[must_use] +pub struct StaticRwLockWriteGuard(&'static imp::RwLock); + +impl Drop for StaticRwLockWriteGuard { + #[inline] + fn drop(&mut self) { + unsafe { + self.0.write_unlock(); + } + } +} + +/// An OS-based reader-writer lock. +/// +/// This rwlock cleans up its resources in its `Drop` implementation and may +/// safely be moved (when not borrowed). +/// +/// This rwlock does not implement poisoning. +/// +/// This is either a wrapper around `LazyBox<imp::RwLock>` or `imp::RwLock`, +/// depending on the platform. It is boxed on platforms where `imp::RwLock` may +/// not be moved. +pub struct MovableRwLock(imp::MovableRwLock); + +impl MovableRwLock { + /// Creates a new reader-writer lock for use. + #[inline] + pub const fn new() -> Self { + Self(imp::MovableRwLock::new()) + } + + /// Acquires shared access to the underlying lock, blocking the current + /// thread to do so. + #[inline] + pub fn read(&self) { + unsafe { self.0.read() } + } + + /// Attempts to acquire shared access to this lock, returning whether it + /// succeeded or not. + /// + /// This function does not block the current thread. + #[inline] + pub fn try_read(&self) -> bool { + unsafe { self.0.try_read() } + } + + /// Acquires write access to the underlying lock, blocking the current thread + /// to do so. + #[inline] + pub fn write(&self) { + unsafe { self.0.write() } + } + + /// Attempts to acquire exclusive access to this lock, returning whether it + /// succeeded or not. + /// + /// This function does not block the current thread. + #[inline] + pub fn try_write(&self) -> bool { + unsafe { self.0.try_write() } + } + + /// Unlocks previously acquired shared access to this lock. + /// + /// Behavior is undefined if the current thread does not have shared access. + #[inline] + pub unsafe fn read_unlock(&self) { + self.0.read_unlock() + } + + /// Unlocks previously acquired exclusive access to this lock. + /// + /// Behavior is undefined if the current thread does not currently have + /// exclusive access. + #[inline] + pub unsafe fn write_unlock(&self) { + self.0.write_unlock() + } +} diff --git a/library/std/src/sys_common/tests.rs b/library/std/src/sys_common/tests.rs new file mode 100644 index 000000000..1b6446db5 --- /dev/null +++ b/library/std/src/sys_common/tests.rs @@ -0,0 +1,6 @@ +use super::mul_div_u64; + +#[test] +fn test_muldiv() { + assert_eq!(mul_div_u64(1_000_000_000_001, 1_000_000_000, 1_000_000), 1_000_000_000_001_000); +} diff --git a/library/std/src/sys_common/thread.rs b/library/std/src/sys_common/thread.rs new file mode 100644 index 000000000..76466b2b3 --- /dev/null +++ b/library/std/src/sys_common/thread.rs @@ -0,0 +1,18 @@ +use crate::env; +use crate::sync::atomic::{self, Ordering}; +use crate::sys::thread as imp; + +pub fn min_stack() -> usize { + static MIN: atomic::AtomicUsize = atomic::AtomicUsize::new(0); + match MIN.load(Ordering::Relaxed) { + 0 => {} + n => return n - 1, + } + let amt = env::var("RUST_MIN_STACK").ok().and_then(|s| s.parse().ok()); + let amt = amt.unwrap_or(imp::DEFAULT_MIN_STACK_SIZE); + + // 0 is our sentinel value, so ensure that we'll never see 0 after + // initialization has run + MIN.store(amt + 1, Ordering::Relaxed); + amt +} diff --git a/library/std/src/sys_common/thread_info.rs b/library/std/src/sys_common/thread_info.rs new file mode 100644 index 000000000..38c9e5000 --- /dev/null +++ b/library/std/src/sys_common/thread_info.rs @@ -0,0 +1,47 @@ +#![allow(dead_code)] // stack_guard isn't used right now on all platforms +#![allow(unused_unsafe)] // thread_local with `const {}` triggers this liny + +use crate::cell::RefCell; +use crate::sys::thread::guard::Guard; +use crate::thread::Thread; + +struct ThreadInfo { + stack_guard: Option<Guard>, + thread: Thread, +} + +thread_local! { static THREAD_INFO: RefCell<Option<ThreadInfo>> = const { RefCell::new(None) } } + +impl ThreadInfo { + fn with<R, F>(f: F) -> Option<R> + where + F: FnOnce(&mut ThreadInfo) -> R, + { + THREAD_INFO + .try_with(move |thread_info| { + let mut thread_info = thread_info.borrow_mut(); + let thread_info = thread_info.get_or_insert_with(|| ThreadInfo { + stack_guard: None, + thread: Thread::new(None), + }); + f(thread_info) + }) + .ok() + } +} + +pub fn current_thread() -> Option<Thread> { + ThreadInfo::with(|info| info.thread.clone()) +} + +pub fn stack_guard() -> Option<Guard> { + ThreadInfo::with(|info| info.stack_guard.clone()).and_then(|o| o) +} + +pub fn set(stack_guard: Option<Guard>, thread: Thread) { + THREAD_INFO.with(move |thread_info| { + let mut thread_info = thread_info.borrow_mut(); + rtassert!(thread_info.is_none()); + *thread_info = Some(ThreadInfo { stack_guard, thread }); + }); +} diff --git a/library/std/src/sys_common/thread_local_dtor.rs b/library/std/src/sys_common/thread_local_dtor.rs new file mode 100644 index 000000000..1d13a7171 --- /dev/null +++ b/library/std/src/sys_common/thread_local_dtor.rs @@ -0,0 +1,49 @@ +//! Thread-local destructor +//! +//! Besides thread-local "keys" (pointer-sized non-addressable thread-local store +//! with an associated destructor), many platforms also provide thread-local +//! destructors that are not associated with any particular data. These are +//! often more efficient. +//! +//! This module provides a fallback implementation for that interface, based +//! on the less efficient thread-local "keys". Each platform provides +//! a `thread_local_dtor` module which will either re-export the fallback, +//! or implement something more efficient. + +#![unstable(feature = "thread_local_internals", issue = "none")] +#![allow(dead_code)] + +use crate::ptr; +use crate::sys_common::thread_local_key::StaticKey; + +pub unsafe fn register_dtor_fallback(t: *mut u8, dtor: unsafe extern "C" fn(*mut u8)) { + // The fallback implementation uses a vanilla OS-based TLS key to track + // the list of destructors that need to be run for this thread. The key + // then has its own destructor which runs all the other destructors. + // + // The destructor for DTORS is a little special in that it has a `while` + // loop to continuously drain the list of registered destructors. It + // *should* be the case that this loop always terminates because we + // provide the guarantee that a TLS key cannot be set after it is + // flagged for destruction. + + static DTORS: StaticKey = StaticKey::new(Some(run_dtors)); + type List = Vec<(*mut u8, unsafe extern "C" fn(*mut u8))>; + if DTORS.get().is_null() { + let v: Box<List> = box Vec::new(); + DTORS.set(Box::into_raw(v) as *mut u8); + } + let list: &mut List = &mut *(DTORS.get() as *mut List); + list.push((t, dtor)); + + unsafe extern "C" fn run_dtors(mut ptr: *mut u8) { + while !ptr.is_null() { + let list: Box<List> = Box::from_raw(ptr as *mut List); + for (ptr, dtor) in list.into_iter() { + dtor(ptr); + } + ptr = DTORS.get(); + DTORS.set(ptr::null_mut()); + } + } +} diff --git a/library/std/src/sys_common/thread_local_key.rs b/library/std/src/sys_common/thread_local_key.rs new file mode 100644 index 000000000..70beebe86 --- /dev/null +++ b/library/std/src/sys_common/thread_local_key.rs @@ -0,0 +1,237 @@ +//! OS-based thread local storage +//! +//! This module provides an implementation of OS-based thread local storage, +//! using the native OS-provided facilities (think `TlsAlloc` or +//! `pthread_setspecific`). The interface of this differs from the other types +//! of thread-local-storage provided in this crate in that OS-based TLS can only +//! get/set pointer-sized data, possibly with an associated destructor. +//! +//! This module also provides two flavors of TLS. One is intended for static +//! initialization, and does not contain a `Drop` implementation to deallocate +//! the OS-TLS key. The other is a type which does implement `Drop` and hence +//! has a safe interface. +//! +//! # Usage +//! +//! This module should likely not be used directly unless other primitives are +//! being built on. Types such as `thread_local::spawn::Key` are likely much +//! more useful in practice than this OS-based version which likely requires +//! unsafe code to interoperate with. +//! +//! # Examples +//! +//! Using a dynamically allocated TLS key. Note that this key can be shared +//! among many threads via an `Arc`. +//! +//! ```ignore (cannot-doctest-private-modules) +//! let key = Key::new(None); +//! assert!(key.get().is_null()); +//! key.set(1 as *mut u8); +//! assert!(!key.get().is_null()); +//! +//! drop(key); // deallocate this TLS slot. +//! ``` +//! +//! Sometimes a statically allocated key is either required or easier to work +//! with, however. +//! +//! ```ignore (cannot-doctest-private-modules) +//! static KEY: StaticKey = INIT; +//! +//! unsafe { +//! assert!(KEY.get().is_null()); +//! KEY.set(1 as *mut u8); +//! } +//! ``` + +#![allow(non_camel_case_types)] +#![unstable(feature = "thread_local_internals", issue = "none")] +#![allow(dead_code)] + +#[cfg(test)] +mod tests; + +use crate::sync::atomic::{self, AtomicUsize, Ordering}; +use crate::sys::thread_local_key as imp; +use crate::sys_common::mutex::StaticMutex; + +/// A type for TLS keys that are statically allocated. +/// +/// This type is entirely `unsafe` to use as it does not protect against +/// use-after-deallocation or use-during-deallocation. +/// +/// The actual OS-TLS key is lazily allocated when this is used for the first +/// time. The key is also deallocated when the Rust runtime exits or `destroy` +/// is called, whichever comes first. +/// +/// # Examples +/// +/// ```ignore (cannot-doctest-private-modules) +/// use tls::os::{StaticKey, INIT}; +/// +/// static KEY: StaticKey = INIT; +/// +/// unsafe { +/// assert!(KEY.get().is_null()); +/// KEY.set(1 as *mut u8); +/// } +/// ``` +pub struct StaticKey { + /// Inner static TLS key (internals). + key: AtomicUsize, + /// Destructor for the TLS value. + /// + /// See `Key::new` for information about when the destructor runs and how + /// it runs. + dtor: Option<unsafe extern "C" fn(*mut u8)>, +} + +/// A type for a safely managed OS-based TLS slot. +/// +/// This type allocates an OS TLS key when it is initialized and will deallocate +/// the key when it falls out of scope. When compared with `StaticKey`, this +/// type is entirely safe to use. +/// +/// Implementations will likely, however, contain unsafe code as this type only +/// operates on `*mut u8`, a raw pointer. +/// +/// # Examples +/// +/// ```ignore (cannot-doctest-private-modules) +/// use tls::os::Key; +/// +/// let key = Key::new(None); +/// assert!(key.get().is_null()); +/// key.set(1 as *mut u8); +/// assert!(!key.get().is_null()); +/// +/// drop(key); // deallocate this TLS slot. +/// ``` +pub struct Key { + key: imp::Key, +} + +/// Constant initialization value for static TLS keys. +/// +/// This value specifies no destructor by default. +pub const INIT: StaticKey = StaticKey::new(None); + +impl StaticKey { + #[rustc_const_unstable(feature = "thread_local_internals", issue = "none")] + pub const fn new(dtor: Option<unsafe extern "C" fn(*mut u8)>) -> StaticKey { + StaticKey { key: atomic::AtomicUsize::new(0), dtor } + } + + /// Gets the value associated with this TLS key + /// + /// This will lazily allocate a TLS key from the OS if one has not already + /// been allocated. + #[inline] + pub unsafe fn get(&self) -> *mut u8 { + imp::get(self.key()) + } + + /// Sets this TLS key to a new value. + /// + /// This will lazily allocate a TLS key from the OS if one has not already + /// been allocated. + #[inline] + pub unsafe fn set(&self, val: *mut u8) { + imp::set(self.key(), val) + } + + #[inline] + unsafe fn key(&self) -> imp::Key { + match self.key.load(Ordering::Relaxed) { + 0 => self.lazy_init() as imp::Key, + n => n as imp::Key, + } + } + + unsafe fn lazy_init(&self) -> usize { + // Currently the Windows implementation of TLS is pretty hairy, and + // it greatly simplifies creation if we just synchronize everything. + // + // Additionally a 0-index of a tls key hasn't been seen on windows, so + // we just simplify the whole branch. + if imp::requires_synchronized_create() { + // We never call `INIT_LOCK.init()`, so it is UB to attempt to + // acquire this mutex reentrantly! + static INIT_LOCK: StaticMutex = StaticMutex::new(); + let _guard = INIT_LOCK.lock(); + let mut key = self.key.load(Ordering::SeqCst); + if key == 0 { + key = imp::create(self.dtor) as usize; + self.key.store(key, Ordering::SeqCst); + } + rtassert!(key != 0); + return key; + } + + // POSIX allows the key created here to be 0, but the compare_exchange + // below relies on using 0 as a sentinel value to check who won the + // race to set the shared TLS key. As far as I know, there is no + // guaranteed value that cannot be returned as a posix_key_create key, + // so there is no value we can initialize the inner key with to + // prove that it has not yet been set. As such, we'll continue using a + // value of 0, but with some gyrations to make sure we have a non-0 + // value returned from the creation routine. + // FIXME: this is clearly a hack, and should be cleaned up. + let key1 = imp::create(self.dtor); + let key = if key1 != 0 { + key1 + } else { + let key2 = imp::create(self.dtor); + imp::destroy(key1); + key2 + }; + rtassert!(key != 0); + match self.key.compare_exchange(0, key as usize, Ordering::SeqCst, Ordering::SeqCst) { + // The CAS succeeded, so we've created the actual key + Ok(_) => key as usize, + // If someone beat us to the punch, use their key instead + Err(n) => { + imp::destroy(key); + n + } + } + } +} + +impl Key { + /// Creates a new managed OS TLS key. + /// + /// This key will be deallocated when the key falls out of scope. + /// + /// The argument provided is an optionally-specified destructor for the + /// value of this TLS key. When a thread exits and the value for this key + /// is non-null the destructor will be invoked. The TLS value will be reset + /// to null before the destructor is invoked. + /// + /// Note that the destructor will not be run when the `Key` goes out of + /// scope. + #[inline] + pub fn new(dtor: Option<unsafe extern "C" fn(*mut u8)>) -> Key { + Key { key: unsafe { imp::create(dtor) } } + } + + /// See StaticKey::get + #[inline] + pub fn get(&self) -> *mut u8 { + unsafe { imp::get(self.key) } + } + + /// See StaticKey::set + #[inline] + pub fn set(&self, val: *mut u8) { + unsafe { imp::set(self.key, val) } + } +} + +impl Drop for Key { + fn drop(&mut self) { + // Right now Windows doesn't support TLS key destruction, but this also + // isn't used anywhere other than tests, so just leak the TLS key. + // unsafe { imp::destroy(self.key) } + } +} diff --git a/library/std/src/sys_common/thread_local_key/tests.rs b/library/std/src/sys_common/thread_local_key/tests.rs new file mode 100644 index 000000000..968738a41 --- /dev/null +++ b/library/std/src/sys_common/thread_local_key/tests.rs @@ -0,0 +1,34 @@ +use super::{Key, StaticKey}; + +fn assert_sync<T: Sync>() {} +fn assert_send<T: Send>() {} + +#[test] +fn smoke() { + assert_sync::<Key>(); + assert_send::<Key>(); + + let k1 = Key::new(None); + let k2 = Key::new(None); + assert!(k1.get().is_null()); + assert!(k2.get().is_null()); + k1.set(1 as *mut _); + k2.set(2 as *mut _); + assert_eq!(k1.get() as usize, 1); + assert_eq!(k2.get() as usize, 2); +} + +#[test] +fn statik() { + static K1: StaticKey = StaticKey::new(None); + static K2: StaticKey = StaticKey::new(None); + + unsafe { + assert!(K1.get().is_null()); + assert!(K2.get().is_null()); + K1.set(1 as *mut _); + K2.set(2 as *mut _); + assert_eq!(K1.get() as usize, 1); + assert_eq!(K2.get() as usize, 2); + } +} diff --git a/library/std/src/sys_common/thread_parker/futex.rs b/library/std/src/sys_common/thread_parker/futex.rs new file mode 100644 index 000000000..d9e2f39e3 --- /dev/null +++ b/library/std/src/sys_common/thread_parker/futex.rs @@ -0,0 +1,97 @@ +use crate::pin::Pin; +use crate::sync::atomic::AtomicU32; +use crate::sync::atomic::Ordering::{Acquire, Release}; +use crate::sys::futex::{futex_wait, futex_wake}; +use crate::time::Duration; + +const PARKED: u32 = u32::MAX; +const EMPTY: u32 = 0; +const NOTIFIED: u32 = 1; + +pub struct Parker { + state: AtomicU32, +} + +// Notes about memory ordering: +// +// Memory ordering is only relevant for the relative ordering of operations +// between different variables. Even Ordering::Relaxed guarantees a +// monotonic/consistent order when looking at just a single atomic variable. +// +// So, since this parker is just a single atomic variable, we only need to look +// at the ordering guarantees we need to provide to the 'outside world'. +// +// The only memory ordering guarantee that parking and unparking provide, is +// that things which happened before unpark() are visible on the thread +// returning from park() afterwards. Otherwise, it was effectively unparked +// before unpark() was called while still consuming the 'token'. +// +// In other words, unpark() needs to synchronize with the part of park() that +// consumes the token and returns. +// +// This is done with a release-acquire synchronization, by using +// Ordering::Release when writing NOTIFIED (the 'token') in unpark(), and using +// Ordering::Acquire when checking for this state in park(). +impl Parker { + /// Construct the futex parker. The UNIX parker implementation + /// requires this to happen in-place. + pub unsafe fn new(parker: *mut Parker) { + parker.write(Self { state: AtomicU32::new(EMPTY) }); + } + + // Assumes this is only called by the thread that owns the Parker, + // which means that `self.state != PARKED`. + pub unsafe fn park(self: Pin<&Self>) { + // Change NOTIFIED=>EMPTY or EMPTY=>PARKED, and directly return in the + // first case. + if self.state.fetch_sub(1, Acquire) == NOTIFIED { + return; + } + loop { + // Wait for something to happen, assuming it's still set to PARKED. + futex_wait(&self.state, PARKED, None); + // Change NOTIFIED=>EMPTY and return in that case. + if self.state.compare_exchange(NOTIFIED, EMPTY, Acquire, Acquire).is_ok() { + return; + } else { + // Spurious wake up. We loop to try again. + } + } + } + + // Assumes this is only called by the thread that owns the Parker, + // which means that `self.state != PARKED`. This implementation doesn't + // require `Pin`, but other implementations do. + pub unsafe fn park_timeout(self: Pin<&Self>, timeout: Duration) { + // Change NOTIFIED=>EMPTY or EMPTY=>PARKED, and directly return in the + // first case. + if self.state.fetch_sub(1, Acquire) == NOTIFIED { + return; + } + // Wait for something to happen, assuming it's still set to PARKED. + futex_wait(&self.state, PARKED, Some(timeout)); + // This is not just a store, because we need to establish a + // release-acquire ordering with unpark(). + if self.state.swap(EMPTY, Acquire) == NOTIFIED { + // Woke up because of unpark(). + } else { + // Timeout or spurious wake up. + // We return either way, because we can't easily tell if it was the + // timeout or not. + } + } + + // This implementation doesn't require `Pin`, but other implementations do. + #[inline] + pub fn unpark(self: Pin<&Self>) { + // Change PARKED=>NOTIFIED, EMPTY=>NOTIFIED, or NOTIFIED=>NOTIFIED, and + // wake the thread in the first case. + // + // Note that even NOTIFIED=>NOTIFIED results in a write. This is on + // purpose, to make sure every unpark() has a release-acquire ordering + // with park(). + if self.state.swap(NOTIFIED, Release) == PARKED { + futex_wake(&self.state); + } + } +} diff --git a/library/std/src/sys_common/thread_parker/generic.rs b/library/std/src/sys_common/thread_parker/generic.rs new file mode 100644 index 000000000..f3d8b34d3 --- /dev/null +++ b/library/std/src/sys_common/thread_parker/generic.rs @@ -0,0 +1,125 @@ +//! Parker implementation based on a Mutex and Condvar. + +use crate::pin::Pin; +use crate::sync::atomic::AtomicUsize; +use crate::sync::atomic::Ordering::SeqCst; +use crate::sync::{Condvar, Mutex}; +use crate::time::Duration; + +const EMPTY: usize = 0; +const PARKED: usize = 1; +const NOTIFIED: usize = 2; + +pub struct Parker { + state: AtomicUsize, + lock: Mutex<()>, + cvar: Condvar, +} + +impl Parker { + /// Construct the generic parker. The UNIX parker implementation + /// requires this to happen in-place. + pub unsafe fn new(parker: *mut Parker) { + parker.write(Parker { + state: AtomicUsize::new(EMPTY), + lock: Mutex::new(()), + cvar: Condvar::new(), + }); + } + + // This implementation doesn't require `unsafe` and `Pin`, but other implementations do. + pub unsafe fn park(self: Pin<&Self>) { + // If we were previously notified then we consume this notification and + // return quickly. + if self.state.compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst).is_ok() { + return; + } + + // Otherwise we need to coordinate going to sleep + let mut m = self.lock.lock().unwrap(); + match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read here, even though we know it will be `NOTIFIED`. + // This is because `unpark` may have been called again since we read + // `NOTIFIED` in the `compare_exchange` above. We must perform an + // acquire operation that synchronizes with that `unpark` to observe + // any writes it made before the call to unpark. To do that we must + // read from the write it made to `state`. + let old = self.state.swap(EMPTY, SeqCst); + assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + return; + } // should consume this notification, so prohibit spurious wakeups in next park. + Err(_) => panic!("inconsistent park state"), + } + loop { + m = self.cvar.wait(m).unwrap(); + match self.state.compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) { + Ok(_) => return, // got a notification + Err(_) => {} // spurious wakeup, go back to sleep + } + } + } + + // This implementation doesn't require `unsafe` and `Pin`, but other implementations do. + pub unsafe fn park_timeout(self: Pin<&Self>, dur: Duration) { + // Like `park` above we have a fast path for an already-notified thread, and + // afterwards we start coordinating for a sleep. + // return quickly. + if self.state.compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst).is_ok() { + return; + } + let m = self.lock.lock().unwrap(); + match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { + Ok(_) => {} + Err(NOTIFIED) => { + // We must read again here, see `park`. + let old = self.state.swap(EMPTY, SeqCst); + assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + return; + } // should consume this notification, so prohibit spurious wakeups in next park. + Err(_) => panic!("inconsistent park_timeout state"), + } + + // Wait with a timeout, and if we spuriously wake up or otherwise wake up + // from a notification we just want to unconditionally set the state back to + // empty, either consuming a notification or un-flagging ourselves as + // parked. + let (_m, _result) = self.cvar.wait_timeout(m, dur).unwrap(); + match self.state.swap(EMPTY, SeqCst) { + NOTIFIED => {} // got a notification, hurray! + PARKED => {} // no notification, alas + n => panic!("inconsistent park_timeout state: {n}"), + } + } + + // This implementation doesn't require `Pin`, but other implementations do. + pub fn unpark(self: Pin<&Self>) { + // To ensure the unparked thread will observe any writes we made + // before this call, we must perform a release operation that `park` + // can synchronize with. To do that we must write `NOTIFIED` even if + // `state` is already `NOTIFIED`. That is why this must be a swap + // rather than a compare-and-swap that returns if it reads `NOTIFIED` + // on failure. + match self.state.swap(NOTIFIED, SeqCst) { + EMPTY => return, // no one was waiting + NOTIFIED => return, // already unparked + PARKED => {} // gotta go wake someone up + _ => panic!("inconsistent state in unpark"), + } + + // There is a period between when the parked thread sets `state` to + // `PARKED` (or last checked `state` in the case of a spurious wake + // up) and when it actually waits on `cvar`. If we were to notify + // during this period it would be ignored and then when the parked + // thread went to sleep it would never wake up. Fortunately, it has + // `lock` locked at this stage so we can acquire `lock` to wait until + // it is ready to receive the notification. + // + // Releasing `lock` before the call to `notify_one` means that when the + // parked thread wakes it doesn't get woken only to have to wait for us + // to release `lock`. + drop(self.lock.lock().unwrap()); + self.cvar.notify_one() + } +} diff --git a/library/std/src/sys_common/thread_parker/mod.rs b/library/std/src/sys_common/thread_parker/mod.rs new file mode 100644 index 000000000..cbd7832eb --- /dev/null +++ b/library/std/src/sys_common/thread_parker/mod.rs @@ -0,0 +1,22 @@ +cfg_if::cfg_if! { + if #[cfg(any( + target_os = "linux", + target_os = "android", + all(target_arch = "wasm32", target_feature = "atomics"), + target_os = "freebsd", + target_os = "openbsd", + target_os = "dragonfly", + target_os = "fuchsia", + ))] { + mod futex; + pub use futex::Parker; + } else if #[cfg(target_os = "solid_asp3")] { + mod wait_flag; + pub use wait_flag::Parker; + } else if #[cfg(any(windows, target_family = "unix"))] { + pub use crate::sys::thread_parker::Parker; + } else { + mod generic; + pub use generic::Parker; + } +} diff --git a/library/std/src/sys_common/thread_parker/wait_flag.rs b/library/std/src/sys_common/thread_parker/wait_flag.rs new file mode 100644 index 000000000..6561c1866 --- /dev/null +++ b/library/std/src/sys_common/thread_parker/wait_flag.rs @@ -0,0 +1,102 @@ +//! A wait-flag-based thread parker. +//! +//! Some operating systems provide low-level parking primitives like wait counts, +//! event flags or semaphores which are not susceptible to race conditions (meaning +//! the wakeup can occur before the wait operation). To implement the `std` thread +//! parker on top of these primitives, we only have to ensure that parking is fast +//! when the thread token is available, the atomic ordering guarantees are maintained +//! and spurious wakeups are minimized. +//! +//! To achieve this, this parker uses an atomic variable with three states: `EMPTY`, +//! `PARKED` and `NOTIFIED`: +//! * `EMPTY` means the token has not been made available, but the thread is not +//! currently waiting on it. +//! * `PARKED` means the token is not available and the thread is parked. +//! * `NOTIFIED` means the token is available. +//! +//! `park` and `park_timeout` change the state from `EMPTY` to `PARKED` and from +//! `NOTIFIED` to `EMPTY`. If the state was `NOTIFIED`, the thread was unparked and +//! execution can continue without calling into the OS. If the state was `EMPTY`, +//! the token is not available and the thread waits on the primitive (here called +//! "wait flag"). +//! +//! `unpark` changes the state to `NOTIFIED`. If the state was `PARKED`, the thread +//! is or will be sleeping on the wait flag, so we raise it. + +use crate::pin::Pin; +use crate::sync::atomic::AtomicI8; +use crate::sync::atomic::Ordering::{Acquire, Relaxed, Release}; +use crate::sys::wait_flag::WaitFlag; +use crate::time::Duration; + +const EMPTY: i8 = 0; +const PARKED: i8 = -1; +const NOTIFIED: i8 = 1; + +pub struct Parker { + state: AtomicI8, + wait_flag: WaitFlag, +} + +impl Parker { + /// Construct a parker for the current thread. The UNIX parker + /// implementation requires this to happen in-place. + pub unsafe fn new(parker: *mut Parker) { + parker.write(Parker { state: AtomicI8::new(EMPTY), wait_flag: WaitFlag::new() }) + } + + // This implementation doesn't require `unsafe` and `Pin`, but other implementations do. + pub unsafe fn park(self: Pin<&Self>) { + match self.state.fetch_sub(1, Acquire) { + // NOTIFIED => EMPTY + NOTIFIED => return, + // EMPTY => PARKED + EMPTY => (), + _ => panic!("inconsistent park state"), + } + + // Avoid waking up from spurious wakeups (these are quite likely, see below). + loop { + self.wait_flag.wait(); + + match self.state.compare_exchange(NOTIFIED, EMPTY, Acquire, Relaxed) { + Ok(_) => return, + Err(PARKED) => (), + Err(_) => panic!("inconsistent park state"), + } + } + } + + // This implementation doesn't require `unsafe` and `Pin`, but other implementations do. + pub unsafe fn park_timeout(self: Pin<&Self>, dur: Duration) { + match self.state.fetch_sub(1, Acquire) { + NOTIFIED => return, + EMPTY => (), + _ => panic!("inconsistent park state"), + } + + self.wait_flag.wait_timeout(dur); + + // Either a wakeup or a timeout occurred. Wakeups may be spurious, as there can be + // a race condition when `unpark` is performed between receiving the timeout and + // resetting the state, resulting in the eventflag being set unnecessarily. `park` + // is protected against this by looping until the token is actually given, but + // here we cannot easily tell. + + // Use `swap` to provide acquire ordering. + match self.state.swap(EMPTY, Acquire) { + NOTIFIED => (), + PARKED => (), + _ => panic!("inconsistent park state"), + } + } + + // This implementation doesn't require `Pin`, but other implementations do. + pub fn unpark(self: Pin<&Self>) { + let state = self.state.swap(NOTIFIED, Release); + + if state == PARKED { + self.wait_flag.raise(); + } + } +} diff --git a/library/std/src/sys_common/wtf8.rs b/library/std/src/sys_common/wtf8.rs new file mode 100644 index 000000000..57fa49893 --- /dev/null +++ b/library/std/src/sys_common/wtf8.rs @@ -0,0 +1,926 @@ +//! Implementation of [the WTF-8 encoding](https://simonsapin.github.io/wtf-8/). +//! +//! This library uses Rust’s type system to maintain +//! [well-formedness](https://simonsapin.github.io/wtf-8/#well-formed), +//! like the `String` and `&str` types do for UTF-8. +//! +//! Since [WTF-8 must not be used +//! for interchange](https://simonsapin.github.io/wtf-8/#intended-audience), +//! this library deliberately does not provide access to the underlying bytes +//! of WTF-8 strings, +//! nor can it decode WTF-8 from arbitrary bytes. +//! WTF-8 strings can be obtained from UTF-8, UTF-16, or code points. + +// this module is imported from @SimonSapin's repo and has tons of dead code on +// unix (it's mostly used on windows), so don't worry about dead code here. +#![allow(dead_code)] + +#[cfg(test)] +mod tests; + +use core::str::next_code_point; + +use crate::borrow::Cow; +use crate::char; +use crate::collections::TryReserveError; +use crate::fmt; +use crate::hash::{Hash, Hasher}; +use crate::iter::FusedIterator; +use crate::mem; +use crate::ops; +use crate::rc::Rc; +use crate::slice; +use crate::str; +use crate::sync::Arc; +use crate::sys_common::AsInner; + +const UTF8_REPLACEMENT_CHARACTER: &str = "\u{FFFD}"; + +/// A Unicode code point: from U+0000 to U+10FFFF. +/// +/// Compares with the `char` type, +/// which represents a Unicode scalar value: +/// a code point that is not a surrogate (U+D800 to U+DFFF). +#[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +pub struct CodePoint { + value: u32, +} + +/// Format the code point as `U+` followed by four to six hexadecimal digits. +/// Example: `U+1F4A9` +impl fmt::Debug for CodePoint { + #[inline] + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "U+{:04X}", self.value) + } +} + +impl CodePoint { + /// Unsafely creates a new `CodePoint` without checking the value. + /// + /// Only use when `value` is known to be less than or equal to 0x10FFFF. + #[inline] + pub unsafe fn from_u32_unchecked(value: u32) -> CodePoint { + CodePoint { value } + } + + /// Creates a new `CodePoint` if the value is a valid code point. + /// + /// Returns `None` if `value` is above 0x10FFFF. + #[inline] + pub fn from_u32(value: u32) -> Option<CodePoint> { + match value { + 0..=0x10FFFF => Some(CodePoint { value }), + _ => None, + } + } + + /// Creates a new `CodePoint` from a `char`. + /// + /// Since all Unicode scalar values are code points, this always succeeds. + #[inline] + pub fn from_char(value: char) -> CodePoint { + CodePoint { value: value as u32 } + } + + /// Returns the numeric value of the code point. + #[inline] + pub fn to_u32(&self) -> u32 { + self.value + } + + /// Optionally returns a Unicode scalar value for the code point. + /// + /// Returns `None` if the code point is a surrogate (from U+D800 to U+DFFF). + #[inline] + pub fn to_char(&self) -> Option<char> { + match self.value { + 0xD800..=0xDFFF => None, + _ => Some(unsafe { char::from_u32_unchecked(self.value) }), + } + } + + /// Returns a Unicode scalar value for the code point. + /// + /// Returns `'\u{FFFD}'` (the replacement character “�”) + /// if the code point is a surrogate (from U+D800 to U+DFFF). + #[inline] + pub fn to_char_lossy(&self) -> char { + self.to_char().unwrap_or('\u{FFFD}') + } +} + +/// An owned, growable string of well-formed WTF-8 data. +/// +/// Similar to `String`, but can additionally contain surrogate code points +/// if they’re not in a surrogate pair. +#[derive(Eq, PartialEq, Ord, PartialOrd, Clone)] +pub struct Wtf8Buf { + bytes: Vec<u8>, +} + +impl ops::Deref for Wtf8Buf { + type Target = Wtf8; + + fn deref(&self) -> &Wtf8 { + self.as_slice() + } +} + +impl ops::DerefMut for Wtf8Buf { + fn deref_mut(&mut self) -> &mut Wtf8 { + self.as_mut_slice() + } +} + +/// Format the string with double quotes, +/// and surrogates as `\u` followed by four hexadecimal digits. +/// Example: `"a\u{D800}"` for a string with code points [U+0061, U+D800] +impl fmt::Debug for Wtf8Buf { + #[inline] + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, formatter) + } +} + +impl Wtf8Buf { + /// Creates a new, empty WTF-8 string. + #[inline] + pub fn new() -> Wtf8Buf { + Wtf8Buf { bytes: Vec::new() } + } + + /// Creates a new, empty WTF-8 string with pre-allocated capacity for `capacity` bytes. + #[inline] + pub fn with_capacity(capacity: usize) -> Wtf8Buf { + Wtf8Buf { bytes: Vec::with_capacity(capacity) } + } + + /// Creates a WTF-8 string from a UTF-8 `String`. + /// + /// This takes ownership of the `String` and does not copy. + /// + /// Since WTF-8 is a superset of UTF-8, this always succeeds. + #[inline] + pub fn from_string(string: String) -> Wtf8Buf { + Wtf8Buf { bytes: string.into_bytes() } + } + + /// Creates a WTF-8 string from a UTF-8 `&str` slice. + /// + /// This copies the content of the slice. + /// + /// Since WTF-8 is a superset of UTF-8, this always succeeds. + #[inline] + pub fn from_str(str: &str) -> Wtf8Buf { + Wtf8Buf { bytes: <[_]>::to_vec(str.as_bytes()) } + } + + pub fn clear(&mut self) { + self.bytes.clear() + } + + /// Creates a WTF-8 string from a potentially ill-formed UTF-16 slice of 16-bit code units. + /// + /// This is lossless: calling `.encode_wide()` on the resulting string + /// will always return the original code units. + pub fn from_wide(v: &[u16]) -> Wtf8Buf { + let mut string = Wtf8Buf::with_capacity(v.len()); + for item in char::decode_utf16(v.iter().cloned()) { + match item { + Ok(ch) => string.push_char(ch), + Err(surrogate) => { + let surrogate = surrogate.unpaired_surrogate(); + // Surrogates are known to be in the code point range. + let code_point = unsafe { CodePoint::from_u32_unchecked(surrogate as u32) }; + // Skip the WTF-8 concatenation check, + // surrogate pairs are already decoded by decode_utf16 + string.push_code_point_unchecked(code_point) + } + } + } + string + } + + /// Copied from String::push + /// This does **not** include the WTF-8 concatenation check. + fn push_code_point_unchecked(&mut self, code_point: CodePoint) { + let mut bytes = [0; 4]; + let bytes = char::encode_utf8_raw(code_point.value, &mut bytes); + self.bytes.extend_from_slice(bytes) + } + + #[inline] + pub fn as_slice(&self) -> &Wtf8 { + unsafe { Wtf8::from_bytes_unchecked(&self.bytes) } + } + + #[inline] + pub fn as_mut_slice(&mut self) -> &mut Wtf8 { + unsafe { Wtf8::from_mut_bytes_unchecked(&mut self.bytes) } + } + + /// Reserves capacity for at least `additional` more bytes to be inserted + /// in the given `Wtf8Buf`. + /// The collection may reserve more space to avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity overflows `usize`. + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.bytes.reserve(additional) + } + + /// Tries to reserve capacity for at least `additional` more length units + /// in the given `Wtf8Buf`. The `Wtf8Buf` may reserve more space to avoid + /// frequent reallocations. After calling `try_reserve`, capacity will be + /// greater than or equal to `self.len() + additional`. Does nothing if + /// capacity is already sufficient. + /// + /// # Errors + /// + /// If the capacity overflows, or the allocator reports a failure, then an error + /// is returned. + #[inline] + pub fn try_reserve(&mut self, additional: usize) -> Result<(), TryReserveError> { + self.bytes.try_reserve(additional) + } + + #[inline] + pub fn reserve_exact(&mut self, additional: usize) { + self.bytes.reserve_exact(additional) + } + + /// Tries to reserve the minimum capacity for exactly `additional` + /// length units in the given `Wtf8Buf`. After calling + /// `try_reserve_exact`, capacity will be greater than or equal to + /// `self.len() + additional` if it returns `Ok(())`. + /// Does nothing if the capacity is already sufficient. + /// + /// Note that the allocator may give the `Wtf8Buf` more space than it + /// requests. Therefore, capacity can not be relied upon to be precisely + /// minimal. Prefer [`try_reserve`] if future insertions are expected. + /// + /// [`try_reserve`]: Wtf8Buf::try_reserve + /// + /// # Errors + /// + /// If the capacity overflows, or the allocator reports a failure, then an error + /// is returned. + #[inline] + pub fn try_reserve_exact(&mut self, additional: usize) -> Result<(), TryReserveError> { + self.bytes.try_reserve_exact(additional) + } + + #[inline] + pub fn shrink_to_fit(&mut self) { + self.bytes.shrink_to_fit() + } + + #[inline] + pub fn shrink_to(&mut self, min_capacity: usize) { + self.bytes.shrink_to(min_capacity) + } + + /// Returns the number of bytes that this string buffer can hold without reallocating. + #[inline] + pub fn capacity(&self) -> usize { + self.bytes.capacity() + } + + /// Append a UTF-8 slice at the end of the string. + #[inline] + pub fn push_str(&mut self, other: &str) { + self.bytes.extend_from_slice(other.as_bytes()) + } + + /// Append a WTF-8 slice at the end of the string. + /// + /// This replaces newly paired surrogates at the boundary + /// with a supplementary code point, + /// like concatenating ill-formed UTF-16 strings effectively would. + #[inline] + pub fn push_wtf8(&mut self, other: &Wtf8) { + match ((&*self).final_lead_surrogate(), other.initial_trail_surrogate()) { + // Replace newly paired surrogates by a supplementary code point. + (Some(lead), Some(trail)) => { + let len_without_lead_surrogate = self.len() - 3; + self.bytes.truncate(len_without_lead_surrogate); + let other_without_trail_surrogate = &other.bytes[3..]; + // 4 bytes for the supplementary code point + self.bytes.reserve(4 + other_without_trail_surrogate.len()); + self.push_char(decode_surrogate_pair(lead, trail)); + self.bytes.extend_from_slice(other_without_trail_surrogate); + } + _ => self.bytes.extend_from_slice(&other.bytes), + } + } + + /// Append a Unicode scalar value at the end of the string. + #[inline] + pub fn push_char(&mut self, c: char) { + self.push_code_point_unchecked(CodePoint::from_char(c)) + } + + /// Append a code point at the end of the string. + /// + /// This replaces newly paired surrogates at the boundary + /// with a supplementary code point, + /// like concatenating ill-formed UTF-16 strings effectively would. + #[inline] + pub fn push(&mut self, code_point: CodePoint) { + if let trail @ 0xDC00..=0xDFFF = code_point.to_u32() { + if let Some(lead) = (&*self).final_lead_surrogate() { + let len_without_lead_surrogate = self.len() - 3; + self.bytes.truncate(len_without_lead_surrogate); + self.push_char(decode_surrogate_pair(lead, trail as u16)); + return; + } + } + + // No newly paired surrogates at the boundary. + self.push_code_point_unchecked(code_point) + } + + /// Shortens a string to the specified length. + /// + /// # Panics + /// + /// Panics if `new_len` > current length, + /// or if `new_len` is not a code point boundary. + #[inline] + pub fn truncate(&mut self, new_len: usize) { + assert!(is_code_point_boundary(self, new_len)); + self.bytes.truncate(new_len) + } + + /// Consumes the WTF-8 string and tries to convert it to UTF-8. + /// + /// This does not copy the data. + /// + /// If the contents are not well-formed UTF-8 + /// (that is, if the string contains surrogates), + /// the original WTF-8 string is returned instead. + pub fn into_string(self) -> Result<String, Wtf8Buf> { + match self.next_surrogate(0) { + None => Ok(unsafe { String::from_utf8_unchecked(self.bytes) }), + Some(_) => Err(self), + } + } + + /// Consumes the WTF-8 string and converts it lossily to UTF-8. + /// + /// This does not copy the data (but may overwrite parts of it in place). + /// + /// Surrogates are replaced with `"\u{FFFD}"` (the replacement character “�”) + pub fn into_string_lossy(mut self) -> String { + let mut pos = 0; + loop { + match self.next_surrogate(pos) { + Some((surrogate_pos, _)) => { + pos = surrogate_pos + 3; + self.bytes[surrogate_pos..pos] + .copy_from_slice(UTF8_REPLACEMENT_CHARACTER.as_bytes()); + } + None => return unsafe { String::from_utf8_unchecked(self.bytes) }, + } + } + } + + /// Converts this `Wtf8Buf` into a boxed `Wtf8`. + #[inline] + pub fn into_box(self) -> Box<Wtf8> { + unsafe { mem::transmute(self.bytes.into_boxed_slice()) } + } + + /// Converts a `Box<Wtf8>` into a `Wtf8Buf`. + pub fn from_box(boxed: Box<Wtf8>) -> Wtf8Buf { + let bytes: Box<[u8]> = unsafe { mem::transmute(boxed) }; + Wtf8Buf { bytes: bytes.into_vec() } + } +} + +/// Creates a new WTF-8 string from an iterator of code points. +/// +/// This replaces surrogate code point pairs with supplementary code points, +/// like concatenating ill-formed UTF-16 strings effectively would. +impl FromIterator<CodePoint> for Wtf8Buf { + fn from_iter<T: IntoIterator<Item = CodePoint>>(iter: T) -> Wtf8Buf { + let mut string = Wtf8Buf::new(); + string.extend(iter); + string + } +} + +/// Append code points from an iterator to the string. +/// +/// This replaces surrogate code point pairs with supplementary code points, +/// like concatenating ill-formed UTF-16 strings effectively would. +impl Extend<CodePoint> for Wtf8Buf { + fn extend<T: IntoIterator<Item = CodePoint>>(&mut self, iter: T) { + let iterator = iter.into_iter(); + let (low, _high) = iterator.size_hint(); + // Lower bound of one byte per code point (ASCII only) + self.bytes.reserve(low); + iterator.for_each(move |code_point| self.push(code_point)); + } + + #[inline] + fn extend_one(&mut self, code_point: CodePoint) { + self.push(code_point); + } + + #[inline] + fn extend_reserve(&mut self, additional: usize) { + // Lower bound of one byte per code point (ASCII only) + self.bytes.reserve(additional); + } +} + +/// A borrowed slice of well-formed WTF-8 data. +/// +/// Similar to `&str`, but can additionally contain surrogate code points +/// if they’re not in a surrogate pair. +#[derive(Eq, Ord, PartialEq, PartialOrd)] +pub struct Wtf8 { + bytes: [u8], +} + +impl AsInner<[u8]> for Wtf8 { + fn as_inner(&self) -> &[u8] { + &self.bytes + } +} + +/// Format the slice with double quotes, +/// and surrogates as `\u` followed by four hexadecimal digits. +/// Example: `"a\u{D800}"` for a slice with code points [U+0061, U+D800] +impl fmt::Debug for Wtf8 { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fn write_str_escaped(f: &mut fmt::Formatter<'_>, s: &str) -> fmt::Result { + use crate::fmt::Write; + for c in s.chars().flat_map(|c| c.escape_debug()) { + f.write_char(c)? + } + Ok(()) + } + + formatter.write_str("\"")?; + let mut pos = 0; + while let Some((surrogate_pos, surrogate)) = self.next_surrogate(pos) { + write_str_escaped(formatter, unsafe { + str::from_utf8_unchecked(&self.bytes[pos..surrogate_pos]) + })?; + write!(formatter, "\\u{{{:x}}}", surrogate)?; + pos = surrogate_pos + 3; + } + write_str_escaped(formatter, unsafe { str::from_utf8_unchecked(&self.bytes[pos..]) })?; + formatter.write_str("\"") + } +} + +impl fmt::Display for Wtf8 { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + let wtf8_bytes = &self.bytes; + let mut pos = 0; + loop { + match self.next_surrogate(pos) { + Some((surrogate_pos, _)) => { + formatter.write_str(unsafe { + str::from_utf8_unchecked(&wtf8_bytes[pos..surrogate_pos]) + })?; + formatter.write_str(UTF8_REPLACEMENT_CHARACTER)?; + pos = surrogate_pos + 3; + } + None => { + let s = unsafe { str::from_utf8_unchecked(&wtf8_bytes[pos..]) }; + if pos == 0 { return s.fmt(formatter) } else { return formatter.write_str(s) } + } + } + } + } +} + +impl Wtf8 { + /// Creates a WTF-8 slice from a UTF-8 `&str` slice. + /// + /// Since WTF-8 is a superset of UTF-8, this always succeeds. + #[inline] + pub fn from_str(value: &str) -> &Wtf8 { + unsafe { Wtf8::from_bytes_unchecked(value.as_bytes()) } + } + + /// Creates a WTF-8 slice from a WTF-8 byte slice. + /// + /// Since the byte slice is not checked for valid WTF-8, this functions is + /// marked unsafe. + #[inline] + unsafe fn from_bytes_unchecked(value: &[u8]) -> &Wtf8 { + mem::transmute(value) + } + + /// Creates a mutable WTF-8 slice from a mutable WTF-8 byte slice. + /// + /// Since the byte slice is not checked for valid WTF-8, this functions is + /// marked unsafe. + #[inline] + unsafe fn from_mut_bytes_unchecked(value: &mut [u8]) -> &mut Wtf8 { + mem::transmute(value) + } + + /// Returns the length, in WTF-8 bytes. + #[inline] + pub fn len(&self) -> usize { + self.bytes.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.bytes.is_empty() + } + + /// Returns the code point at `position` if it is in the ASCII range, + /// or `b'\xFF' otherwise. + /// + /// # Panics + /// + /// Panics if `position` is beyond the end of the string. + #[inline] + pub fn ascii_byte_at(&self, position: usize) -> u8 { + match self.bytes[position] { + ascii_byte @ 0x00..=0x7F => ascii_byte, + _ => 0xFF, + } + } + + /// Returns an iterator for the string’s code points. + #[inline] + pub fn code_points(&self) -> Wtf8CodePoints<'_> { + Wtf8CodePoints { bytes: self.bytes.iter() } + } + + /// Tries to convert the string to UTF-8 and return a `&str` slice. + /// + /// Returns `None` if the string contains surrogates. + /// + /// This does not copy the data. + #[inline] + pub fn as_str(&self) -> Option<&str> { + // Well-formed WTF-8 is also well-formed UTF-8 + // if and only if it contains no surrogate. + match self.next_surrogate(0) { + None => Some(unsafe { str::from_utf8_unchecked(&self.bytes) }), + Some(_) => None, + } + } + + /// Lossily converts the string to UTF-8. + /// Returns a UTF-8 `&str` slice if the contents are well-formed in UTF-8. + /// + /// Surrogates are replaced with `"\u{FFFD}"` (the replacement character “�”). + /// + /// This only copies the data if necessary (if it contains any surrogate). + pub fn to_string_lossy(&self) -> Cow<'_, str> { + let surrogate_pos = match self.next_surrogate(0) { + None => return Cow::Borrowed(unsafe { str::from_utf8_unchecked(&self.bytes) }), + Some((pos, _)) => pos, + }; + let wtf8_bytes = &self.bytes; + let mut utf8_bytes = Vec::with_capacity(self.len()); + utf8_bytes.extend_from_slice(&wtf8_bytes[..surrogate_pos]); + utf8_bytes.extend_from_slice(UTF8_REPLACEMENT_CHARACTER.as_bytes()); + let mut pos = surrogate_pos + 3; + loop { + match self.next_surrogate(pos) { + Some((surrogate_pos, _)) => { + utf8_bytes.extend_from_slice(&wtf8_bytes[pos..surrogate_pos]); + utf8_bytes.extend_from_slice(UTF8_REPLACEMENT_CHARACTER.as_bytes()); + pos = surrogate_pos + 3; + } + None => { + utf8_bytes.extend_from_slice(&wtf8_bytes[pos..]); + return Cow::Owned(unsafe { String::from_utf8_unchecked(utf8_bytes) }); + } + } + } + } + + /// Converts the WTF-8 string to potentially ill-formed UTF-16 + /// and return an iterator of 16-bit code units. + /// + /// This is lossless: + /// calling `Wtf8Buf::from_ill_formed_utf16` on the resulting code units + /// would always return the original WTF-8 string. + #[inline] + pub fn encode_wide(&self) -> EncodeWide<'_> { + EncodeWide { code_points: self.code_points(), extra: 0 } + } + + #[inline] + fn next_surrogate(&self, mut pos: usize) -> Option<(usize, u16)> { + let mut iter = self.bytes[pos..].iter(); + loop { + let b = *iter.next()?; + if b < 0x80 { + pos += 1; + } else if b < 0xE0 { + iter.next(); + pos += 2; + } else if b == 0xED { + match (iter.next(), iter.next()) { + (Some(&b2), Some(&b3)) if b2 >= 0xA0 => { + return Some((pos, decode_surrogate(b2, b3))); + } + _ => pos += 3, + } + } else if b < 0xF0 { + iter.next(); + iter.next(); + pos += 3; + } else { + iter.next(); + iter.next(); + iter.next(); + pos += 4; + } + } + } + + #[inline] + fn final_lead_surrogate(&self) -> Option<u16> { + match self.bytes { + [.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)), + _ => None, + } + } + + #[inline] + fn initial_trail_surrogate(&self) -> Option<u16> { + match self.bytes { + [0xED, b2 @ 0xB0..=0xBF, b3, ..] => Some(decode_surrogate(b2, b3)), + _ => None, + } + } + + pub fn clone_into(&self, buf: &mut Wtf8Buf) { + self.bytes.clone_into(&mut buf.bytes) + } + + /// Boxes this `Wtf8`. + #[inline] + pub fn into_box(&self) -> Box<Wtf8> { + let boxed: Box<[u8]> = self.bytes.into(); + unsafe { mem::transmute(boxed) } + } + + /// Creates a boxed, empty `Wtf8`. + pub fn empty_box() -> Box<Wtf8> { + let boxed: Box<[u8]> = Default::default(); + unsafe { mem::transmute(boxed) } + } + + #[inline] + pub fn into_arc(&self) -> Arc<Wtf8> { + let arc: Arc<[u8]> = Arc::from(&self.bytes); + unsafe { Arc::from_raw(Arc::into_raw(arc) as *const Wtf8) } + } + + #[inline] + pub fn into_rc(&self) -> Rc<Wtf8> { + let rc: Rc<[u8]> = Rc::from(&self.bytes); + unsafe { Rc::from_raw(Rc::into_raw(rc) as *const Wtf8) } + } + + #[inline] + pub fn make_ascii_lowercase(&mut self) { + self.bytes.make_ascii_lowercase() + } + + #[inline] + pub fn make_ascii_uppercase(&mut self) { + self.bytes.make_ascii_uppercase() + } + + #[inline] + pub fn to_ascii_lowercase(&self) -> Wtf8Buf { + Wtf8Buf { bytes: self.bytes.to_ascii_lowercase() } + } + + #[inline] + pub fn to_ascii_uppercase(&self) -> Wtf8Buf { + Wtf8Buf { bytes: self.bytes.to_ascii_uppercase() } + } + + #[inline] + pub fn is_ascii(&self) -> bool { + self.bytes.is_ascii() + } + + #[inline] + pub fn eq_ignore_ascii_case(&self, other: &Self) -> bool { + self.bytes.eq_ignore_ascii_case(&other.bytes) + } +} + +/// Returns a slice of the given string for the byte range \[`begin`..`end`). +/// +/// # Panics +/// +/// Panics when `begin` and `end` do not point to code point boundaries, +/// or point beyond the end of the string. +impl ops::Index<ops::Range<usize>> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, range: ops::Range<usize>) -> &Wtf8 { + // is_code_point_boundary checks that the index is in [0, .len()] + if range.start <= range.end + && is_code_point_boundary(self, range.start) + && is_code_point_boundary(self, range.end) + { + unsafe { slice_unchecked(self, range.start, range.end) } + } else { + slice_error_fail(self, range.start, range.end) + } + } +} + +/// Returns a slice of the given string from byte `begin` to its end. +/// +/// # Panics +/// +/// Panics when `begin` is not at a code point boundary, +/// or is beyond the end of the string. +impl ops::Index<ops::RangeFrom<usize>> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, range: ops::RangeFrom<usize>) -> &Wtf8 { + // is_code_point_boundary checks that the index is in [0, .len()] + if is_code_point_boundary(self, range.start) { + unsafe { slice_unchecked(self, range.start, self.len()) } + } else { + slice_error_fail(self, range.start, self.len()) + } + } +} + +/// Returns a slice of the given string from its beginning to byte `end`. +/// +/// # Panics +/// +/// Panics when `end` is not at a code point boundary, +/// or is beyond the end of the string. +impl ops::Index<ops::RangeTo<usize>> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, range: ops::RangeTo<usize>) -> &Wtf8 { + // is_code_point_boundary checks that the index is in [0, .len()] + if is_code_point_boundary(self, range.end) { + unsafe { slice_unchecked(self, 0, range.end) } + } else { + slice_error_fail(self, 0, range.end) + } + } +} + +impl ops::Index<ops::RangeFull> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, _range: ops::RangeFull) -> &Wtf8 { + self + } +} + +#[inline] +fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 { + // The first byte is assumed to be 0xED + 0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F +} + +#[inline] +fn decode_surrogate_pair(lead: u16, trail: u16) -> char { + let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32); + unsafe { char::from_u32_unchecked(code_point) } +} + +/// Copied from core::str::StrPrelude::is_char_boundary +#[inline] +pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { + if index == slice.len() { + return true; + } + match slice.bytes.get(index) { + None => false, + Some(&b) => b < 128 || b >= 192, + } +} + +/// Copied from core::str::raw::slice_unchecked +#[inline] +pub unsafe fn slice_unchecked(s: &Wtf8, begin: usize, end: usize) -> &Wtf8 { + // memory layout of a &[u8] and &Wtf8 are the same + Wtf8::from_bytes_unchecked(slice::from_raw_parts(s.bytes.as_ptr().add(begin), end - begin)) +} + +/// Copied from core::str::raw::slice_error_fail +#[inline(never)] +pub fn slice_error_fail(s: &Wtf8, begin: usize, end: usize) -> ! { + assert!(begin <= end); + panic!("index {begin} and/or {end} in `{s:?}` do not lie on character boundary"); +} + +/// Iterator for the code points of a WTF-8 string. +/// +/// Created with the method `.code_points()`. +#[derive(Clone)] +pub struct Wtf8CodePoints<'a> { + bytes: slice::Iter<'a, u8>, +} + +impl<'a> Iterator for Wtf8CodePoints<'a> { + type Item = CodePoint; + + #[inline] + fn next(&mut self) -> Option<CodePoint> { + // SAFETY: `self.bytes` has been created from a WTF-8 string + unsafe { next_code_point(&mut self.bytes).map(|c| CodePoint { value: c }) } + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + let len = self.bytes.len(); + (len.saturating_add(3) / 4, Some(len)) + } +} + +/// Generates a wide character sequence for potentially ill-formed UTF-16. +#[stable(feature = "rust1", since = "1.0.0")] +#[derive(Clone)] +pub struct EncodeWide<'a> { + code_points: Wtf8CodePoints<'a>, + extra: u16, +} + +// Copied from libunicode/u_str.rs +#[stable(feature = "rust1", since = "1.0.0")] +impl<'a> Iterator for EncodeWide<'a> { + type Item = u16; + + #[inline] + fn next(&mut self) -> Option<u16> { + if self.extra != 0 { + let tmp = self.extra; + self.extra = 0; + return Some(tmp); + } + + let mut buf = [0; 2]; + self.code_points.next().map(|code_point| { + let n = char::encode_utf16_raw(code_point.value, &mut buf).len(); + if n == 2 { + self.extra = buf[1]; + } + buf[0] + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + let (low, high) = self.code_points.size_hint(); + let ext = (self.extra != 0) as usize; + // every code point gets either one u16 or two u16, + // so this iterator is between 1 or 2 times as + // long as the underlying iterator. + (low + ext, high.and_then(|n| n.checked_mul(2)).and_then(|n| n.checked_add(ext))) + } +} + +#[stable(feature = "encode_wide_fused_iterator", since = "1.62.0")] +impl FusedIterator for EncodeWide<'_> {} + +impl Hash for CodePoint { + #[inline] + fn hash<H: Hasher>(&self, state: &mut H) { + self.value.hash(state) + } +} + +impl Hash for Wtf8Buf { + #[inline] + fn hash<H: Hasher>(&self, state: &mut H) { + state.write(&self.bytes); + 0xfeu8.hash(state) + } +} + +impl Hash for Wtf8 { + #[inline] + fn hash<H: Hasher>(&self, state: &mut H) { + state.write(&self.bytes); + 0xfeu8.hash(state) + } +} diff --git a/library/std/src/sys_common/wtf8/tests.rs b/library/std/src/sys_common/wtf8/tests.rs new file mode 100644 index 000000000..931996791 --- /dev/null +++ b/library/std/src/sys_common/wtf8/tests.rs @@ -0,0 +1,409 @@ +use super::*; +use crate::borrow::Cow; + +#[test] +fn code_point_from_u32() { + assert!(CodePoint::from_u32(0).is_some()); + assert!(CodePoint::from_u32(0xD800).is_some()); + assert!(CodePoint::from_u32(0x10FFFF).is_some()); + assert!(CodePoint::from_u32(0x110000).is_none()); +} + +#[test] +fn code_point_to_u32() { + fn c(value: u32) -> CodePoint { + CodePoint::from_u32(value).unwrap() + } + assert_eq!(c(0).to_u32(), 0); + assert_eq!(c(0xD800).to_u32(), 0xD800); + assert_eq!(c(0x10FFFF).to_u32(), 0x10FFFF); +} + +#[test] +fn code_point_from_char() { + assert_eq!(CodePoint::from_char('a').to_u32(), 0x61); + assert_eq!(CodePoint::from_char('💩').to_u32(), 0x1F4A9); +} + +#[test] +fn code_point_to_string() { + assert_eq!(format!("{:?}", CodePoint::from_char('a')), "U+0061"); + assert_eq!(format!("{:?}", CodePoint::from_char('💩')), "U+1F4A9"); +} + +#[test] +fn code_point_to_char() { + fn c(value: u32) -> CodePoint { + CodePoint::from_u32(value).unwrap() + } + assert_eq!(c(0x61).to_char(), Some('a')); + assert_eq!(c(0x1F4A9).to_char(), Some('💩')); + assert_eq!(c(0xD800).to_char(), None); +} + +#[test] +fn code_point_to_char_lossy() { + fn c(value: u32) -> CodePoint { + CodePoint::from_u32(value).unwrap() + } + assert_eq!(c(0x61).to_char_lossy(), 'a'); + assert_eq!(c(0x1F4A9).to_char_lossy(), '💩'); + assert_eq!(c(0xD800).to_char_lossy(), '\u{FFFD}'); +} + +#[test] +fn wtf8buf_new() { + assert_eq!(Wtf8Buf::new().bytes, b""); +} + +#[test] +fn wtf8buf_from_str() { + assert_eq!(Wtf8Buf::from_str("").bytes, b""); + assert_eq!(Wtf8Buf::from_str("aé 💩").bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); +} + +#[test] +fn wtf8buf_from_string() { + assert_eq!(Wtf8Buf::from_string(String::from("")).bytes, b""); + assert_eq!(Wtf8Buf::from_string(String::from("aé 💩")).bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); +} + +#[test] +fn wtf8buf_from_wide() { + assert_eq!(Wtf8Buf::from_wide(&[]).bytes, b""); + assert_eq!( + Wtf8Buf::from_wide(&[0x61, 0xE9, 0x20, 0xD83D, 0xD83D, 0xDCA9]).bytes, + b"a\xC3\xA9 \xED\xA0\xBD\xF0\x9F\x92\xA9" + ); +} + +#[test] +fn wtf8buf_push_str() { + let mut string = Wtf8Buf::new(); + assert_eq!(string.bytes, b""); + string.push_str("aé 💩"); + assert_eq!(string.bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); +} + +#[test] +fn wtf8buf_push_char() { + let mut string = Wtf8Buf::from_str("aé "); + assert_eq!(string.bytes, b"a\xC3\xA9 "); + string.push_char('💩'); + assert_eq!(string.bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); +} + +#[test] +fn wtf8buf_push() { + let mut string = Wtf8Buf::from_str("aé "); + assert_eq!(string.bytes, b"a\xC3\xA9 "); + string.push(CodePoint::from_char('💩')); + assert_eq!(string.bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); + + fn c(value: u32) -> CodePoint { + CodePoint::from_u32(value).unwrap() + } + + let mut string = Wtf8Buf::new(); + string.push(c(0xD83D)); // lead + string.push(c(0xDCA9)); // trail + assert_eq!(string.bytes, b"\xF0\x9F\x92\xA9"); // Magic! + + let mut string = Wtf8Buf::new(); + string.push(c(0xD83D)); // lead + string.push(c(0x20)); // not surrogate + string.push(c(0xDCA9)); // trail + assert_eq!(string.bytes, b"\xED\xA0\xBD \xED\xB2\xA9"); + + let mut string = Wtf8Buf::new(); + string.push(c(0xD800)); // lead + string.push(c(0xDBFF)); // lead + assert_eq!(string.bytes, b"\xED\xA0\x80\xED\xAF\xBF"); + + let mut string = Wtf8Buf::new(); + string.push(c(0xD800)); // lead + string.push(c(0xE000)); // not surrogate + assert_eq!(string.bytes, b"\xED\xA0\x80\xEE\x80\x80"); + + let mut string = Wtf8Buf::new(); + string.push(c(0xD7FF)); // not surrogate + string.push(c(0xDC00)); // trail + assert_eq!(string.bytes, b"\xED\x9F\xBF\xED\xB0\x80"); + + let mut string = Wtf8Buf::new(); + string.push(c(0x61)); // not surrogate, < 3 bytes + string.push(c(0xDC00)); // trail + assert_eq!(string.bytes, b"\x61\xED\xB0\x80"); + + let mut string = Wtf8Buf::new(); + string.push(c(0xDC00)); // trail + assert_eq!(string.bytes, b"\xED\xB0\x80"); +} + +#[test] +fn wtf8buf_push_wtf8() { + let mut string = Wtf8Buf::from_str("aé"); + assert_eq!(string.bytes, b"a\xC3\xA9"); + string.push_wtf8(Wtf8::from_str(" 💩")); + assert_eq!(string.bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); + + fn w(v: &[u8]) -> &Wtf8 { + unsafe { Wtf8::from_bytes_unchecked(v) } + } + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"\xED\xA0\xBD")); // lead + string.push_wtf8(w(b"\xED\xB2\xA9")); // trail + assert_eq!(string.bytes, b"\xF0\x9F\x92\xA9"); // Magic! + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"\xED\xA0\xBD")); // lead + string.push_wtf8(w(b" ")); // not surrogate + string.push_wtf8(w(b"\xED\xB2\xA9")); // trail + assert_eq!(string.bytes, b"\xED\xA0\xBD \xED\xB2\xA9"); + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"\xED\xA0\x80")); // lead + string.push_wtf8(w(b"\xED\xAF\xBF")); // lead + assert_eq!(string.bytes, b"\xED\xA0\x80\xED\xAF\xBF"); + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"\xED\xA0\x80")); // lead + string.push_wtf8(w(b"\xEE\x80\x80")); // not surrogate + assert_eq!(string.bytes, b"\xED\xA0\x80\xEE\x80\x80"); + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"\xED\x9F\xBF")); // not surrogate + string.push_wtf8(w(b"\xED\xB0\x80")); // trail + assert_eq!(string.bytes, b"\xED\x9F\xBF\xED\xB0\x80"); + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"a")); // not surrogate, < 3 bytes + string.push_wtf8(w(b"\xED\xB0\x80")); // trail + assert_eq!(string.bytes, b"\x61\xED\xB0\x80"); + + let mut string = Wtf8Buf::new(); + string.push_wtf8(w(b"\xED\xB0\x80")); // trail + assert_eq!(string.bytes, b"\xED\xB0\x80"); +} + +#[test] +fn wtf8buf_truncate() { + let mut string = Wtf8Buf::from_str("aé"); + string.truncate(1); + assert_eq!(string.bytes, b"a"); +} + +#[test] +#[should_panic] +fn wtf8buf_truncate_fail_code_point_boundary() { + let mut string = Wtf8Buf::from_str("aé"); + string.truncate(2); +} + +#[test] +#[should_panic] +fn wtf8buf_truncate_fail_longer() { + let mut string = Wtf8Buf::from_str("aé"); + string.truncate(4); +} + +#[test] +fn wtf8buf_into_string() { + let mut string = Wtf8Buf::from_str("aé 💩"); + assert_eq!(string.clone().into_string(), Ok(String::from("aé 💩"))); + string.push(CodePoint::from_u32(0xD800).unwrap()); + assert_eq!(string.clone().into_string(), Err(string)); +} + +#[test] +fn wtf8buf_into_string_lossy() { + let mut string = Wtf8Buf::from_str("aé 💩"); + assert_eq!(string.clone().into_string_lossy(), String::from("aé 💩")); + string.push(CodePoint::from_u32(0xD800).unwrap()); + assert_eq!(string.clone().into_string_lossy(), String::from("aé 💩�")); +} + +#[test] +fn wtf8buf_from_iterator() { + fn f(values: &[u32]) -> Wtf8Buf { + values.iter().map(|&c| CodePoint::from_u32(c).unwrap()).collect::<Wtf8Buf>() + } + assert_eq!(f(&[0x61, 0xE9, 0x20, 0x1F4A9]).bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); + + assert_eq!(f(&[0xD83D, 0xDCA9]).bytes, b"\xF0\x9F\x92\xA9"); // Magic! + assert_eq!(f(&[0xD83D, 0x20, 0xDCA9]).bytes, b"\xED\xA0\xBD \xED\xB2\xA9"); + assert_eq!(f(&[0xD800, 0xDBFF]).bytes, b"\xED\xA0\x80\xED\xAF\xBF"); + assert_eq!(f(&[0xD800, 0xE000]).bytes, b"\xED\xA0\x80\xEE\x80\x80"); + assert_eq!(f(&[0xD7FF, 0xDC00]).bytes, b"\xED\x9F\xBF\xED\xB0\x80"); + assert_eq!(f(&[0x61, 0xDC00]).bytes, b"\x61\xED\xB0\x80"); + assert_eq!(f(&[0xDC00]).bytes, b"\xED\xB0\x80"); +} + +#[test] +fn wtf8buf_extend() { + fn e(initial: &[u32], extended: &[u32]) -> Wtf8Buf { + fn c(value: &u32) -> CodePoint { + CodePoint::from_u32(*value).unwrap() + } + let mut string = initial.iter().map(c).collect::<Wtf8Buf>(); + string.extend(extended.iter().map(c)); + string + } + + assert_eq!(e(&[0x61, 0xE9], &[0x20, 0x1F4A9]).bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); + + assert_eq!(e(&[0xD83D], &[0xDCA9]).bytes, b"\xF0\x9F\x92\xA9"); // Magic! + assert_eq!(e(&[0xD83D, 0x20], &[0xDCA9]).bytes, b"\xED\xA0\xBD \xED\xB2\xA9"); + assert_eq!(e(&[0xD800], &[0xDBFF]).bytes, b"\xED\xA0\x80\xED\xAF\xBF"); + assert_eq!(e(&[0xD800], &[0xE000]).bytes, b"\xED\xA0\x80\xEE\x80\x80"); + assert_eq!(e(&[0xD7FF], &[0xDC00]).bytes, b"\xED\x9F\xBF\xED\xB0\x80"); + assert_eq!(e(&[0x61], &[0xDC00]).bytes, b"\x61\xED\xB0\x80"); + assert_eq!(e(&[], &[0xDC00]).bytes, b"\xED\xB0\x80"); +} + +#[test] +fn wtf8buf_show() { + let mut string = Wtf8Buf::from_str("a\té \u{7f}💩\r"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + assert_eq!(format!("{string:?}"), "\"a\\té \\u{7f}\u{1f4a9}\\r\\u{d800}\""); +} + +#[test] +fn wtf8buf_as_slice() { + assert_eq!(Wtf8Buf::from_str("aé").as_slice(), Wtf8::from_str("aé")); +} + +#[test] +fn wtf8buf_show_str() { + let text = "a\té 💩\r"; + let string = Wtf8Buf::from_str(text); + assert_eq!(format!("{text:?}"), format!("{string:?}")); +} + +#[test] +fn wtf8_from_str() { + assert_eq!(&Wtf8::from_str("").bytes, b""); + assert_eq!(&Wtf8::from_str("aé 💩").bytes, b"a\xC3\xA9 \xF0\x9F\x92\xA9"); +} + +#[test] +fn wtf8_len() { + assert_eq!(Wtf8::from_str("").len(), 0); + assert_eq!(Wtf8::from_str("aé 💩").len(), 8); +} + +#[test] +fn wtf8_slice() { + assert_eq!(&Wtf8::from_str("aé 💩")[1..4].bytes, b"\xC3\xA9 "); +} + +#[test] +#[should_panic] +fn wtf8_slice_not_code_point_boundary() { + let _ = &Wtf8::from_str("aé 💩")[2..4]; +} + +#[test] +fn wtf8_slice_from() { + assert_eq!(&Wtf8::from_str("aé 💩")[1..].bytes, b"\xC3\xA9 \xF0\x9F\x92\xA9"); +} + +#[test] +#[should_panic] +fn wtf8_slice_from_not_code_point_boundary() { + let _ = &Wtf8::from_str("aé 💩")[2..]; +} + +#[test] +fn wtf8_slice_to() { + assert_eq!(&Wtf8::from_str("aé 💩")[..4].bytes, b"a\xC3\xA9 "); +} + +#[test] +#[should_panic] +fn wtf8_slice_to_not_code_point_boundary() { + let _ = &Wtf8::from_str("aé 💩")[5..]; +} + +#[test] +fn wtf8_ascii_byte_at() { + let slice = Wtf8::from_str("aé 💩"); + assert_eq!(slice.ascii_byte_at(0), b'a'); + assert_eq!(slice.ascii_byte_at(1), b'\xFF'); + assert_eq!(slice.ascii_byte_at(2), b'\xFF'); + assert_eq!(slice.ascii_byte_at(3), b' '); + assert_eq!(slice.ascii_byte_at(4), b'\xFF'); +} + +#[test] +fn wtf8_code_points() { + fn c(value: u32) -> CodePoint { + CodePoint::from_u32(value).unwrap() + } + fn cp(string: &Wtf8Buf) -> Vec<Option<char>> { + string.code_points().map(|c| c.to_char()).collect::<Vec<_>>() + } + let mut string = Wtf8Buf::from_str("é "); + assert_eq!(cp(&string), [Some('é'), Some(' ')]); + string.push(c(0xD83D)); + assert_eq!(cp(&string), [Some('é'), Some(' '), None]); + string.push(c(0xDCA9)); + assert_eq!(cp(&string), [Some('é'), Some(' '), Some('💩')]); +} + +#[test] +fn wtf8_as_str() { + assert_eq!(Wtf8::from_str("").as_str(), Some("")); + assert_eq!(Wtf8::from_str("aé 💩").as_str(), Some("aé 💩")); + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + assert_eq!(string.as_str(), None); +} + +#[test] +fn wtf8_to_string_lossy() { + assert_eq!(Wtf8::from_str("").to_string_lossy(), Cow::Borrowed("")); + assert_eq!(Wtf8::from_str("aé 💩").to_string_lossy(), Cow::Borrowed("aé 💩")); + let mut string = Wtf8Buf::from_str("aé 💩"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + let expected: Cow<'_, str> = Cow::Owned(String::from("aé 💩�")); + assert_eq!(string.to_string_lossy(), expected); +} + +#[test] +fn wtf8_display() { + fn d(b: &[u8]) -> String { + (&unsafe { Wtf8::from_bytes_unchecked(b) }).to_string() + } + + assert_eq!("", d("".as_bytes())); + assert_eq!("aé 💩", d("aé 💩".as_bytes())); + + let mut string = Wtf8Buf::from_str("aé 💩"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + assert_eq!("aé 💩�", d(string.as_inner())); +} + +#[test] +fn wtf8_encode_wide() { + let mut string = Wtf8Buf::from_str("aé "); + string.push(CodePoint::from_u32(0xD83D).unwrap()); + string.push_char('💩'); + assert_eq!( + string.encode_wide().collect::<Vec<_>>(), + vec![0x61, 0xE9, 0x20, 0xD83D, 0xD83D, 0xDCA9] + ); +} + +#[test] +fn wtf8_encode_wide_size_hint() { + let string = Wtf8Buf::from_str("\u{12345}"); + let mut iter = string.encode_wide(); + assert_eq!((1, Some(8)), iter.size_hint()); + iter.next().unwrap(); + assert_eq!((1, Some(1)), iter.size_hint()); + iter.next().unwrap(); + assert_eq!((0, Some(0)), iter.size_hint()); + assert!(iter.next().is_none()); +} |