diff options
Diffstat (limited to 'third_party/rust/rand/src')
26 files changed, 8607 insertions, 0 deletions
diff --git a/third_party/rust/rand/src/distributions/bernoulli.rs b/third_party/rust/rand/src/distributions/bernoulli.rs new file mode 100644 index 0000000000..226db79fa9 --- /dev/null +++ b/third_party/rust/rand/src/distributions/bernoulli.rs @@ -0,0 +1,219 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Bernoulli distribution. + +use crate::distributions::Distribution; +use crate::Rng; +use core::{fmt, u64}; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; +/// The Bernoulli distribution. +/// +/// This is a special case of the Binomial distribution where `n = 1`. +/// +/// # Example +/// +/// ```rust +/// use rand::distributions::{Bernoulli, Distribution}; +/// +/// let d = Bernoulli::new(0.3).unwrap(); +/// let v = d.sample(&mut rand::thread_rng()); +/// println!("{} is from a Bernoulli distribution", v); +/// ``` +/// +/// # Precision +/// +/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`), +/// so only probabilities that are multiples of 2<sup>-64</sup> can be +/// represented. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct Bernoulli { + /// Probability of success, relative to the maximal integer. + p_int: u64, +} + +// To sample from the Bernoulli distribution we use a method that compares a +// random `u64` value `v < (p * 2^64)`. +// +// If `p == 1.0`, the integer `v` to compare against can not represented as a +// `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64). +// Note that value of `p < 1.0` can never result in `u64::MAX`, because an +// `f64` only has 53 bits of precision, and the next largest value of `p` will +// result in `2^64 - 2048`. +// +// Also there is a 100% theoretical concern: if someone consistently wants to +// generate `true` using the Bernoulli distribution (i.e. by using a probability +// of `1.0`), just using `u64::MAX` is not enough. On average it would return +// false once every 2^64 iterations. Some people apparently care about this +// case. +// +// That is why we special-case `u64::MAX` to always return `true`, without using +// the RNG, and pay the performance price for all uses that *are* reasonable. +// Luckily, if `new()` and `sample` are close, the compiler can optimize out the +// extra check. +const ALWAYS_TRUE: u64 = u64::MAX; + +// This is just `2.0.powi(64)`, but written this way because it is not available +// in `no_std` mode. +const SCALE: f64 = 2.0 * (1u64 << 63) as f64; + +/// Error type returned from `Bernoulli::new`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BernoulliError { + /// `p < 0` or `p > 1`. + InvalidProbability, +} + +impl fmt::Display for BernoulliError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + BernoulliError::InvalidProbability => "p is outside [0, 1] in Bernoulli distribution", + }) + } +} + +#[cfg(feature = "std")] +impl ::std::error::Error for BernoulliError {} + +impl Bernoulli { + /// Construct a new `Bernoulli` with the given probability of success `p`. + /// + /// # Precision + /// + /// For `p = 1.0`, the resulting distribution will always generate true. + /// For `p = 0.0`, the resulting distribution will always generate false. + /// + /// This method is accurate for any input `p` in the range `[0, 1]` which is + /// a multiple of 2<sup>-64</sup>. (Note that not all multiples of + /// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.) + #[inline] + pub fn new(p: f64) -> Result<Bernoulli, BernoulliError> { + if !(0.0..1.0).contains(&p) { + if p == 1.0 { + return Ok(Bernoulli { p_int: ALWAYS_TRUE }); + } + return Err(BernoulliError::InvalidProbability); + } + Ok(Bernoulli { + p_int: (p * SCALE) as u64, + }) + } + + /// Construct a new `Bernoulli` with the probability of success of + /// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return + /// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`. + /// + /// return `true`. If `numerator == 0` it will always return `false`. + /// For `numerator > denominator` and `denominator == 0`, this returns an + /// error. Otherwise, for `numerator == denominator`, samples are always + /// true; for `numerator == 0` samples are always false. + #[inline] + pub fn from_ratio(numerator: u32, denominator: u32) -> Result<Bernoulli, BernoulliError> { + if numerator > denominator || denominator == 0 { + return Err(BernoulliError::InvalidProbability); + } + if numerator == denominator { + return Ok(Bernoulli { p_int: ALWAYS_TRUE }); + } + let p_int = ((f64::from(numerator) / f64::from(denominator)) * SCALE) as u64; + Ok(Bernoulli { p_int }) + } +} + +impl Distribution<bool> for Bernoulli { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool { + // Make sure to always return true for p = 1.0. + if self.p_int == ALWAYS_TRUE { + return true; + } + let v: u64 = rng.gen(); + v < self.p_int + } +} + +#[cfg(test)] +mod test { + use super::Bernoulli; + use crate::distributions::Distribution; + use crate::Rng; + + #[test] + #[cfg(feature="serde1")] + fn test_serializing_deserializing_bernoulli() { + let coin_flip = Bernoulli::new(0.5).unwrap(); + let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); + + assert_eq!(coin_flip.p_int, de_coin_flip.p_int); + } + + #[test] + fn test_trivial() { + // We prefer to be explicit here. + #![allow(clippy::bool_assert_comparison)] + + let mut r = crate::test::rng(1); + let always_false = Bernoulli::new(0.0).unwrap(); + let always_true = Bernoulli::new(1.0).unwrap(); + for _ in 0..5 { + assert_eq!(r.sample::<bool, _>(&always_false), false); + assert_eq!(r.sample::<bool, _>(&always_true), true); + assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false); + assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_average() { + const P: f64 = 0.3; + const NUM: u32 = 3; + const DENOM: u32 = 10; + let d1 = Bernoulli::new(P).unwrap(); + let d2 = Bernoulli::from_ratio(NUM, DENOM).unwrap(); + const N: u32 = 100_000; + + let mut sum1: u32 = 0; + let mut sum2: u32 = 0; + let mut rng = crate::test::rng(2); + for _ in 0..N { + if d1.sample(&mut rng) { + sum1 += 1; + } + if d2.sample(&mut rng) { + sum2 += 1; + } + } + let avg1 = (sum1 as f64) / (N as f64); + assert!((avg1 - P).abs() < 5e-3); + + let avg2 = (sum2 as f64) / (N as f64); + assert!((avg2 - (NUM as f64) / (DENOM as f64)).abs() < 5e-3); + } + + #[test] + fn value_stability() { + let mut rng = crate::test::rng(3); + let distr = Bernoulli::new(0.4532).unwrap(); + let mut buf = [false; 10]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, [ + true, false, false, true, false, false, true, true, true, true + ]); + } + + #[test] + fn bernoulli_distributions_can_be_compared() { + assert_eq!(Bernoulli::new(1.0), Bernoulli::new(1.0)); + } +} diff --git a/third_party/rust/rand/src/distributions/distribution.rs b/third_party/rust/rand/src/distributions/distribution.rs new file mode 100644 index 0000000000..c5cf6a607b --- /dev/null +++ b/third_party/rust/rand/src/distributions/distribution.rs @@ -0,0 +1,272 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Distribution trait and associates + +use crate::Rng; +use core::iter; +#[cfg(feature = "alloc")] +use alloc::string::String; + +/// Types (distributions) that can be used to create a random instance of `T`. +/// +/// It is possible to sample from a distribution through both the +/// `Distribution` and [`Rng`] traits, via `distr.sample(&mut rng)` and +/// `rng.sample(distr)`. They also both offer the [`sample_iter`] method, which +/// produces an iterator that samples from the distribution. +/// +/// All implementations are expected to be immutable; this has the significant +/// advantage of not needing to consider thread safety, and for most +/// distributions efficient state-less sampling algorithms are available. +/// +/// Implementations are typically expected to be portable with reproducible +/// results when used with a PRNG with fixed seed; see the +/// [portability chapter](https://rust-random.github.io/book/portability.html) +/// of The Rust Rand Book. In some cases this does not apply, e.g. the `usize` +/// type requires different sampling on 32-bit and 64-bit machines. +/// +/// [`sample_iter`]: Distribution::sample_iter +pub trait Distribution<T> { + /// Generate a random value of `T`, using `rng` as the source of randomness. + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T; + + /// Create an iterator that generates random values of `T`, using `rng` as + /// the source of randomness. + /// + /// Note that this function takes `self` by value. This works since + /// `Distribution<T>` is impl'd for `&D` where `D: Distribution<T>`, + /// however borrowing is not automatic hence `distr.sample_iter(...)` may + /// need to be replaced with `(&distr).sample_iter(...)` to borrow or + /// `(&*distr).sample_iter(...)` to reborrow an existing reference. + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard}; + /// + /// let mut rng = thread_rng(); + /// + /// // Vec of 16 x f32: + /// let v: Vec<f32> = Standard.sample_iter(&mut rng).take(16).collect(); + /// + /// // String: + /// let s: String = Alphanumeric + /// .sample_iter(&mut rng) + /// .take(7) + /// .map(char::from) + /// .collect(); + /// + /// // Dice-rolling: + /// let die_range = Uniform::new_inclusive(1, 6); + /// let mut roll_die = die_range.sample_iter(&mut rng); + /// while roll_die.next().unwrap() != 6 { + /// println!("Not a 6; rolling again!"); + /// } + /// ``` + fn sample_iter<R>(self, rng: R) -> DistIter<Self, R, T> + where + R: Rng, + Self: Sized, + { + DistIter { + distr: self, + rng, + phantom: ::core::marker::PhantomData, + } + } + + /// Create a distribution of values of 'S' by mapping the output of `Self` + /// through the closure `F` + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::distributions::{Distribution, Uniform}; + /// + /// let mut rng = thread_rng(); + /// + /// let die = Uniform::new_inclusive(1, 6); + /// let even_number = die.map(|num| num % 2 == 0); + /// while !even_number.sample(&mut rng) { + /// println!("Still odd; rolling again!"); + /// } + /// ``` + fn map<F, S>(self, func: F) -> DistMap<Self, F, T, S> + where + F: Fn(T) -> S, + Self: Sized, + { + DistMap { + distr: self, + func, + phantom: ::core::marker::PhantomData, + } + } +} + +impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T { + (*self).sample(rng) + } +} + +/// An iterator that generates random values of `T` with distribution `D`, +/// using `R` as the source of randomness. +/// +/// This `struct` is created by the [`sample_iter`] method on [`Distribution`]. +/// See its documentation for more. +/// +/// [`sample_iter`]: Distribution::sample_iter +#[derive(Debug)] +pub struct DistIter<D, R, T> { + distr: D, + rng: R, + phantom: ::core::marker::PhantomData<T>, +} + +impl<D, R, T> Iterator for DistIter<D, R, T> +where + D: Distribution<T>, + R: Rng, +{ + type Item = T; + + #[inline(always)] + fn next(&mut self) -> Option<T> { + // Here, self.rng may be a reference, but we must take &mut anyway. + // Even if sample could take an R: Rng by value, we would need to do this + // since Rng is not copyable and we cannot enforce that this is "reborrowable". + Some(self.distr.sample(&mut self.rng)) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + (usize::max_value(), None) + } +} + +impl<D, R, T> iter::FusedIterator for DistIter<D, R, T> +where + D: Distribution<T>, + R: Rng, +{ +} + +#[cfg(features = "nightly")] +impl<D, R, T> iter::TrustedLen for DistIter<D, R, T> +where + D: Distribution<T>, + R: Rng, +{ +} + +/// A distribution of values of type `S` derived from the distribution `D` +/// by mapping its output of type `T` through the closure `F`. +/// +/// This `struct` is created by the [`Distribution::map`] method. +/// See its documentation for more. +#[derive(Debug)] +pub struct DistMap<D, F, T, S> { + distr: D, + func: F, + phantom: ::core::marker::PhantomData<fn(T) -> S>, +} + +impl<D, F, T, S> Distribution<S> for DistMap<D, F, T, S> +where + D: Distribution<T>, + F: Fn(T) -> S, +{ + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S { + (self.func)(self.distr.sample(rng)) + } +} + +/// `String` sampler +/// +/// Sampling a `String` of random characters is not quite the same as collecting +/// a sequence of chars. This trait contains some helpers. +#[cfg(feature = "alloc")] +pub trait DistString { + /// Append `len` random chars to `string` + fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize); + + /// Generate a `String` of `len` random chars + #[inline] + fn sample_string<R: Rng + ?Sized>(&self, rng: &mut R, len: usize) -> String { + let mut s = String::new(); + self.append_string(rng, &mut s, len); + s + } +} + +#[cfg(test)] +mod tests { + use crate::distributions::{Distribution, Uniform}; + use crate::Rng; + + #[test] + fn test_distributions_iter() { + use crate::distributions::Open01; + let mut rng = crate::test::rng(210); + let distr = Open01; + let mut iter = Distribution::<f32>::sample_iter(distr, &mut rng); + let mut sum: f32 = 0.; + for _ in 0..100 { + sum += iter.next().unwrap(); + } + assert!(0. < sum && sum < 100.); + } + + #[test] + fn test_distributions_map() { + let dist = Uniform::new_inclusive(0, 5).map(|val| val + 15); + + let mut rng = crate::test::rng(212); + let val = dist.sample(&mut rng); + assert!((15..=20).contains(&val)); + } + + #[test] + fn test_make_an_iter() { + fn ten_dice_rolls_other_than_five<R: Rng>( + rng: &mut R, + ) -> impl Iterator<Item = i32> + '_ { + Uniform::new_inclusive(1, 6) + .sample_iter(rng) + .filter(|x| *x != 5) + .take(10) + } + + let mut rng = crate::test::rng(211); + let mut count = 0; + for val in ten_dice_rolls_other_than_five(&mut rng) { + assert!((1..=6).contains(&val) && val != 5); + count += 1; + } + assert_eq!(count, 10); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_dist_string() { + use core::str; + use crate::distributions::{Alphanumeric, DistString, Standard}; + let mut rng = crate::test::rng(213); + + let s1 = Alphanumeric.sample_string(&mut rng, 20); + assert_eq!(s1.len(), 20); + assert_eq!(str::from_utf8(s1.as_bytes()), Ok(s1.as_str())); + + let s2 = Standard.sample_string(&mut rng, 20); + assert_eq!(s2.chars().count(), 20); + assert_eq!(str::from_utf8(s2.as_bytes()), Ok(s2.as_str())); + } +} diff --git a/third_party/rust/rand/src/distributions/float.rs b/third_party/rust/rand/src/distributions/float.rs new file mode 100644 index 0000000000..ce5946f7f0 --- /dev/null +++ b/third_party/rust/rand/src/distributions/float.rs @@ -0,0 +1,312 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Basic floating-point number distributions + +use crate::distributions::utils::FloatSIMDUtils; +use crate::distributions::{Distribution, Standard}; +use crate::Rng; +use core::mem; +#[cfg(feature = "simd_support")] use packed_simd::*; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; + +/// A distribution to sample floating point numbers uniformly in the half-open +/// interval `(0, 1]`, i.e. including 1 but not 0. +/// +/// All values that can be generated are of the form `n * ฮต/2`. For `f32` +/// the 24 most significant random bits of a `u32` are used and for `f64` the +/// 53 most significant bits of a `u64` are used. The conversion uses the +/// multiplicative method. +/// +/// See also: [`Standard`] which samples from `[0, 1)`, [`Open01`] +/// which samples from `(0, 1)` and [`Uniform`] which samples from arbitrary +/// ranges. +/// +/// # Example +/// ``` +/// use rand::{thread_rng, Rng}; +/// use rand::distributions::OpenClosed01; +/// +/// let val: f32 = thread_rng().sample(OpenClosed01); +/// println!("f32 from (0, 1): {}", val); +/// ``` +/// +/// [`Standard`]: crate::distributions::Standard +/// [`Open01`]: crate::distributions::Open01 +/// [`Uniform`]: crate::distributions::uniform::Uniform +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct OpenClosed01; + +/// A distribution to sample floating point numbers uniformly in the open +/// interval `(0, 1)`, i.e. not including either endpoint. +/// +/// All values that can be generated are of the form `n * ฮต + ฮต/2`. For `f32` +/// the 23 most significant random bits of an `u32` are used, for `f64` 52 from +/// an `u64`. The conversion uses a transmute-based method. +/// +/// See also: [`Standard`] which samples from `[0, 1)`, [`OpenClosed01`] +/// which samples from `(0, 1]` and [`Uniform`] which samples from arbitrary +/// ranges. +/// +/// # Example +/// ``` +/// use rand::{thread_rng, Rng}; +/// use rand::distributions::Open01; +/// +/// let val: f32 = thread_rng().sample(Open01); +/// println!("f32 from (0, 1): {}", val); +/// ``` +/// +/// [`Standard`]: crate::distributions::Standard +/// [`OpenClosed01`]: crate::distributions::OpenClosed01 +/// [`Uniform`]: crate::distributions::uniform::Uniform +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct Open01; + + +// This trait is needed by both this lib and rand_distr hence is a hidden export +#[doc(hidden)] +pub trait IntoFloat { + type F; + + /// Helper method to combine the fraction and a constant exponent into a + /// float. + /// + /// Only the least significant bits of `self` may be set, 23 for `f32` and + /// 52 for `f64`. + /// The resulting value will fall in a range that depends on the exponent. + /// As an example the range with exponent 0 will be + /// [2<sup>0</sup>..2<sup>1</sup>), which is [1..2). + fn into_float_with_exponent(self, exponent: i32) -> Self::F; +} + +macro_rules! float_impls { + ($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, + $fraction_bits:expr, $exponent_bias:expr) => { + impl IntoFloat for $uty { + type F = $ty; + #[inline(always)] + fn into_float_with_exponent(self, exponent: i32) -> $ty { + // The exponent is encoded using an offset-binary representation + let exponent_bits: $u_scalar = + (($exponent_bias + exponent) as $u_scalar) << $fraction_bits; + $ty::from_bits(self | exponent_bits) + } + } + + impl Distribution<$ty> for Standard { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + // Multiply-based method; 24/53 random bits; [0, 1) interval. + // We use the most significant bits because for simple RNGs + // those are usually more random. + let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let precision = $fraction_bits + 1; + let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); + + let value: $uty = rng.gen(); + let value = value >> (float_size - precision); + scale * $ty::cast_from_int(value) + } + } + + impl Distribution<$ty> for OpenClosed01 { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + // Multiply-based method; 24/53 random bits; (0, 1] interval. + // We use the most significant bits because for simple RNGs + // those are usually more random. + let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let precision = $fraction_bits + 1; + let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); + + let value: $uty = rng.gen(); + let value = value >> (float_size - precision); + // Add 1 to shift up; will not overflow because of right-shift: + scale * $ty::cast_from_int(value + 1) + } + } + + impl Distribution<$ty> for Open01 { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + // Transmute-based method; 23/52 random bits; (0, 1) interval. + // We use the most significant bits because for simple RNGs + // those are usually more random. + use core::$f_scalar::EPSILON; + let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + + let value: $uty = rng.gen(); + let fraction = value >> (float_size - $fraction_bits); + fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0) + } + } + } +} + +float_impls! { f32, u32, f32, u32, 23, 127 } +float_impls! { f64, u64, f64, u64, 52, 1023 } + +#[cfg(feature = "simd_support")] +float_impls! { f32x2, u32x2, f32, u32, 23, 127 } +#[cfg(feature = "simd_support")] +float_impls! { f32x4, u32x4, f32, u32, 23, 127 } +#[cfg(feature = "simd_support")] +float_impls! { f32x8, u32x8, f32, u32, 23, 127 } +#[cfg(feature = "simd_support")] +float_impls! { f32x16, u32x16, f32, u32, 23, 127 } + +#[cfg(feature = "simd_support")] +float_impls! { f64x2, u64x2, f64, u64, 52, 1023 } +#[cfg(feature = "simd_support")] +float_impls! { f64x4, u64x4, f64, u64, 52, 1023 } +#[cfg(feature = "simd_support")] +float_impls! { f64x8, u64x8, f64, u64, 52, 1023 } + + +#[cfg(test)] +mod tests { + use super::*; + use crate::rngs::mock::StepRng; + + const EPSILON32: f32 = ::core::f32::EPSILON; + const EPSILON64: f64 = ::core::f64::EPSILON; + + macro_rules! test_f32 { + ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { + #[test] + fn $fnn() { + // Standard + let mut zeros = StepRng::new(0, 0); + assert_eq!(zeros.gen::<$ty>(), $ZERO); + let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); + assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0); + let mut max = StepRng::new(!0, 0); + assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0); + + // OpenClosed01 + let mut zeros = StepRng::new(0, 0); + assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0); + let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); + assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); + let mut max = StepRng::new(!0, 0); + assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0); + + // Open01 + let mut zeros = StepRng::new(0, 0); + assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0); + let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); + assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0); + let mut max = StepRng::new(!0, 0); + assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0); + } + }; + } + test_f32! { f32_edge_cases, f32, 0.0, EPSILON32 } + #[cfg(feature = "simd_support")] + test_f32! { f32x2_edge_cases, f32x2, f32x2::splat(0.0), f32x2::splat(EPSILON32) } + #[cfg(feature = "simd_support")] + test_f32! { f32x4_edge_cases, f32x4, f32x4::splat(0.0), f32x4::splat(EPSILON32) } + #[cfg(feature = "simd_support")] + test_f32! { f32x8_edge_cases, f32x8, f32x8::splat(0.0), f32x8::splat(EPSILON32) } + #[cfg(feature = "simd_support")] + test_f32! { f32x16_edge_cases, f32x16, f32x16::splat(0.0), f32x16::splat(EPSILON32) } + + macro_rules! test_f64 { + ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { + #[test] + fn $fnn() { + // Standard + let mut zeros = StepRng::new(0, 0); + assert_eq!(zeros.gen::<$ty>(), $ZERO); + let mut one = StepRng::new(1 << 11, 0); + assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0); + let mut max = StepRng::new(!0, 0); + assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0); + + // OpenClosed01 + let mut zeros = StepRng::new(0, 0); + assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0); + let mut one = StepRng::new(1 << 11, 0); + assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); + let mut max = StepRng::new(!0, 0); + assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0); + + // Open01 + let mut zeros = StepRng::new(0, 0); + assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0); + let mut one = StepRng::new(1 << 12, 0); + assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0); + let mut max = StepRng::new(!0, 0); + assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0); + } + }; + } + test_f64! { f64_edge_cases, f64, 0.0, EPSILON64 } + #[cfg(feature = "simd_support")] + test_f64! { f64x2_edge_cases, f64x2, f64x2::splat(0.0), f64x2::splat(EPSILON64) } + #[cfg(feature = "simd_support")] + test_f64! { f64x4_edge_cases, f64x4, f64x4::splat(0.0), f64x4::splat(EPSILON64) } + #[cfg(feature = "simd_support")] + test_f64! { f64x8_edge_cases, f64x8, f64x8::splat(0.0), f64x8::splat(EPSILON64) } + + #[test] + fn value_stability() { + fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>( + distr: &D, zero: T, expected: &[T], + ) { + let mut rng = crate::test::rng(0x6f44f5646c2a7334); + let mut buf = [zero; 3]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + } + + test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]); + test_samples(&Standard, 0f64, &[ + 0.7346051961657583, + 0.20298547462974248, + 0.8166436635290655, + ]); + + test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]); + test_samples(&OpenClosed01, 0f64, &[ + 0.7346051961657584, + 0.2029854746297426, + 0.8166436635290656, + ]); + + test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]); + test_samples(&Open01, 0f64, &[ + 0.7346051961657584, + 0.20298547462974248, + 0.8166436635290656, + ]); + + #[cfg(feature = "simd_support")] + { + // We only test a sub-set of types here. Values are identical to + // non-SIMD types; we assume this pattern continues across all + // SIMD types. + + test_samples(&Standard, f32x2::new(0.0, 0.0), &[ + f32x2::new(0.0035963655, 0.7346052), + f32x2::new(0.09778172, 0.20298547), + f32x2::new(0.34296435, 0.81664366), + ]); + + test_samples(&Standard, f64x2::new(0.0, 0.0), &[ + f64x2::new(0.7346051961657583, 0.20298547462974248), + f64x2::new(0.8166436635290655, 0.7423708925400552), + f64x2::new(0.16387782224016323, 0.9087068770169618), + ]); + } + } +} diff --git a/third_party/rust/rand/src/distributions/integer.rs b/third_party/rust/rand/src/distributions/integer.rs new file mode 100644 index 0000000000..19ce71599c --- /dev/null +++ b/third_party/rust/rand/src/distributions/integer.rs @@ -0,0 +1,274 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The implementations of the `Standard` distribution for integer types. + +use crate::distributions::{Distribution, Standard}; +use crate::Rng; +#[cfg(all(target_arch = "x86", feature = "simd_support"))] +use core::arch::x86::{__m128i, __m256i}; +#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] +use core::arch::x86_64::{__m128i, __m256i}; +use core::num::{NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize, + NonZeroU128}; +#[cfg(feature = "simd_support")] use packed_simd::*; + +impl Distribution<u8> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u8 { + rng.next_u32() as u8 + } +} + +impl Distribution<u16> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u16 { + rng.next_u32() as u16 + } +} + +impl Distribution<u32> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u32 { + rng.next_u32() + } +} + +impl Distribution<u64> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { + rng.next_u64() + } +} + +impl Distribution<u128> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 { + // Use LE; we explicitly generate one value before the next. + let x = u128::from(rng.next_u64()); + let y = u128::from(rng.next_u64()); + (y << 64) | x + } +} + +impl Distribution<usize> for Standard { + #[inline] + #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { + rng.next_u32() as usize + } + + #[inline] + #[cfg(target_pointer_width = "64")] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { + rng.next_u64() as usize + } +} + +macro_rules! impl_int_from_uint { + ($ty:ty, $uty:ty) => { + impl Distribution<$ty> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + rng.gen::<$uty>() as $ty + } + } + }; +} + +impl_int_from_uint! { i8, u8 } +impl_int_from_uint! { i16, u16 } +impl_int_from_uint! { i32, u32 } +impl_int_from_uint! { i64, u64 } +impl_int_from_uint! { i128, u128 } +impl_int_from_uint! { isize, usize } + +macro_rules! impl_nzint { + ($ty:ty, $new:path) => { + impl Distribution<$ty> for Standard { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + loop { + if let Some(nz) = $new(rng.gen()) { + break nz; + } + } + } + } + }; +} + +impl_nzint!(NonZeroU8, NonZeroU8::new); +impl_nzint!(NonZeroU16, NonZeroU16::new); +impl_nzint!(NonZeroU32, NonZeroU32::new); +impl_nzint!(NonZeroU64, NonZeroU64::new); +impl_nzint!(NonZeroU128, NonZeroU128::new); +impl_nzint!(NonZeroUsize, NonZeroUsize::new); + +#[cfg(feature = "simd_support")] +macro_rules! simd_impl { + ($(($intrinsic:ident, $vec:ty),)+) => {$( + impl Distribution<$intrinsic> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $intrinsic { + $intrinsic::from_bits(rng.gen::<$vec>()) + } + } + )+}; + + ($bits:expr,) => {}; + ($bits:expr, $ty:ty, $($ty_more:ty,)*) => { + simd_impl!($bits, $($ty_more,)*); + + impl Distribution<$ty> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + let mut vec: $ty = Default::default(); + unsafe { + let ptr = &mut vec; + let b_ptr = &mut *(ptr as *mut $ty as *mut [u8; $bits/8]); + rng.fill_bytes(b_ptr); + } + vec.to_le() + } + } + }; +} + +#[cfg(feature = "simd_support")] +simd_impl!(16, u8x2, i8x2,); +#[cfg(feature = "simd_support")] +simd_impl!(32, u8x4, i8x4, u16x2, i16x2,); +#[cfg(feature = "simd_support")] +simd_impl!(64, u8x8, i8x8, u16x4, i16x4, u32x2, i32x2,); +#[cfg(feature = "simd_support")] +simd_impl!(128, u8x16, i8x16, u16x8, i16x8, u32x4, i32x4, u64x2, i64x2,); +#[cfg(feature = "simd_support")] +simd_impl!(256, u8x32, i8x32, u16x16, i16x16, u32x8, i32x8, u64x4, i64x4,); +#[cfg(feature = "simd_support")] +simd_impl!(512, u8x64, i8x64, u16x32, i16x32, u32x16, i32x16, u64x8, i64x8,); +#[cfg(all( + feature = "simd_support", + any(target_arch = "x86", target_arch = "x86_64") +))] +simd_impl!((__m128i, u8x16), (__m256i, u8x32),); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_integers() { + let mut rng = crate::test::rng(806); + + rng.sample::<isize, _>(Standard); + rng.sample::<i8, _>(Standard); + rng.sample::<i16, _>(Standard); + rng.sample::<i32, _>(Standard); + rng.sample::<i64, _>(Standard); + rng.sample::<i128, _>(Standard); + + rng.sample::<usize, _>(Standard); + rng.sample::<u8, _>(Standard); + rng.sample::<u16, _>(Standard); + rng.sample::<u32, _>(Standard); + rng.sample::<u64, _>(Standard); + rng.sample::<u128, _>(Standard); + } + + #[test] + fn value_stability() { + fn test_samples<T: Copy + core::fmt::Debug + PartialEq>(zero: T, expected: &[T]) + where Standard: Distribution<T> { + let mut rng = crate::test::rng(807); + let mut buf = [zero; 3]; + for x in &mut buf { + *x = rng.sample(Standard); + } + assert_eq!(&buf, expected); + } + + test_samples(0u8, &[9, 247, 111]); + test_samples(0u16, &[32265, 42999, 38255]); + test_samples(0u32, &[2220326409, 2575017975, 2018088303]); + test_samples(0u64, &[ + 11059617991457472009, + 16096616328739788143, + 1487364411147516184, + ]); + test_samples(0u128, &[ + 296930161868957086625409848350820761097, + 145644820879247630242265036535529306392, + 111087889832015897993126088499035356354, + ]); + #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] + test_samples(0usize, &[2220326409, 2575017975, 2018088303]); + #[cfg(target_pointer_width = "64")] + test_samples(0usize, &[ + 11059617991457472009, + 16096616328739788143, + 1487364411147516184, + ]); + + test_samples(0i8, &[9, -9, 111]); + // Skip further i* types: they are simple reinterpretation of u* samples + + #[cfg(feature = "simd_support")] + { + // We only test a sub-set of types here and make assumptions about the rest. + + test_samples(u8x2::default(), &[ + u8x2::new(9, 126), + u8x2::new(247, 167), + u8x2::new(111, 149), + ]); + test_samples(u8x4::default(), &[ + u8x4::new(9, 126, 87, 132), + u8x4::new(247, 167, 123, 153), + u8x4::new(111, 149, 73, 120), + ]); + test_samples(u8x8::default(), &[ + u8x8::new(9, 126, 87, 132, 247, 167, 123, 153), + u8x8::new(111, 149, 73, 120, 68, 171, 98, 223), + u8x8::new(24, 121, 1, 50, 13, 46, 164, 20), + ]); + + test_samples(i64x8::default(), &[ + i64x8::new( + -7387126082252079607, + -2350127744969763473, + 1487364411147516184, + 7895421560427121838, + 602190064936008898, + 6022086574635100741, + -5080089175222015595, + -4066367846667249123, + ), + i64x8::new( + 9180885022207963908, + 3095981199532211089, + 6586075293021332726, + 419343203796414657, + 3186951873057035255, + 5287129228749947252, + 444726432079249540, + -1587028029513790706, + ), + i64x8::new( + 6075236523189346388, + 1351763722368165432, + -6192309979959753740, + -7697775502176768592, + -4482022114172078123, + 7522501477800909500, + -1837258847956201231, + -586926753024886735, + ), + ]); + } + } +} diff --git a/third_party/rust/rand/src/distributions/mod.rs b/third_party/rust/rand/src/distributions/mod.rs new file mode 100644 index 0000000000..05ca80606b --- /dev/null +++ b/third_party/rust/rand/src/distributions/mod.rs @@ -0,0 +1,218 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating random samples from probability distributions +//! +//! This module is the home of the [`Distribution`] trait and several of its +//! implementations. It is the workhorse behind some of the convenient +//! functionality of the [`Rng`] trait, e.g. [`Rng::gen`] and of course +//! [`Rng::sample`]. +//! +//! Abstractly, a [probability distribution] describes the probability of +//! occurrence of each value in its sample space. +//! +//! More concretely, an implementation of `Distribution<T>` for type `X` is an +//! algorithm for choosing values from the sample space (a subset of `T`) +//! according to the distribution `X` represents, using an external source of +//! randomness (an RNG supplied to the `sample` function). +//! +//! A type `X` may implement `Distribution<T>` for multiple types `T`. +//! Any type implementing [`Distribution`] is stateless (i.e. immutable), +//! but it may have internal parameters set at construction time (for example, +//! [`Uniform`] allows specification of its sample space as a range within `T`). +//! +//! +//! # The `Standard` distribution +//! +//! The [`Standard`] distribution is important to mention. This is the +//! distribution used by [`Rng::gen`] and represents the "default" way to +//! produce a random value for many different types, including most primitive +//! types, tuples, arrays, and a few derived types. See the documentation of +//! [`Standard`] for more details. +//! +//! Implementing `Distribution<T>` for [`Standard`] for user types `T` makes it +//! possible to generate type `T` with [`Rng::gen`], and by extension also +//! with the [`random`] function. +//! +//! ## Random characters +//! +//! [`Alphanumeric`] is a simple distribution to sample random letters and +//! numbers of the `char` type; in contrast [`Standard`] may sample any valid +//! `char`. +//! +//! +//! # Uniform numeric ranges +//! +//! The [`Uniform`] distribution is more flexible than [`Standard`], but also +//! more specialised: it supports fewer target types, but allows the sample +//! space to be specified as an arbitrary range within its target type `T`. +//! Both [`Standard`] and [`Uniform`] are in some sense uniform distributions. +//! +//! Values may be sampled from this distribution using [`Rng::sample(Range)`] or +//! by creating a distribution object with [`Uniform::new`], +//! [`Uniform::new_inclusive`] or `From<Range>`. When the range limits are not +//! known at compile time it is typically faster to reuse an existing +//! `Uniform` object than to call [`Rng::sample(Range)`]. +//! +//! User types `T` may also implement `Distribution<T>` for [`Uniform`], +//! although this is less straightforward than for [`Standard`] (see the +//! documentation in the [`uniform`] module). Doing so enables generation of +//! values of type `T` with [`Rng::sample(Range)`]. +//! +//! ## Open and half-open ranges +//! +//! There are surprisingly many ways to uniformly generate random floats. A +//! range between 0 and 1 is standard, but the exact bounds (open vs closed) +//! and accuracy differ. In addition to the [`Standard`] distribution Rand offers +//! [`Open01`] and [`OpenClosed01`]. See "Floating point implementation" section of +//! [`Standard`] documentation for more details. +//! +//! # Non-uniform sampling +//! +//! Sampling a simple true/false outcome with a given probability has a name: +//! the [`Bernoulli`] distribution (this is used by [`Rng::gen_bool`]). +//! +//! For weighted sampling from a sequence of discrete values, use the +//! [`WeightedIndex`] distribution. +//! +//! This crate no longer includes other non-uniform distributions; instead +//! it is recommended that you use either [`rand_distr`] or [`statrs`]. +//! +//! +//! [probability distribution]: https://en.wikipedia.org/wiki/Probability_distribution +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +//! [`random`]: crate::random +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +mod bernoulli; +mod distribution; +mod float; +mod integer; +mod other; +mod slice; +mod utils; +#[cfg(feature = "alloc")] +mod weighted_index; + +#[doc(hidden)] +pub mod hidden_export { + pub use super::float::IntoFloat; // used by rand_distr +} +pub mod uniform; +#[deprecated( + since = "0.8.0", + note = "use rand::distributions::{WeightedIndex, WeightedError} instead" +)] +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub mod weighted; + +pub use self::bernoulli::{Bernoulli, BernoulliError}; +pub use self::distribution::{Distribution, DistIter, DistMap}; +#[cfg(feature = "alloc")] +pub use self::distribution::DistString; +pub use self::float::{Open01, OpenClosed01}; +pub use self::other::Alphanumeric; +pub use self::slice::Slice; +#[doc(inline)] +pub use self::uniform::Uniform; +#[cfg(feature = "alloc")] +pub use self::weighted_index::{WeightedError, WeightedIndex}; + +#[allow(unused)] +use crate::Rng; + +/// A generic random value distribution, implemented for many primitive types. +/// Usually generates values with a numerically uniform distribution, and with a +/// range appropriate to the type. +/// +/// ## Provided implementations +/// +/// Assuming the provided `Rng` is well-behaved, these implementations +/// generate values with the following ranges and distributions: +/// +/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed +/// over all values of the type. +/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all +/// code points in the range `0...0x10_FFFF`, except for the range +/// `0xD800...0xDFFF` (the surrogate code points). This includes +/// unassigned/reserved code points. +/// * `bool`: Generates `false` or `true`, each with probability 0.5. +/// * Floating point types (`f32` and `f64`): Uniformly distributed in the +/// half-open range `[0, 1)`. See notes below. +/// * Wrapping integers (`Wrapping<T>`), besides the type identical to their +/// normal integer variants. +/// +/// The `Standard` distribution also supports generation of the following +/// compound types where all component types are supported: +/// +/// * Tuples (up to 12 elements): each element is generated sequentially. +/// * Arrays (up to 32 elements): each element is generated sequentially; +/// see also [`Rng::fill`] which supports arbitrary array length for integer +/// and float types and tends to be faster for `u32` and smaller types. +/// When using `rustc` โฅ 1.51, enable the `min_const_gen` feature to support +/// arrays larger than 32 elements. +/// Note that [`Rng::fill`] and `Standard`'s array support are *not* equivalent: +/// the former is optimised for integer types (using fewer RNG calls for +/// element types smaller than the RNG word size), while the latter supports +/// any element type supported by `Standard`. +/// * `Option<T>` first generates a `bool`, and if true generates and returns +/// `Some(value)` where `value: T`, otherwise returning `None`. +/// +/// ## Custom implementations +/// +/// The [`Standard`] distribution may be implemented for user types as follows: +/// +/// ``` +/// # #![allow(dead_code)] +/// use rand::Rng; +/// use rand::distributions::{Distribution, Standard}; +/// +/// struct MyF32 { +/// x: f32, +/// } +/// +/// impl Distribution<MyF32> for Standard { +/// fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> MyF32 { +/// MyF32 { x: rng.gen() } +/// } +/// } +/// ``` +/// +/// ## Example usage +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::Standard; +/// +/// let val: f32 = StdRng::from_entropy().sample(Standard); +/// println!("f32 from [0, 1): {}", val); +/// ``` +/// +/// # Floating point implementation +/// The floating point implementations for `Standard` generate a random value in +/// the half-open interval `[0, 1)`, i.e. including 0 but not 1. +/// +/// All values that can be generated are of the form `n * ฮต/2`. For `f32` +/// the 24 most significant random bits of a `u32` are used and for `f64` the +/// 53 most significant bits of a `u64` are used. The conversion uses the +/// multiplicative method: `(rng.gen::<$uty>() >> N) as $ty * (ฮต/2)`. +/// +/// See also: [`Open01`] which samples from `(0, 1)`, [`OpenClosed01`] which +/// samples from `(0, 1]` and `Rng::gen_range(0..1)` which also samples from +/// `[0, 1)`. Note that `Open01` uses transmute-based methods which yield 1 bit +/// less precision but may perform faster on some architectures (on modern Intel +/// CPUs all methods have approximately equal performance). +/// +/// [`Uniform`]: uniform::Uniform +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +pub struct Standard; diff --git a/third_party/rust/rand/src/distributions/other.rs b/third_party/rust/rand/src/distributions/other.rs new file mode 100644 index 0000000000..03802a76d5 --- /dev/null +++ b/third_party/rust/rand/src/distributions/other.rs @@ -0,0 +1,365 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The implementations of the `Standard` distribution for other built-in types. + +use core::char; +use core::num::Wrapping; +#[cfg(feature = "alloc")] +use alloc::string::String; + +use crate::distributions::{Distribution, Standard, Uniform}; +#[cfg(feature = "alloc")] +use crate::distributions::DistString; +use crate::Rng; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; +#[cfg(feature = "min_const_gen")] +use core::mem::{self, MaybeUninit}; + + +// ----- Sampling distributions ----- + +/// Sample a `u8`, uniformly distributed over ASCII letters and numbers: +/// a-z, A-Z and 0-9. +/// +/// # Example +/// +/// ``` +/// use rand::{Rng, thread_rng}; +/// use rand::distributions::Alphanumeric; +/// +/// let mut rng = thread_rng(); +/// let chars: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); +/// println!("Random chars: {}", chars); +/// ``` +/// +/// The [`DistString`] trait provides an easier method of generating +/// a random `String`, and offers more efficient allocation: +/// ``` +/// use rand::distributions::{Alphanumeric, DistString}; +/// let string = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); +/// println!("Random string: {}", string); +/// ``` +/// +/// # Passwords +/// +/// Users sometimes ask whether it is safe to use a string of random characters +/// as a password. In principle, all RNGs in Rand implementing `CryptoRng` are +/// suitable as a source of randomness for generating passwords (if they are +/// properly seeded), but it is more conservative to only use randomness +/// directly from the operating system via the `getrandom` crate, or the +/// corresponding bindings of a crypto library. +/// +/// When generating passwords or keys, it is important to consider the threat +/// model and in some cases the memorability of the password. This is out of +/// scope of the Rand project, and therefore we defer to the following +/// references: +/// +/// - [Wikipedia article on Password Strength](https://en.wikipedia.org/wiki/Password_strength) +/// - [Diceware for generating memorable passwords](https://en.wikipedia.org/wiki/Diceware) +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct Alphanumeric; + + +// ----- Implementations of distributions ----- + +impl Distribution<char> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> char { + // A valid `char` is either in the interval `[0, 0xD800)` or + // `(0xDFFF, 0x11_0000)`. All `char`s must therefore be in + // `[0, 0x11_0000)` but not in the "gap" `[0xD800, 0xDFFF]` which is + // reserved for surrogates. This is the size of that gap. + const GAP_SIZE: u32 = 0xDFFF - 0xD800 + 1; + + // Uniform::new(0, 0x11_0000 - GAP_SIZE) can also be used but it + // seemed slower. + let range = Uniform::new(GAP_SIZE, 0x11_0000); + + let mut n = range.sample(rng); + if n <= 0xDFFF { + n -= GAP_SIZE; + } + unsafe { char::from_u32_unchecked(n) } + } +} + +/// Note: the `String` is potentially left with excess capacity; optionally the +/// user may call `string.shrink_to_fit()` afterwards. +#[cfg(feature = "alloc")] +impl DistString for Standard { + fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, s: &mut String, len: usize) { + // A char is encoded with at most four bytes, thus this reservation is + // guaranteed to be sufficient. We do not shrink_to_fit afterwards so + // that repeated usage on the same `String` buffer does not reallocate. + s.reserve(4 * len); + s.extend(Distribution::<char>::sample_iter(self, rng).take(len)); + } +} + +impl Distribution<u8> for Alphanumeric { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u8 { + const RANGE: u32 = 26 + 26 + 10; + const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ + abcdefghijklmnopqrstuvwxyz\ + 0123456789"; + // We can pick from 62 characters. This is so close to a power of 2, 64, + // that we can do better than `Uniform`. Use a simple bitshift and + // rejection sampling. We do not use a bitmask, because for small RNGs + // the most significant bits are usually of higher quality. + loop { + let var = rng.next_u32() >> (32 - 6); + if var < RANGE { + return GEN_ASCII_STR_CHARSET[var as usize]; + } + } + } +} + +#[cfg(feature = "alloc")] +impl DistString for Alphanumeric { + fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) { + unsafe { + let v = string.as_mut_vec(); + v.extend(self.sample_iter(rng).take(len)); + } + } +} + +impl Distribution<bool> for Standard { + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool { + // We can compare against an arbitrary bit of an u32 to get a bool. + // Because the least significant bits of a lower quality RNG can have + // simple patterns, we compare against the most significant bit. This is + // easiest done using a sign test. + (rng.next_u32() as i32) < 0 + } +} + +macro_rules! tuple_impl { + // use variables to indicate the arity of the tuple + ($($tyvar:ident),* ) => { + // the trailing commas are for the 1 tuple + impl< $( $tyvar ),* > + Distribution<( $( $tyvar ),* , )> + for Standard + where $( Standard: Distribution<$tyvar> ),* + { + #[inline] + fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> ( $( $tyvar ),* , ) { + ( + // use the $tyvar's to get the appropriate number of + // repeats (they're not actually needed) + $( + _rng.gen::<$tyvar>() + ),* + , + ) + } + } + } +} + +impl Distribution<()> for Standard { + #[allow(clippy::unused_unit)] + #[inline] + fn sample<R: Rng + ?Sized>(&self, _: &mut R) -> () { + () + } +} +tuple_impl! {A} +tuple_impl! {A, B} +tuple_impl! {A, B, C} +tuple_impl! {A, B, C, D} +tuple_impl! {A, B, C, D, E} +tuple_impl! {A, B, C, D, E, F} +tuple_impl! {A, B, C, D, E, F, G} +tuple_impl! {A, B, C, D, E, F, G, H} +tuple_impl! {A, B, C, D, E, F, G, H, I} +tuple_impl! {A, B, C, D, E, F, G, H, I, J} +tuple_impl! {A, B, C, D, E, F, G, H, I, J, K} +tuple_impl! {A, B, C, D, E, F, G, H, I, J, K, L} + +#[cfg(feature = "min_const_gen")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "min_const_gen")))] +impl<T, const N: usize> Distribution<[T; N]> for Standard +where Standard: Distribution<T> +{ + #[inline] + fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; N] { + let mut buff: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() }; + + for elem in &mut buff { + *elem = MaybeUninit::new(_rng.gen()); + } + + unsafe { mem::transmute_copy::<_, _>(&buff) } + } +} + +#[cfg(not(feature = "min_const_gen"))] +macro_rules! array_impl { + // recursive, given at least one type parameter: + {$n:expr, $t:ident, $($ts:ident,)*} => { + array_impl!{($n - 1), $($ts,)*} + + impl<T> Distribution<[T; $n]> for Standard where Standard: Distribution<T> { + #[inline] + fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; $n] { + [_rng.gen::<$t>(), $(_rng.gen::<$ts>()),*] + } + } + }; + // empty case: + {$n:expr,} => { + impl<T> Distribution<[T; $n]> for Standard { + fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; $n] { [] } + } + }; +} + +#[cfg(not(feature = "min_const_gen"))] +array_impl! {32, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T,} + +impl<T> Distribution<Option<T>> for Standard +where Standard: Distribution<T> +{ + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<T> { + // UFCS is needed here: https://github.com/rust-lang/rust/issues/24066 + if rng.gen::<bool>() { + Some(rng.gen()) + } else { + None + } + } +} + +impl<T> Distribution<Wrapping<T>> for Standard +where Standard: Distribution<T> +{ + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Wrapping<T> { + Wrapping(rng.gen()) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::RngCore; + #[cfg(feature = "alloc")] use alloc::string::String; + + #[test] + fn test_misc() { + let rng: &mut dyn RngCore = &mut crate::test::rng(820); + + rng.sample::<char, _>(Standard); + rng.sample::<bool, _>(Standard); + } + + #[cfg(feature = "alloc")] + #[test] + fn test_chars() { + use core::iter; + let mut rng = crate::test::rng(805); + + // Test by generating a relatively large number of chars, so we also + // take the rejection sampling path. + let word: String = iter::repeat(()) + .map(|()| rng.gen::<char>()) + .take(1000) + .collect(); + assert!(!word.is_empty()); + } + + #[test] + fn test_alphanumeric() { + let mut rng = crate::test::rng(806); + + // Test by generating a relatively large number of chars, so we also + // take the rejection sampling path. + let mut incorrect = false; + for _ in 0..100 { + let c: char = rng.sample(Alphanumeric).into(); + incorrect |= !(('0'..='9').contains(&c) || + ('A'..='Z').contains(&c) || + ('a'..='z').contains(&c) ); + } + assert!(!incorrect); + } + + #[test] + fn value_stability() { + fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>( + distr: &D, zero: T, expected: &[T], + ) { + let mut rng = crate::test::rng(807); + let mut buf = [zero; 5]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + } + + test_samples(&Standard, 'a', &[ + '\u{8cdac}', + '\u{a346a}', + '\u{80120}', + '\u{ed692}', + '\u{35888}', + ]); + test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); + test_samples(&Standard, false, &[true, true, false, true, false]); + test_samples(&Standard, None as Option<bool>, &[ + Some(true), + None, + Some(false), + None, + Some(false), + ]); + test_samples(&Standard, Wrapping(0i32), &[ + Wrapping(-2074640887), + Wrapping(-1719949321), + Wrapping(2018088303), + Wrapping(-547181756), + Wrapping(838957336), + ]); + + // We test only sub-sets of tuple and array impls + test_samples(&Standard, (), &[(), (), (), (), ()]); + test_samples(&Standard, (false,), &[ + (true,), + (true,), + (false,), + (true,), + (false,), + ]); + test_samples(&Standard, (false, false), &[ + (true, true), + (false, true), + (false, false), + (true, false), + (false, false), + ]); + + test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]); + test_samples(&Standard, [0u8; 3], &[ + [9, 247, 111], + [68, 24, 13], + [174, 19, 194], + [172, 69, 213], + [149, 207, 29], + ]); + } +} diff --git a/third_party/rust/rand/src/distributions/slice.rs b/third_party/rust/rand/src/distributions/slice.rs new file mode 100644 index 0000000000..3302deb2a4 --- /dev/null +++ b/third_party/rust/rand/src/distributions/slice.rs @@ -0,0 +1,117 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::distributions::{Distribution, Uniform}; + +/// A distribution to sample items uniformly from a slice. +/// +/// [`Slice::new`] constructs a distribution referencing a slice and uniformly +/// samples references from the items in the slice. It may do extra work up +/// front to make sampling of multiple values faster; if only one sample from +/// the slice is required, [`SliceRandom::choose`] can be more efficient. +/// +/// Steps are taken to avoid bias which might be present in naive +/// implementations; for example `slice[rng.gen() % slice.len()]` samples from +/// the slice, but may be more likely to select numbers in the low range than +/// other values. +/// +/// This distribution samples with replacement; each sample is independent. +/// Sampling without replacement requires state to be retained, and therefore +/// cannot be handled by a distribution; you should instead consider methods +/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`]. +/// +/// # Example +/// +/// ``` +/// use rand::Rng; +/// use rand::distributions::Slice; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let vowels_dist = Slice::new(&vowels).unwrap(); +/// let rng = rand::thread_rng(); +/// +/// // build a string of 10 vowels +/// let vowel_string: String = rng +/// .sample_iter(&vowels_dist) +/// .take(10) +/// .collect(); +/// +/// println!("{}", vowel_string); +/// assert_eq!(vowel_string.len(), 10); +/// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); +/// ``` +/// +/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose] +/// may be preferred: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let mut rng = rand::thread_rng(); +/// +/// println!("{}", vowels.choose(&mut rng).unwrap()) +/// ``` +/// +/// [`SliceRandom`]: crate::seq::SliceRandom +/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose +/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple +#[derive(Debug, Clone, Copy)] +pub struct Slice<'a, T> { + slice: &'a [T], + range: Uniform<usize>, +} + +impl<'a, T> Slice<'a, T> { + /// Create a new `Slice` instance which samples uniformly from the slice. + /// Returns `Err` if the slice is empty. + pub fn new(slice: &'a [T]) -> Result<Self, EmptySlice> { + match slice.len() { + 0 => Err(EmptySlice), + len => Ok(Self { + slice, + range: Uniform::new(0, len), + }), + } + } +} + +impl<'a, T> Distribution<&'a T> for Slice<'a, T> { + fn sample<R: crate::Rng + ?Sized>(&self, rng: &mut R) -> &'a T { + let idx = self.range.sample(rng); + + debug_assert!( + idx < self.slice.len(), + "Uniform::new(0, {}) somehow returned {}", + self.slice.len(), + idx + ); + + // Safety: at construction time, it was ensured that the slice was + // non-empty, and that the `Uniform` range produces values in range + // for the slice + unsafe { self.slice.get_unchecked(idx) } + } +} + +/// Error type indicating that a [`Slice`] distribution was improperly +/// constructed with an empty slice. +#[derive(Debug, Clone, Copy)] +pub struct EmptySlice; + +impl core::fmt::Display for EmptySlice { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "Tried to create a `distributions::Slice` with an empty slice" + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for EmptySlice {} diff --git a/third_party/rust/rand/src/distributions/uniform.rs b/third_party/rust/rand/src/distributions/uniform.rs new file mode 100644 index 0000000000..261357b245 --- /dev/null +++ b/third_party/rust/rand/src/distributions/uniform.rs @@ -0,0 +1,1658 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A distribution uniformly sampling numbers within a given range. +//! +//! [`Uniform`] is the standard distribution to sample uniformly from a range; +//! e.g. `Uniform::new_inclusive(1, 6)` can sample integers from 1 to 6, like a +//! standard die. [`Rng::gen_range`] supports any type supported by +//! [`Uniform`]. +//! +//! This distribution is provided with support for several primitive types +//! (all integer and floating-point types) as well as [`std::time::Duration`], +//! and supports extension to user-defined types via a type-specific *back-end* +//! implementation. +//! +//! The types [`UniformInt`], [`UniformFloat`] and [`UniformDuration`] are the +//! back-ends supporting sampling from primitive integer and floating-point +//! ranges as well as from [`std::time::Duration`]; these types do not normally +//! need to be used directly (unless implementing a derived back-end). +//! +//! # Example usage +//! +//! ``` +//! use rand::{Rng, thread_rng}; +//! use rand::distributions::Uniform; +//! +//! let mut rng = thread_rng(); +//! let side = Uniform::new(-10.0, 10.0); +//! +//! // sample between 1 and 10 points +//! for _ in 0..rng.gen_range(1..=10) { +//! // sample a point from the square with sides -10 - 10 in two dimensions +//! let (x, y) = (rng.sample(side), rng.sample(side)); +//! println!("Point: {}, {}", x, y); +//! } +//! ``` +//! +//! # Extending `Uniform` to support a custom type +//! +//! To extend [`Uniform`] to support your own types, write a back-end which +//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] +//! helper trait to "register" your back-end. See the `MyF32` example below. +//! +//! At a minimum, the back-end needs to store any parameters needed for sampling +//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. +//! Those methods should include an assert to check the range is valid (i.e. +//! `low < high`). The example below merely wraps another back-end. +//! +//! The `new`, `new_inclusive` and `sample_single` functions use arguments of +//! type SampleBorrow<X> in order to support passing in values by reference or +//! by value. In the implementation of these functions, you can choose to +//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose +//! to copy or clone the value, whatever is appropriate for your type. +//! +//! ``` +//! use rand::prelude::*; +//! use rand::distributions::uniform::{Uniform, SampleUniform, +//! UniformSampler, UniformFloat, SampleBorrow}; +//! +//! struct MyF32(f32); +//! +//! #[derive(Clone, Copy, Debug)] +//! struct UniformMyF32(UniformFloat<f32>); +//! +//! impl UniformSampler for UniformMyF32 { +//! type X = MyF32; +//! fn new<B1, B2>(low: B1, high: B2) -> Self +//! where B1: SampleBorrow<Self::X> + Sized, +//! B2: SampleBorrow<Self::X> + Sized +//! { +//! UniformMyF32(UniformFloat::<f32>::new(low.borrow().0, high.borrow().0)) +//! } +//! fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self +//! where B1: SampleBorrow<Self::X> + Sized, +//! B2: SampleBorrow<Self::X> + Sized +//! { +//! UniformMyF32(UniformFloat::<f32>::new_inclusive( +//! low.borrow().0, +//! high.borrow().0, +//! )) +//! } +//! fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { +//! MyF32(self.0.sample(rng)) +//! } +//! } +//! +//! impl SampleUniform for MyF32 { +//! type Sampler = UniformMyF32; +//! } +//! +//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); +//! let uniform = Uniform::new(low, high); +//! let x = uniform.sample(&mut thread_rng()); +//! ``` +//! +//! [`SampleUniform`]: crate::distributions::uniform::SampleUniform +//! [`UniformSampler`]: crate::distributions::uniform::UniformSampler +//! [`UniformInt`]: crate::distributions::uniform::UniformInt +//! [`UniformFloat`]: crate::distributions::uniform::UniformFloat +//! [`UniformDuration`]: crate::distributions::uniform::UniformDuration +//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow + +use core::time::Duration; +use core::ops::{Range, RangeInclusive}; + +use crate::distributions::float::IntoFloat; +use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, WideningMultiply}; +use crate::distributions::Distribution; +use crate::{Rng, RngCore}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] // rustc doesn't detect that this is actually used +use crate::distributions::utils::Float; + +#[cfg(feature = "simd_support")] use packed_simd::*; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; + +/// Sample values uniformly between two bounds. +/// +/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform +/// distribution sampling from the given range; these functions may do extra +/// work up front to make sampling of multiple values faster. If only one sample +/// from the range is required, [`Rng::gen_range`] can be more efficient. +/// +/// When sampling from a constant range, many calculations can happen at +/// compile-time and all methods should be fast; for floating-point ranges and +/// the full range of integer types this should have comparable performance to +/// the `Standard` distribution. +/// +/// Steps are taken to avoid bias which might be present in naive +/// implementations; for example `rng.gen::<u8>() % 170` samples from the range +/// `[0, 169]` but is twice as likely to select numbers less than 85 than other +/// values. Further, the implementations here give more weight to the high-bits +/// generated by the RNG than the low bits, since with some RNGs the low-bits +/// are of lower quality than the high bits. +/// +/// Implementations must sample in `[low, high)` range for +/// `Uniform::new(low, high)`, i.e., excluding `high`. In particular, care must +/// be taken to ensure that rounding never results values `< low` or `>= high`. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::{Distribution, Uniform}; +/// +/// let between = Uniform::from(10..10000); +/// let mut rng = rand::thread_rng(); +/// let mut sum = 0; +/// for _ in 0..1000 { +/// sum += between.sample(&mut rng); +/// } +/// println!("{}", sum); +/// ``` +/// +/// For a single sample, [`Rng::gen_range`] may be preferred: +/// +/// ``` +/// use rand::Rng; +/// +/// let mut rng = rand::thread_rng(); +/// println!("{}", rng.gen_range(0..10)); +/// ``` +/// +/// [`new`]: Uniform::new +/// [`new_inclusive`]: Uniform::new_inclusive +/// [`Rng::gen_range`]: Rng::gen_range +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] +#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] +pub struct Uniform<X: SampleUniform>(X::Sampler); + +impl<X: SampleUniform> Uniform<X> { + /// Create a new `Uniform` instance which samples uniformly from the half + /// open range `[low, high)` (excluding `high`). Panics if `low >= high`. + pub fn new<B1, B2>(low: B1, high: B2) -> Uniform<X> + where + B1: SampleBorrow<X> + Sized, + B2: SampleBorrow<X> + Sized, + { + Uniform(X::Sampler::new(low, high)) + } + + /// Create a new `Uniform` instance which samples uniformly from the closed + /// range `[low, high]` (inclusive). Panics if `low > high`. + pub fn new_inclusive<B1, B2>(low: B1, high: B2) -> Uniform<X> + where + B1: SampleBorrow<X> + Sized, + B2: SampleBorrow<X> + Sized, + { + Uniform(X::Sampler::new_inclusive(low, high)) + } +} + +impl<X: SampleUniform> Distribution<X> for Uniform<X> { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> X { + self.0.sample(rng) + } +} + +/// Helper trait for creating objects using the correct implementation of +/// [`UniformSampler`] for the sampling type. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// [module documentation]: crate::distributions::uniform +pub trait SampleUniform: Sized { + /// The `UniformSampler` implementation supporting type `X`. + type Sampler: UniformSampler<X = Self>; +} + +/// Helper trait handling actual uniform sampling. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// Implementation of [`sample_single`] is optional, and is only useful when +/// the implementation can be faster than `Self::new(low, high).sample(rng)`. +/// +/// [module documentation]: crate::distributions::uniform +/// [`sample_single`]: UniformSampler::sample_single +pub trait UniformSampler: Sized { + /// The type sampled by this implementation. + type X; + + /// Construct self, with inclusive lower bound and exclusive upper bound + /// `[low, high)`. + /// + /// Usually users should not call this directly but instead use + /// `Uniform::new`, which asserts that `low < high` before calling this. + fn new<B1, B2>(low: B1, high: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized; + + /// Construct self, with inclusive bounds `[low, high]`. + /// + /// Usually users should not call this directly but instead use + /// `Uniform::new_inclusive`, which asserts that `low <= high` before + /// calling this. + fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized; + + /// Sample a value. + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X; + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and exclusive upper bound `[low, high)`. + /// + /// By default this is implemented using + /// `UniformSampler::new(low, high).sample(rng)`. However, for some types + /// more optimal implementations for single usage may be provided via this + /// method (which is the case for integers and floats). + /// Results may not be identical. + /// + /// Note that to use this method in a generic context, the type needs to be + /// retrieved via `SampleUniform::Sampler` as follows: + /// ``` + /// use rand::{thread_rng, distributions::uniform::{SampleUniform, UniformSampler}}; + /// # #[allow(unused)] + /// fn sample_from_range<T: SampleUniform>(lb: T, ub: T) -> T { + /// let mut rng = thread_rng(); + /// <T as SampleUniform>::Sampler::sample_single(lb, ub, &mut rng) + /// } + /// ``` + fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R) -> Self::X + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let uniform: Self = UniformSampler::new(low, high); + uniform.sample(rng) + } + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and inclusive upper bound `[low, high]`. + /// + /// By default this is implemented using + /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for + /// some types more optimal implementations for single usage may be provided + /// via this method. + /// Results may not be identical. + fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R) + -> Self::X + where B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized + { + let uniform: Self = UniformSampler::new_inclusive(low, high); + uniform.sample(rng) + } +} + +impl<X: SampleUniform> From<Range<X>> for Uniform<X> { + fn from(r: ::core::ops::Range<X>) -> Uniform<X> { + Uniform::new(r.start, r.end) + } +} + +impl<X: SampleUniform> From<RangeInclusive<X>> for Uniform<X> { + fn from(r: ::core::ops::RangeInclusive<X>) -> Uniform<X> { + Uniform::new_inclusive(r.start(), r.end()) + } +} + + +/// Helper trait similar to [`Borrow`] but implemented +/// only for SampleUniform and references to SampleUniform in +/// order to resolve ambiguity issues. +/// +/// [`Borrow`]: std::borrow::Borrow +pub trait SampleBorrow<Borrowed> { + /// Immutably borrows from an owned value. See [`Borrow::borrow`] + /// + /// [`Borrow::borrow`]: std::borrow::Borrow::borrow + fn borrow(&self) -> &Borrowed; +} +impl<Borrowed> SampleBorrow<Borrowed> for Borrowed +where Borrowed: SampleUniform +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} +impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed +where Borrowed: SampleUniform +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + *self + } +} + +/// Range that supports generating a single sample efficiently. +/// +/// Any type implementing this trait can be used to specify the sampled range +/// for `Rng::gen_range`. +pub trait SampleRange<T> { + /// Generate a sample from the given range. + fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> T; + + /// Check whether the range is empty. + fn is_empty(&self) -> bool; +} + +impl<T: SampleUniform + PartialOrd> SampleRange<T> for Range<T> { + #[inline] + fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> T { + T::Sampler::sample_single(self.start, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start < self.end) + } +} + +impl<T: SampleUniform + PartialOrd> SampleRange<T> for RangeInclusive<T> { + #[inline] + fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> T { + T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start() <= self.end()) + } +} + + +//////////////////////////////////////////////////////////////////////////////// + +// What follows are all back-ends. + + +/// The back-end implementing [`UniformSampler`] for integer types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// For simplicity, we use the same generic struct `UniformInt<X>` for all +/// integer types `X`. This gives us only one field type, `X`; to store unsigned +/// values of this size, we take use the fact that these conversions are no-ops. +/// +/// For a closed range, the number of possible numbers we should generate is +/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of +/// our sample space, `zone`, is a multiple of `range`; other values must be +/// rejected (by replacing with a new random sample). +/// +/// As a special case, we use `range = 0` to represent the full range of the +/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). +/// +/// The optimum `zone` is the largest product of `range` which fits in our +/// (unsigned) target type. We calculate this by calculating how many numbers we +/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) +/// product of `range` will suffice, thus in `sample_single` we multiply by a +/// power of 2 via bit-shifting (faster but may cause more rejections). +/// +/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we +/// use `u32` for our `zone` and samples (because it's not slower and because +/// it reduces the chance of having to reject a sample). In this case we cannot +/// store `zone` in the target type since it is too large, however we know +/// `ints_to_reject < range <= $unsigned::MAX`. +/// +/// An alternative to using a modulus is widening multiply: After a widening +/// multiply by `range`, the result is in the high word. Then comparing the low +/// word against `zone` makes sure our distribution is uniform. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct UniformInt<X> { + low: X, + range: X, + z: X, // either ints_to_reject or zone depending on implementation +} + +macro_rules! uniform_int_impl { + ($ty:ty, $unsigned:ident, $u_large:ident) => { + impl SampleUniform for $ty { + type Sampler = UniformInt<$ty>; + } + + impl UniformSampler for UniformInt<$ty> { + // We play free and fast with unsigned vs signed here + // (when $ty is signed), but that's fine, since the + // contract of this macro is for $ty and $unsigned to be + // "bit-equal", so casting between them is a no-op. + + type X = $ty; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low < high, "Uniform::new called with `low >= high`"); + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!( + low <= high, + "Uniform::new_inclusive called with `low > high`" + ); + let unsigned_max = ::core::$u_large::MAX; + + let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned; + let ints_to_reject = if range > 0 { + let range = $u_large::from(range); + (unsigned_max - range + 1) % range + } else { + 0 + }; + + UniformInt { + low, + // These are really $unsigned values, but store as $ty: + range: range as $ty, + z: ints_to_reject as $unsigned as $ty, + } + } + + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { + let range = self.range as $unsigned as $u_large; + if range > 0 { + let unsigned_max = ::core::$u_large::MAX; + let zone = unsigned_max - (self.z as $unsigned as $u_large); + loop { + let v: $u_large = rng.gen(); + let (hi, lo) = v.wmul(range); + if lo <= zone { + return self.low.wrapping_add(hi as $ty); + } + } + } else { + // Sample from the entire integer range. + rng.gen() + } + } + + #[inline] + fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Self::X + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low < high, "UniformSampler::sample_single: low >= high"); + Self::sample_single_inclusive(low, high - 1, rng) + } + + #[inline] + fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Self::X + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low <= high, "UniformSampler::sample_single_inclusive: low > high"); + let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; + // If the above resulted in wrap-around to 0, the range is $ty::MIN..=$ty::MAX, + // and any integer will do. + if range == 0 { + return rng.gen(); + } + + let zone = if ::core::$unsigned::MAX <= ::core::u16::MAX as $unsigned { + // Using a modulus is faster than the approximation for + // i8 and i16. I suppose we trade the cost of one + // modulus for near-perfect branch prediction. + let unsigned_max: $u_large = ::core::$u_large::MAX; + let ints_to_reject = (unsigned_max - range + 1) % range; + unsigned_max - ints_to_reject + } else { + // conservative but fast approximation. `- 1` is necessary to allow the + // same comparison without bias. + (range << range.leading_zeros()).wrapping_sub(1) + }; + + loop { + let v: $u_large = rng.gen(); + let (hi, lo) = v.wmul(range); + if lo <= zone { + return low.wrapping_add(hi as $ty); + } + } + } + } + }; +} + +uniform_int_impl! { i8, u8, u32 } +uniform_int_impl! { i16, u16, u32 } +uniform_int_impl! { i32, u32, u32 } +uniform_int_impl! { i64, u64, u64 } +uniform_int_impl! { i128, u128, u128 } +uniform_int_impl! { isize, usize, usize } +uniform_int_impl! { u8, u8, u32 } +uniform_int_impl! { u16, u16, u32 } +uniform_int_impl! { u32, u32, u32 } +uniform_int_impl! { u64, u64, u64 } +uniform_int_impl! { usize, usize, usize } +uniform_int_impl! { u128, u128, u128 } + +#[cfg(feature = "simd_support")] +macro_rules! uniform_simd_int_impl { + ($ty:ident, $unsigned:ident, $u_scalar:ident) => { + // The "pick the largest zone that can fit in an `u32`" optimization + // is less useful here. Multiple lanes complicate things, we don't + // know the PRNG's minimal output size, and casting to a larger vector + // is generally a bad idea for SIMD performance. The user can still + // implement it manually. + + // TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality + // perhaps `impl SampleUniform for $u_scalar`? + impl SampleUniform for $ty { + type Sampler = UniformInt<$ty>; + } + + impl UniformSampler for UniformInt<$ty> { + type X = $ty; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new<B1, B2>(low_b: B1, high_b: B2) -> Self + where B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low.lt(high).all(), "Uniform::new called with `low >= high`"); + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self + where B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low.le(high).all(), + "Uniform::new_inclusive called with `low > high`"); + let unsigned_max = ::core::$u_scalar::MAX; + + // NOTE: these may need to be replaced with explicitly + // wrapping operations if `packed_simd` changes + let range: $unsigned = ((high - low) + 1).cast(); + // `% 0` will panic at runtime. + let not_full_range = range.gt($unsigned::splat(0)); + // replacing 0 with `unsigned_max` allows a faster `select` + // with bitwise OR + let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max)); + // wrapping addition + let ints_to_reject = (unsigned_max - range + 1) % modulo; + // When `range` is 0, `lo` of `v.wmul(range)` will always be + // zero which means only one sample is needed. + let zone = unsigned_max - ints_to_reject; + + UniformInt { + low, + // These are really $unsigned values, but store as $ty: + range: range.cast(), + z: zone.cast(), + } + } + + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { + let range: $unsigned = self.range.cast(); + let zone: $unsigned = self.z.cast(); + + // This might seem very slow, generating a whole new + // SIMD vector for every sample rejection. For most uses + // though, the chance of rejection is small and provides good + // general performance. With multiple lanes, that chance is + // multiplied. To mitigate this, we replace only the lanes of + // the vector which fail, iteratively reducing the chance of + // rejection. The replacement method does however add a little + // overhead. Benchmarking or calculating probabilities might + // reveal contexts where this replacement method is slower. + let mut v: $unsigned = rng.gen(); + loop { + let (hi, lo) = v.wmul(range); + let mask = lo.le(zone); + if mask.all() { + let hi: $ty = hi.cast(); + // wrapping addition + let result = self.low + hi; + // `select` here compiles to a blend operation + // When `range.eq(0).none()` the compare and blend + // operations are avoided. + let v: $ty = v.cast(); + return range.gt($unsigned::splat(0)).select(result, v); + } + // Replace only the failing lanes + v = mask.select(v, rng.gen()); + } + } + } + }; + + // bulk implementation + ($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => { + $( + uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar); + uniform_simd_int_impl!($signed, $unsigned, $u_scalar); + )+ + }; +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { + (u64x2, i64x2), + (u64x4, i64x4), + (u64x8, i64x8), + u64 +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { + (u32x2, i32x2), + (u32x4, i32x4), + (u32x8, i32x8), + (u32x16, i32x16), + u32 +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { + (u16x2, i16x2), + (u16x4, i16x4), + (u16x8, i16x8), + (u16x16, i16x16), + (u16x32, i16x32), + u16 +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { + (u8x2, i8x2), + (u8x4, i8x4), + (u8x8, i8x8), + (u8x16, i8x16), + (u8x32, i8x32), + (u8x64, i8x64), + u8 +} + +impl SampleUniform for char { + type Sampler = UniformChar; +} + +/// The back-end implementing [`UniformSampler`] for `char`. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// This differs from integer range sampling since the range `0xD800..=0xDFFF` +/// are used for surrogate pairs in UCS and UTF-16, and consequently are not +/// valid Unicode code points. We must therefore avoid sampling values in this +/// range. +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct UniformChar { + sampler: UniformInt<u32>, +} + +/// UTF-16 surrogate range start +const CHAR_SURROGATE_START: u32 = 0xD800; +/// UTF-16 surrogate range size +const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; + +/// Convert `char` to compressed `u32` +fn char_to_comp_u32(c: char) -> u32 { + match c as u32 { + c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, + c => c, + } +} + +impl UniformSampler for UniformChar { + type X = char; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::<u32>::new(low, high); + UniformChar { sampler } + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::<u32>::new_inclusive(low, high); + UniformChar { sampler } + } + + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { + let mut x = self.sampler.sample(rng); + if x >= CHAR_SURROGATE_START { + x += CHAR_SURROGATE_LEN; + } + // SAFETY: x must not be in surrogate range or greater than char::MAX. + // This relies on range constructors which accept char arguments. + // Validity of input char values is assumed. + unsafe { core::char::from_u32_unchecked(x) } + } +} + +/// The back-end implementing [`UniformSampler`] for floating-point types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// Instead of generating a float in the `[0, 1)` range using [`Standard`], the +/// `UniformFloat` implementation converts the output of an PRNG itself. This +/// way one or two steps can be optimized out. +/// +/// The floats are first converted to a value in the `[1, 2)` interval using a +/// transmute-based method, and then mapped to the expected range with a +/// multiply and addition. Values produced this way have what equals 23 bits of +/// random digits for an `f32`, and 52 for an `f64`. +/// +/// [`new`]: UniformSampler::new +/// [`new_inclusive`]: UniformSampler::new_inclusive +/// [`Standard`]: crate::distributions::Standard +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct UniformFloat<X> { + low: X, + scale: X, +} + +macro_rules! uniform_float_impl { + ($ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { + impl SampleUniform for $ty { + type Sampler = UniformFloat<$ty>; + } + + impl UniformSampler for UniformFloat<$ty> { + type X = $ty; + + fn new<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + debug_assert!( + low.all_finite(), + "Uniform::new called with `low` non-finite." + ); + debug_assert!( + high.all_finite(), + "Uniform::new called with `high` non-finite." + ); + assert!(low.all_lt(high), "Uniform::new called with `low >= high`"); + let max_rand = <$ty>::splat( + (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, + ); + + let mut scale = high - low; + assert!(scale.all_finite(), "Uniform::new: range overflow"); + + loop { + let mask = (scale * max_rand + low).ge_mask(high); + if mask.none() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + + fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + debug_assert!( + low.all_finite(), + "Uniform::new_inclusive called with `low` non-finite." + ); + debug_assert!( + high.all_finite(), + "Uniform::new_inclusive called with `high` non-finite." + ); + assert!( + low.all_le(high), + "Uniform::new_inclusive called with `low > high`" + ); + let max_rand = <$ty>::splat( + (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, + ); + + let mut scale = (high - low) / max_rand; + assert!(scale.all_finite(), "Uniform::new_inclusive: range overflow"); + + loop { + let mask = (scale * max_rand + low).gt_mask(high); + if mask.none() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { + // Generate a value in the range [1, 2) + let value1_2 = (rng.gen::<$uty>() >> $bits_to_discard).into_float_with_exponent(0); + + // Get a value in the range [0, 1) in order to avoid + // overflowing into infinity when multiplying with scale + let value0_1 = value1_2 - 1.0; + + // We don't use `f64::mul_add`, because it is not available with + // `no_std`. Furthermore, it is slower for some targets (but + // faster for others). However, the order of multiplication and + // addition is important, because on some platforms (e.g. ARM) + // it will be optimized to a single (non-FMA) instruction. + value0_1 * self.scale + self.low + } + + #[inline] + fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Self::X + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + debug_assert!( + low.all_finite(), + "UniformSampler::sample_single called with `low` non-finite." + ); + debug_assert!( + high.all_finite(), + "UniformSampler::sample_single called with `high` non-finite." + ); + assert!( + low.all_lt(high), + "UniformSampler::sample_single: low >= high" + ); + let mut scale = high - low; + assert!(scale.all_finite(), "UniformSampler::sample_single: range overflow"); + + loop { + // Generate a value in the range [1, 2) + let value1_2 = + (rng.gen::<$uty>() >> $bits_to_discard).into_float_with_exponent(0); + + // Get a value in the range [0, 1) in order to avoid + // overflowing into infinity when multiplying with scale + let value0_1 = value1_2 - 1.0; + + // Doing multiply before addition allows some architectures + // to use a single instruction. + let res = value0_1 * scale + low; + + debug_assert!(low.all_le(res) || !scale.all_finite()); + if res.all_lt(high) { + return res; + } + + // This handles a number of edge cases. + // * `low` or `high` is NaN. In this case `scale` and + // `res` are going to end up as NaN. + // * `low` is negative infinity and `high` is finite. + // `scale` is going to be infinite and `res` will be + // NaN. + // * `high` is positive infinity and `low` is finite. + // `scale` is going to be infinite and `res` will + // be infinite or NaN (if value0_1 is 0). + // * `low` is negative infinity and `high` is positive + // infinity. `scale` will be infinite and `res` will + // be NaN. + // * `low` and `high` are finite, but `high - low` + // overflows to infinite. `scale` will be infinite + // and `res` will be infinite or NaN (if value0_1 is 0). + // So if `high` or `low` are non-finite, we are guaranteed + // to fail the `res < high` check above and end up here. + // + // While we technically should check for non-finite `low` + // and `high` before entering the loop, by doing the checks + // here instead, we allow the common case to avoid these + // checks. But we are still guaranteed that if `low` or + // `high` are non-finite we'll end up here and can do the + // appropriate checks. + // + // Likewise `high - low` overflowing to infinity is also + // rare, so handle it here after the common case. + let mask = !scale.finite_mask(); + if mask.any() { + assert!( + low.all_finite() && high.all_finite(), + "Uniform::sample_single: low and high must be finite" + ); + scale = scale.decrease_masked(mask); + } + } + } + } + }; +} + +uniform_float_impl! { f32, u32, f32, u32, 32 - 23 } +uniform_float_impl! { f64, u64, f64, u64, 64 - 52 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { f32x2, u32x2, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { f32x4, u32x4, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { f32x8, u32x8, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { f32x16, u32x16, f32, u32, 32 - 23 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { f64x2, u64x2, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { f64x4, u64x4, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } + + +/// The back-end implementing [`UniformSampler`] for `Duration`. +/// +/// Unless you are implementing [`UniformSampler`] for your own types, this type +/// should not be used directly, use [`Uniform`] instead. +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct UniformDuration { + mode: UniformDurationMode, + offset: u32, +} + +#[derive(Debug, Copy, Clone)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +enum UniformDurationMode { + Small { + secs: u64, + nanos: Uniform<u32>, + }, + Medium { + nanos: Uniform<u64>, + }, + Large { + max_secs: u64, + max_nanos: u32, + secs: Uniform<u64>, + }, +} + +impl SampleUniform for Duration { + type Sampler = UniformDuration; +} + +impl UniformSampler for UniformDuration { + type X = Duration; + + #[inline] + fn new<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low < high, "Uniform::new called with `low >= high`"); + UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) + } + + #[inline] + fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!( + low <= high, + "Uniform::new_inclusive called with `low > high`" + ); + + let low_s = low.as_secs(); + let low_n = low.subsec_nanos(); + let mut high_s = high.as_secs(); + let mut high_n = high.subsec_nanos(); + + if high_n < low_n { + high_s -= 1; + high_n += 1_000_000_000; + } + + let mode = if low_s == high_s { + UniformDurationMode::Small { + secs: low_s, + nanos: Uniform::new_inclusive(low_n, high_n), + } + } else { + let max = high_s + .checked_mul(1_000_000_000) + .and_then(|n| n.checked_add(u64::from(high_n))); + + if let Some(higher_bound) = max { + let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); + UniformDurationMode::Medium { + nanos: Uniform::new_inclusive(lower_bound, higher_bound), + } + } else { + // An offset is applied to simplify generation of nanoseconds + let max_nanos = high_n - low_n; + UniformDurationMode::Large { + max_secs: high_s, + max_nanos, + secs: Uniform::new_inclusive(low_s, high_s), + } + } + }; + UniformDuration { + mode, + offset: low_n, + } + } + + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Duration { + match self.mode { + UniformDurationMode::Small { secs, nanos } => { + let n = nanos.sample(rng); + Duration::new(secs, n) + } + UniformDurationMode::Medium { nanos } => { + let nanos = nanos.sample(rng); + Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) + } + UniformDurationMode::Large { + max_secs, + max_nanos, + secs, + } => { + // constant folding means this is at least as fast as `Rng::sample(Range)` + let nano_range = Uniform::new(0, 1_000_000_000); + loop { + let s = secs.sample(rng); + let n = nano_range.sample(rng); + if !(s == max_secs && n > max_nanos) { + let sum = n + self.offset; + break Duration::new(s, sum); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rngs::mock::StepRng; + + #[test] + #[cfg(feature = "serde1")] + fn test_serialization_uniform_duration() { + let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)); + let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); + assert_eq!( + distr.offset, de_distr.offset + ); + match (distr.mode, de_distr.mode) { + (UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => { + assert_eq!(a_secs, secs); + + assert_eq!(a_nanos.0.low, nanos.0.low); + assert_eq!(a_nanos.0.range, nanos.0.range); + assert_eq!(a_nanos.0.z, nanos.0.z); + } + (UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => { + assert_eq!(a_nanos.0.low, nanos.0.low); + assert_eq!(a_nanos.0.range, nanos.0.range); + assert_eq!(a_nanos.0.z, nanos.0.z); + } + (UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => { + assert_eq!(a_max_secs, max_secs); + assert_eq!(a_max_nanos, max_nanos); + + assert_eq!(a_secs.0.low, secs.0.low); + assert_eq!(a_secs.0.range, secs.0.range); + assert_eq!(a_secs.0.z, secs.0.z); + } + _ => panic!("`UniformDurationMode` was not serialized/deserialized correctly") + } + } + + #[test] + #[cfg(feature = "serde1")] + fn test_uniform_serialization() { + let unit_box: Uniform<i32> = Uniform::new(-1, 1); + let de_unit_box: Uniform<i32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + + assert_eq!(unit_box.0.low, de_unit_box.0.low); + assert_eq!(unit_box.0.range, de_unit_box.0.range); + assert_eq!(unit_box.0.z, de_unit_box.0.z); + + let unit_box: Uniform<f32> = Uniform::new(-1., 1.); + let de_unit_box: Uniform<f32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + + assert_eq!(unit_box.0.low, de_unit_box.0.low); + assert_eq!(unit_box.0.scale, de_unit_box.0.scale); + } + + #[should_panic] + #[test] + fn test_uniform_bad_limits_equal_int() { + Uniform::new(10, 10); + } + + #[test] + fn test_uniform_good_limits_equal_int() { + let mut rng = crate::test::rng(804); + let dist = Uniform::new_inclusive(10, 10); + for _ in 0..20 { + assert_eq!(rng.sample(dist), 10); + } + } + + #[should_panic] + #[test] + fn test_uniform_bad_limits_flipped_int() { + Uniform::new(10, 5); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_integers() { + use core::{i128, u128}; + use core::{i16, i32, i64, i8, isize}; + use core::{u16, u32, u64, u8, usize}; + + let mut rng = crate::test::rng(251); + macro_rules! t { + ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ + for &(low, high) in $v.iter() { + let my_uniform = Uniform::new(low, high); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(low, high); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + let my_uniform = Uniform::new(&low, high); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(&low, &high); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng); + assert!($le(low, v) && $lt(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng); + assert!($le(low, v) && $le(v, high)); + } + } + }}; + + // scalar bulk + ($($ty:ident),*) => {{ + $(t!( + $ty, + [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], + |x, y| x <= y, + |x, y| x < y + );)* + }}; + + // simd bulk + ($($ty:ident),* => $scalar:ident) => {{ + $(t!( + $ty, + [ + ($ty::splat(0), $ty::splat(10)), + ($ty::splat(10), $ty::splat(127)), + ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), + ], + |x: $ty, y| x.le(y).all(), + |x: $ty, y| x.lt(y).all() + );)* + }}; + } + t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, i128, u128); + + #[cfg(feature = "simd_support")] + { + t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8); + t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8); + t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); + t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); + t!(u32x2, u32x4, u32x8, u32x16 => u32); + t!(i32x2, i32x4, i32x8, i32x16 => i32); + t!(u64x2, u64x4, u64x8 => u64); + t!(i64x2, i64x4, i64x8 => i64); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_char() { + let mut rng = crate::test::rng(891); + let mut max = core::char::from_u32(0).unwrap(); + for _ in 0..100 { + let c = rng.gen_range('A'..='Z'); + assert!(('A'..='Z').contains(&c)); + max = max.max(c); + } + assert_eq!(max, 'Z'); + let d = Uniform::new( + core::char::from_u32(0xD7F0).unwrap(), + core::char::from_u32(0xE010).unwrap(), + ); + for _ in 0..100 { + let c = d.sample(&mut rng); + assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_floats() { + let mut rng = crate::test::rng(252); + let mut zero_rng = StepRng::new(0, 0); + let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); + macro_rules! t { + ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + (0.0, 100.0), + (-1e35, -1e25), + (1e-35, 1e-25), + (-1e35, 1e35), + (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), + (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), + (-<$f_scalar>::from_bits(5), 0.0), + (-<$f_scalar>::from_bits(7), -0.0), + (0.1 * ::core::$f_scalar::MAX, ::core::$f_scalar::MAX), + (-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::lanes() { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + let my_uniform = Uniform::new(low, high); + let my_incl_uniform = Uniform::new_inclusive(low, high); + for _ in 0..100 { + let v = rng.sample(my_uniform).extract(lane); + assert!(low_scalar <= v && v < high_scalar); + let v = rng.sample(my_incl_uniform).extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = <$ty as SampleUniform>::Sampler + ::sample_single(low, high, &mut rng).extract(lane); + assert!(low_scalar <= v && v < high_scalar); + } + + assert_eq!( + rng.sample(Uniform::new_inclusive(low, low)).extract(lane), + low_scalar + ); + + assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); + assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); + assert_eq!(<$ty as SampleUniform>::Sampler + ::sample_single(low, high, &mut zero_rng) + .extract(lane), low_scalar); + assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); + assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); + + // Don't run this test for really tiny differences between high and low + // since for those rounding might result in selecting high for a very + // long time. + if (high_scalar - low_scalar) > 0.0001 { + let mut lowering_max_rng = StepRng::new( + 0xffff_ffff_ffff_ffff, + (-1i64 << $bits_shifted) as u64, + ); + assert!( + <$ty as SampleUniform>::Sampler + ::sample_single(low, high, &mut lowering_max_rng) + .extract(lane) < high_scalar + ); + } + } + } + + assert_eq!( + rng.sample(Uniform::new_inclusive( + ::core::$f_scalar::MAX, + ::core::$f_scalar::MAX + )), + ::core::$f_scalar::MAX + ); + assert_eq!( + rng.sample(Uniform::new_inclusive( + -::core::$f_scalar::MAX, + -::core::$f_scalar::MAX + )), + -::core::$f_scalar::MAX + ); + }}; + } + + t!(f32, f32, 32 - 23); + t!(f64, f64, 64 - 52); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32, 32 - 23); + t!(f32x4, f32, 32 - 23); + t!(f32x8, f32, 32 - 23); + t!(f32x16, f32, 32 - 23); + t!(f64x2, f64, 64 - 52); + t!(f64x4, f64, 64 - 52); + t!(f64x8, f64, 64 - 52); + } + } + + #[test] + #[should_panic] + fn test_float_overflow() { + let _ = Uniform::from(::core::f64::MIN..::core::f64::MAX); + } + + #[test] + #[should_panic] + fn test_float_overflow_single() { + let mut rng = crate::test::rng(252); + rng.gen_range(::core::f64::MIN..::core::f64::MAX); + } + + #[test] + #[cfg(all( + feature = "std", + not(target_arch = "wasm32"), + not(target_arch = "asmjs") + ))] + fn test_float_assertions() { + use super::SampleUniform; + use std::panic::catch_unwind; + fn range<T: SampleUniform>(low: T, high: T) { + let mut rng = crate::test::rng(253); + T::Sampler::sample_single(low, high, &mut rng); + } + + macro_rules! t { + ($ty:ident, $f_scalar:ident) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + (::std::$f_scalar::NAN, 0.0), + (1.0, ::std::$f_scalar::NAN), + (::std::$f_scalar::NAN, ::std::$f_scalar::NAN), + (1.0, 0.5), + (::std::$f_scalar::MAX, -::std::$f_scalar::MAX), + (::std::$f_scalar::INFINITY, ::std::$f_scalar::INFINITY), + ( + ::std::$f_scalar::NEG_INFINITY, + ::std::$f_scalar::NEG_INFINITY, + ), + (::std::$f_scalar::NEG_INFINITY, 5.0), + (5.0, ::std::$f_scalar::INFINITY), + (::std::$f_scalar::NAN, ::std::$f_scalar::INFINITY), + (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::NAN), + (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::lanes() { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + assert!(catch_unwind(|| range(low, high)).is_err()); + assert!(catch_unwind(|| Uniform::new(low, high)).is_err()); + assert!(catch_unwind(|| Uniform::new_inclusive(low, high)).is_err()); + assert!(catch_unwind(|| range(low, low)).is_err()); + assert!(catch_unwind(|| Uniform::new(low, low)).is_err()); + } + } + }}; + } + + t!(f32, f32); + t!(f64, f64); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32); + t!(f32x4, f32); + t!(f32x8, f32); + t!(f32x16, f32); + t!(f64x2, f64); + t!(f64x4, f64); + t!(f64x8, f64); + } + } + + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_durations() { + let mut rng = crate::test::rng(253); + + let v = &[ + (Duration::new(10, 50000), Duration::new(100, 1234)), + (Duration::new(0, 100), Duration::new(1, 50)), + ( + Duration::new(0, 0), + Duration::new(u64::max_value(), 999_999_999), + ), + ]; + for &(low, high) in v.iter() { + let my_uniform = Uniform::new(low, high); + for _ in 0..1000 { + let v = rng.sample(my_uniform); + assert!(low <= v && v < high); + } + } + } + + #[test] + fn test_custom_uniform() { + use crate::distributions::uniform::{ + SampleBorrow, SampleUniform, UniformFloat, UniformSampler, + }; + #[derive(Clone, Copy, PartialEq, PartialOrd)] + struct MyF32 { + x: f32, + } + #[derive(Clone, Copy, Debug)] + struct UniformMyF32(UniformFloat<f32>); + impl UniformSampler for UniformMyF32 { + type X = MyF32; + + fn new<B1, B2>(low: B1, high: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + UniformMyF32(UniformFloat::<f32>::new(low.borrow().x, high.borrow().x)) + } + + fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + UniformSampler::new(low, high) + } + + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { + MyF32 { + x: self.0.sample(rng), + } + } + } + impl SampleUniform for MyF32 { + type Sampler = UniformMyF32; + } + + let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); + let uniform = Uniform::new(low, high); + let mut rng = crate::test::rng(804); + for _ in 0..100 { + let x: MyF32 = rng.sample(uniform); + assert!(low <= x && x < high); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::from(2u32..7); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + let r = Uniform::from(2.0f64..7.0); + assert_eq!(r.0.low, 2.0); + assert_eq!(r.0.scale, 5.0); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::from(2u32..=6); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + let r = Uniform::from(2.0f64..=7.0); + assert_eq!(r.0.low, 2.0); + assert!(r.0.scale > 5.0); + assert!(r.0.scale < 5.0 + 1e-14); + } + + #[test] + fn value_stability() { + fn test_samples<T: SampleUniform + Copy + core::fmt::Debug + PartialEq>( + lb: T, ub: T, expected_single: &[T], expected_multiple: &[T], + ) where Uniform<T>: Distribution<T> { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 3]; + + for x in &mut buf { + *x = T::Sampler::sample_single(lb, ub, &mut rng); + } + assert_eq!(&buf, expected_single); + + let distr = Uniform::new(lb, ub); + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected_multiple); + } + + // We test on a sub-set of types; possibly we should do more. + // TODO: SIMD types + + test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]); + test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]); + + test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[ + 0.008194133, + 0.00398172, + 0.007428536, + ]); + test_samples( + -1e10f64, + 1e10f64, + &[-4673848682.871551, 6388267422.932352, 4857075081.198343], + &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], + ); + + test_samples( + Duration::new(2, 0), + Duration::new(4, 0), + &[ + Duration::new(2, 532615131), + Duration::new(3, 638826742), + Duration::new(3, 485707508), + ], + &[ + Duration::new(3, 117337521), + Duration::new(3, 191764285), + Duration::new(3, 236507617), + ], + ); + } + + #[test] + fn uniform_distributions_can_be_compared() { + assert_eq!(Uniform::new(1.0, 2.0), Uniform::new(1.0, 2.0)); + + // To cover UniformInt + assert_eq!(Uniform::new(1 as u32, 2 as u32), Uniform::new(1 as u32, 2 as u32)); + } +} diff --git a/third_party/rust/rand/src/distributions/utils.rs b/third_party/rust/rand/src/distributions/utils.rs new file mode 100644 index 0000000000..89da5fd7aa --- /dev/null +++ b/third_party/rust/rand/src/distributions/utils.rs @@ -0,0 +1,429 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Math helper functions + +#[cfg(feature = "simd_support")] use packed_simd::*; + + +pub(crate) trait WideningMultiply<RHS = Self> { + type Output; + + fn wmul(self, x: RHS) -> Self::Output; +} + +macro_rules! wmul_impl { + ($ty:ty, $wide:ty, $shift:expr) => { + impl WideningMultiply for $ty { + type Output = ($ty, $ty); + + #[inline(always)] + fn wmul(self, x: $ty) -> Self::Output { + let tmp = (self as $wide) * (x as $wide); + ((tmp >> $shift) as $ty, tmp as $ty) + } + } + }; + + // simd bulk implementation + ($(($ty:ident, $wide:ident),)+, $shift:expr) => { + $( + impl WideningMultiply for $ty { + type Output = ($ty, $ty); + + #[inline(always)] + fn wmul(self, x: $ty) -> Self::Output { + // For supported vectors, this should compile to a couple + // supported multiply & swizzle instructions (no actual + // casting). + // TODO: optimize + let y: $wide = self.cast(); + let x: $wide = x.cast(); + let tmp = y * x; + let hi: $ty = (tmp >> $shift).cast(); + let lo: $ty = tmp.cast(); + (hi, lo) + } + } + )+ + }; +} +wmul_impl! { u8, u16, 8 } +wmul_impl! { u16, u32, 16 } +wmul_impl! { u32, u64, 32 } +wmul_impl! { u64, u128, 64 } + +// This code is a translation of the __mulddi3 function in LLVM's +// compiler-rt. It is an optimised variant of the common method +// `(a + b) * (c + d) = ac + ad + bc + bd`. +// +// For some reason LLVM can optimise the C version very well, but +// keeps shuffling registers in this Rust translation. +macro_rules! wmul_impl_large { + ($ty:ty, $half:expr) => { + impl WideningMultiply for $ty { + type Output = ($ty, $ty); + + #[inline(always)] + fn wmul(self, b: $ty) -> Self::Output { + const LOWER_MASK: $ty = !0 >> $half; + let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK); + let mut t = low >> $half; + low &= LOWER_MASK; + t += (self >> $half).wrapping_mul(b & LOWER_MASK); + low += (t & LOWER_MASK) << $half; + let mut high = t >> $half; + t = low >> $half; + low &= LOWER_MASK; + t += (b >> $half).wrapping_mul(self & LOWER_MASK); + low += (t & LOWER_MASK) << $half; + high += t >> $half; + high += (self >> $half).wrapping_mul(b >> $half); + + (high, low) + } + } + }; + + // simd bulk implementation + (($($ty:ty,)+) $scalar:ty, $half:expr) => { + $( + impl WideningMultiply for $ty { + type Output = ($ty, $ty); + + #[inline(always)] + fn wmul(self, b: $ty) -> Self::Output { + // needs wrapping multiplication + const LOWER_MASK: $scalar = !0 >> $half; + let mut low = (self & LOWER_MASK) * (b & LOWER_MASK); + let mut t = low >> $half; + low &= LOWER_MASK; + t += (self >> $half) * (b & LOWER_MASK); + low += (t & LOWER_MASK) << $half; + let mut high = t >> $half; + t = low >> $half; + low &= LOWER_MASK; + t += (b >> $half) * (self & LOWER_MASK); + low += (t & LOWER_MASK) << $half; + high += t >> $half; + high += (self >> $half) * (b >> $half); + + (high, low) + } + } + )+ + }; +} +wmul_impl_large! { u128, 64 } + +macro_rules! wmul_impl_usize { + ($ty:ty) => { + impl WideningMultiply for usize { + type Output = (usize, usize); + + #[inline(always)] + fn wmul(self, x: usize) -> Self::Output { + let (high, low) = (self as $ty).wmul(x as $ty); + (high as usize, low as usize) + } + } + }; +} +#[cfg(target_pointer_width = "16")] +wmul_impl_usize! { u16 } +#[cfg(target_pointer_width = "32")] +wmul_impl_usize! { u32 } +#[cfg(target_pointer_width = "64")] +wmul_impl_usize! { u64 } + +#[cfg(feature = "simd_support")] +mod simd_wmul { + use super::*; + #[cfg(target_arch = "x86")] use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; + + wmul_impl! { + (u8x2, u16x2), + (u8x4, u16x4), + (u8x8, u16x8), + (u8x16, u16x16), + (u8x32, u16x32),, + 8 + } + + wmul_impl! { (u16x2, u32x2),, 16 } + wmul_impl! { (u16x4, u32x4),, 16 } + #[cfg(not(target_feature = "sse2"))] + wmul_impl! { (u16x8, u32x8),, 16 } + #[cfg(not(target_feature = "avx2"))] + wmul_impl! { (u16x16, u32x16),, 16 } + + // 16-bit lane widths allow use of the x86 `mulhi` instructions, which + // means `wmul` can be implemented with only two instructions. + #[allow(unused_macros)] + macro_rules! wmul_impl_16 { + ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => { + impl WideningMultiply for $ty { + type Output = ($ty, $ty); + + #[inline(always)] + fn wmul(self, x: $ty) -> Self::Output { + let b = $intrinsic::from_bits(x); + let a = $intrinsic::from_bits(self); + let hi = $ty::from_bits(unsafe { $mulhi(a, b) }); + let lo = $ty::from_bits(unsafe { $mullo(a, b) }); + (hi, lo) + } + } + }; + } + + #[cfg(target_feature = "sse2")] + wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 } + #[cfg(target_feature = "avx2")] + wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 } + // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>` + // cannot use the same implementation. + + wmul_impl! { + (u32x2, u64x2), + (u32x4, u64x4), + (u32x8, u64x8),, + 32 + } + + // TODO: optimize, this seems to seriously slow things down + wmul_impl_large! { (u8x64,) u8, 4 } + wmul_impl_large! { (u16x32,) u16, 8 } + wmul_impl_large! { (u32x16,) u32, 16 } + wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 } +} + +/// Helper trait when dealing with scalar and SIMD floating point types. +pub(crate) trait FloatSIMDUtils { + // `PartialOrd` for vectors compares lexicographically. We want to compare all + // the individual SIMD lanes instead, and get the combined result over all + // lanes. This is possible using something like `a.lt(b).all()`, but we + // implement it as a trait so we can write the same code for `f32` and `f64`. + // Only the comparison functions we need are implemented. + fn all_lt(self, other: Self) -> bool; + fn all_le(self, other: Self) -> bool; + fn all_finite(self) -> bool; + + type Mask; + fn finite_mask(self) -> Self::Mask; + fn gt_mask(self, other: Self) -> Self::Mask; + fn ge_mask(self, other: Self) -> Self::Mask; + + // Decrease all lanes where the mask is `true` to the next lower value + // representable by the floating-point type. At least one of the lanes + // must be set. + fn decrease_masked(self, mask: Self::Mask) -> Self; + + // Convert from int value. Conversion is done while retaining the numerical + // value, not by retaining the binary representation. + type UInt; + fn cast_from_int(i: Self::UInt) -> Self; +} + +/// Implement functions available in std builds but missing from core primitives +#[cfg(not(std))] +// False positive: We are following `std` here. +#[allow(clippy::wrong_self_convention)] +pub(crate) trait Float: Sized { + fn is_nan(self) -> bool; + fn is_infinite(self) -> bool; + fn is_finite(self) -> bool; +} + +/// Implement functions on f32/f64 to give them APIs similar to SIMD types +pub(crate) trait FloatAsSIMD: Sized { + #[inline(always)] + fn lanes() -> usize { + 1 + } + #[inline(always)] + fn splat(scalar: Self) -> Self { + scalar + } + #[inline(always)] + fn extract(self, index: usize) -> Self { + debug_assert_eq!(index, 0); + self + } + #[inline(always)] + fn replace(self, index: usize, new_value: Self) -> Self { + debug_assert_eq!(index, 0); + new_value + } +} + +pub(crate) trait BoolAsSIMD: Sized { + fn any(self) -> bool; + fn all(self) -> bool; + fn none(self) -> bool; +} + +impl BoolAsSIMD for bool { + #[inline(always)] + fn any(self) -> bool { + self + } + + #[inline(always)] + fn all(self) -> bool { + self + } + + #[inline(always)] + fn none(self) -> bool { + !self + } +} + +macro_rules! scalar_float_impl { + ($ty:ident, $uty:ident) => { + #[cfg(not(std))] + impl Float for $ty { + #[inline] + fn is_nan(self) -> bool { + self != self + } + + #[inline] + fn is_infinite(self) -> bool { + self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY + } + + #[inline] + fn is_finite(self) -> bool { + !(self.is_nan() || self.is_infinite()) + } + } + + impl FloatSIMDUtils for $ty { + type Mask = bool; + type UInt = $uty; + + #[inline(always)] + fn all_lt(self, other: Self) -> bool { + self < other + } + + #[inline(always)] + fn all_le(self, other: Self) -> bool { + self <= other + } + + #[inline(always)] + fn all_finite(self) -> bool { + self.is_finite() + } + + #[inline(always)] + fn finite_mask(self) -> Self::Mask { + self.is_finite() + } + + #[inline(always)] + fn gt_mask(self, other: Self) -> Self::Mask { + self > other + } + + #[inline(always)] + fn ge_mask(self, other: Self) -> Self::Mask { + self >= other + } + + #[inline(always)] + fn decrease_masked(self, mask: Self::Mask) -> Self { + debug_assert!(mask, "At least one lane must be set"); + <$ty>::from_bits(self.to_bits() - 1) + } + + #[inline] + fn cast_from_int(i: Self::UInt) -> Self { + i as $ty + } + } + + impl FloatAsSIMD for $ty {} + }; +} + +scalar_float_impl!(f32, u32); +scalar_float_impl!(f64, u64); + + +#[cfg(feature = "simd_support")] +macro_rules! simd_impl { + ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => { + impl FloatSIMDUtils for $ty { + type Mask = $mty; + type UInt = $uty; + + #[inline(always)] + fn all_lt(self, other: Self) -> bool { + self.lt(other).all() + } + + #[inline(always)] + fn all_le(self, other: Self) -> bool { + self.le(other).all() + } + + #[inline(always)] + fn all_finite(self) -> bool { + self.finite_mask().all() + } + + #[inline(always)] + fn finite_mask(self) -> Self::Mask { + // This can possibly be done faster by checking bit patterns + let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY); + let pos_inf = $ty::splat(::core::$f_scalar::INFINITY); + self.gt(neg_inf) & self.lt(pos_inf) + } + + #[inline(always)] + fn gt_mask(self, other: Self) -> Self::Mask { + self.gt(other) + } + + #[inline(always)] + fn ge_mask(self, other: Self) -> Self::Mask { + self.ge(other) + } + + #[inline(always)] + fn decrease_masked(self, mask: Self::Mask) -> Self { + // Casting a mask into ints will produce all bits set for + // true, and 0 for false. Adding that to the binary + // representation of a float means subtracting one from + // the binary representation, resulting in the next lower + // value representable by $ty. This works even when the + // current value is infinity. + debug_assert!(mask.any(), "At least one lane must be set"); + <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask)) + } + + #[inline] + fn cast_from_int(i: Self::UInt) -> Self { + i.cast() + } + } + }; +} + +#[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 } +#[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 } +#[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 } +#[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 } +#[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 } +#[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 } +#[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 } diff --git a/third_party/rust/rand/src/distributions/weighted.rs b/third_party/rust/rand/src/distributions/weighted.rs new file mode 100644 index 0000000000..846b9df9c2 --- /dev/null +++ b/third_party/rust/rand/src/distributions/weighted.rs @@ -0,0 +1,47 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted index sampling +//! +//! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and +//! [`crate::distributions::WeightedError`] instead. + +pub use super::{WeightedIndex, WeightedError}; + +#[allow(missing_docs)] +#[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] +pub mod alias_method { + // This module exists to provide a deprecation warning which minimises + // compile errors, but still fails to compile if ever used. + use core::marker::PhantomData; + use alloc::vec::Vec; + use super::WeightedError; + + #[derive(Debug)] + pub struct WeightedIndex<W: Weight> { + _phantom: PhantomData<W>, + } + impl<W: Weight> WeightedIndex<W> { + pub fn new(_weights: Vec<W>) -> Result<Self, WeightedError> { + Err(WeightedError::NoItem) + } + } + + pub trait Weight {} + macro_rules! impl_weight { + () => {}; + ($T:ident, $($more:ident,)*) => { + impl Weight for $T {} + impl_weight!($($more,)*); + }; + } + impl_weight!(f64, f32,); + impl_weight!(u8, u16, u32, u64, usize,); + impl_weight!(i8, i16, i32, i64, isize,); + impl_weight!(u128, i128,); +} diff --git a/third_party/rust/rand/src/distributions/weighted_index.rs b/third_party/rust/rand/src/distributions/weighted_index.rs new file mode 100644 index 0000000000..8252b172f7 --- /dev/null +++ b/third_party/rust/rand/src/distributions/weighted_index.rs @@ -0,0 +1,458 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted index sampling + +use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; +use crate::distributions::Distribution; +use crate::Rng; +use core::cmp::PartialOrd; +use core::fmt; + +// Note that this whole module is only imported if feature="alloc" is enabled. +use alloc::vec::Vec; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; + +/// A distribution using weighted sampling of discrete items +/// +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// value of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform<X>`] exists. +/// +/// # Performance +/// +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. As an alternative, +/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) +/// supports `O(1)` sampling, but with much higher initialisation cost. +/// +/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. +/// +/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains, this might cause additional allocations, though for primitive +/// types, [`Uniform<X>`] doesn't allocate any memory. +/// +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementation of `Uniform<X>::sample`. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`Uniform<X>`]: crate::distributions::Uniform +/// [`RngCore`]: crate::RngCore +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub struct WeightedIndex<X: SampleUniform + PartialOrd> { + cumulative_weights: Vec<X>, + total_weight: X, + weight_distribution: X::Sampler, +} + +impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementation of [`Uniform<X>`] exists. + /// + /// Returns an error if the iterator is empty, if any weight is `< 0`, or + /// if its total value is 0. + /// + /// [`Uniform<X>`]: crate::distributions::uniform::Uniform + pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> + where + I: IntoIterator, + I::Item: SampleBorrow<X>, + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + + let zero = <X as Default>::default(); + if !(total_weight >= zero) { + return Err(WeightedError::InvalidWeight); + } + + let mut weights = Vec::<X>::with_capacity(iter.size_hint().0); + for w in iter { + // Note that `!(w >= x)` is not equivalent to `w < x` for partially + // ordered types due to NaNs which are equal to nothing. + if !(w.borrow() >= &zero) { + return Err(WeightedError::InvalidWeight); + } + weights.push(total_weight.clone()); + total_weight += w.borrow(); + } + + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + let distr = X::Sampler::new(zero, total_weight.clone()); + + Ok(WeightedIndex { + cumulative_weights: weights, + total_weight, + weight_distribution: distr, + }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + where X: for<'a> ::core::ops::AddAssign<&'a X> + + for<'a> ::core::ops::SubAssign<&'a X> + + Clone + + Default { + if new_weights.is_empty() { + return Ok(()); + } + + let zero = <X as Default>::default(); + + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(WeightedError::InvalidWeight); + } + } + if !(*w >= zero) { + return Err(WeightedError::InvalidWeight); + } + if i > self.cumulative_weights.len() { + return Err(WeightedError::TooMany); + } + + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); + } + if total_weight <= zero { + return Err(WeightedError::AllWeightsZero); + } + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + } + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); + + Ok(()) + } +} + +impl<X> Distribution<usize> for WeightedIndex<X> +where X: SampleUniform + PartialOrd +{ + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { + use ::core::cmp::Ordering; + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights + .binary_search_by(|w| { + if *w <= chosen_weight { + Ordering::Less + } else { + Ordering::Greater + } + }) + .unwrap_err() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[cfg(feature = "serde1")] + #[test] + fn test_weightedindex_serde1() { + let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); + + let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); + let de_weighted_index: WeightedIndex<i32> = + bincode::deserialize(&ser_weighted_index).unwrap(); + + assert_eq!( + de_weighted_index.cumulative_weights, + weighted_index.cumulative_weights + ); + assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); + } + + #[test] + fn test_accepting_nan(){ + assert_eq!( + WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), + WeightedError::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight, + ); + + assert_eq!( + WeightedIndex::new(&[0.5, 7.0]) + .unwrap() + .update_weights(&[(0, &core::f32::NAN)]) + .unwrap_err(), + WeightedError::InvalidWeight, + ) + } + + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weightedindex() { + let mut r = crate::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::<u32>() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + for _ in 0..5 { + assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); + assert_eq!( + WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) + .unwrap() + .sample(&mut r), + 4 + ); + } + + assert_eq!( + WeightedIndex::new(&[10][0..0]).unwrap_err(), + WeightedError::NoItem + ); + assert_eq!( + WeightedIndex::new(&[0]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(&[-10]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + fn test_update_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for (weights, update, expected_weights) in data.iter() { + let total_weight = weights.iter().sum::<u32>(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::<u32>(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } + } + + #[test] + fn value_stability() { + fn test_samples<X: SampleUniform + PartialOrd, I>( + weights: I, buf: &mut [usize], expected: &[usize], + ) where + I: IntoIterator, + I::Item: SampleBorrow<X>, + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + { + assert_eq!(buf.len(), expected.len()); + let distr = WeightedIndex::new(weights).unwrap(); + let mut rng = crate::test::rng(701); + for r in buf.iter_mut() { + *r = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + let mut buf = [0; 10]; + test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ + 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, + ]); + test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ + 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, + ]); + test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ + 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, + ]); + } + + #[test] + fn weighted_index_distributions_can_be_compared() { + assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); + } +} + +/// Error type returned from `WeightedIndex::new`. +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WeightedError { + /// The provided weight collection contains no items. + NoItem, + + /// A weight is either less than zero, greater than the supported maximum, + /// NaN, or otherwise invalid. + InvalidWeight, + + /// All items in the provided weight collection are zero. + AllWeightsZero, + + /// Too many weights are provided (length greater than `u32::MAX`) + TooMany, +} + +#[cfg(feature = "std")] +impl std::error::Error for WeightedError {} + +impl fmt::Display for WeightedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + WeightedError::NoItem => "No weights provided in distribution", + WeightedError::InvalidWeight => "A weight is invalid in distribution", + WeightedError::AllWeightsZero => "All weights are zero in distribution", + WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", + }) + } +} diff --git a/third_party/rust/rand/src/lib.rs b/third_party/rust/rand/src/lib.rs new file mode 100644 index 0000000000..6d84718011 --- /dev/null +++ b/third_party/rust/rand/src/lib.rs @@ -0,0 +1,214 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Utilities for random number generation +//! +//! Rand provides utilities to generate random numbers, to convert them to +//! useful types and distributions, and some randomness-related algorithms. +//! +//! # Quick Start +//! +//! To get you started quickly, the easiest and highest-level way to get +//! a random value is to use [`random()`]; alternatively you can use +//! [`thread_rng()`]. The [`Rng`] trait provides a useful API on all RNGs, while +//! the [`distributions`] and [`seq`] modules provide further +//! functionality on top of RNGs. +//! +//! ``` +//! use rand::prelude::*; +//! +//! if rand::random() { // generates a boolean +//! // Try printing a random unicode code point (probably a bad idea)! +//! println!("char: {}", rand::random::<char>()); +//! } +//! +//! let mut rng = rand::thread_rng(); +//! let y: f64 = rng.gen(); // generates a float between 0 and 1 +//! +//! let mut nums: Vec<i32> = (1..100).collect(); +//! nums.shuffle(&mut rng); +//! ``` +//! +//! # The Book +//! +//! For the user guide and further documentation, please read +//! [The Rust Rand Book](https://rust-random.github.io/book). + +#![doc( + html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", + html_favicon_url = "https://www.rust-lang.org/favicon.ico", + html_root_url = "https://rust-random.github.io/rand/" +)] +#![deny(missing_docs)] +#![deny(missing_debug_implementations)] +#![doc(test(attr(allow(unused_variables), deny(warnings))))] +#![no_std] +#![cfg_attr(feature = "simd_support", feature(stdsimd))] +#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![allow( + clippy::float_cmp, + clippy::neg_cmp_op_on_partial_ord, +)] + +#[cfg(feature = "std")] extern crate std; +#[cfg(feature = "alloc")] extern crate alloc; + +#[allow(unused)] +macro_rules! trace { ($($x:tt)*) => ( + #[cfg(feature = "log")] { + log::trace!($($x)*) + } +) } +#[allow(unused)] +macro_rules! debug { ($($x:tt)*) => ( + #[cfg(feature = "log")] { + log::debug!($($x)*) + } +) } +#[allow(unused)] +macro_rules! info { ($($x:tt)*) => ( + #[cfg(feature = "log")] { + log::info!($($x)*) + } +) } +#[allow(unused)] +macro_rules! warn { ($($x:tt)*) => ( + #[cfg(feature = "log")] { + log::warn!($($x)*) + } +) } +#[allow(unused)] +macro_rules! error { ($($x:tt)*) => ( + #[cfg(feature = "log")] { + log::error!($($x)*) + } +) } + +// Re-exports from rand_core +pub use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; + +// Public modules +pub mod distributions; +pub mod prelude; +mod rng; +pub mod rngs; +pub mod seq; + +// Public exports +#[cfg(all(feature = "std", feature = "std_rng"))] +pub use crate::rngs::thread::thread_rng; +pub use rng::{Fill, Rng}; + +#[cfg(all(feature = "std", feature = "std_rng"))] +use crate::distributions::{Distribution, Standard}; + +/// Generates a random value using the thread-local random number generator. +/// +/// This is simply a shortcut for `thread_rng().gen()`. See [`thread_rng`] for +/// documentation of the entropy source and [`Standard`] for documentation of +/// distributions and type-specific generation. +/// +/// # Provided implementations +/// +/// The following types have provided implementations that +/// generate values with the following ranges and distributions: +/// +/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed +/// over all values of the type. +/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all +/// code points in the range `0...0x10_FFFF`, except for the range +/// `0xD800...0xDFFF` (the surrogate code points). This includes +/// unassigned/reserved code points. +/// * `bool`: Generates `false` or `true`, each with probability 0.5. +/// * Floating point types (`f32` and `f64`): Uniformly distributed in the +/// half-open range `[0, 1)`. See notes below. +/// * Wrapping integers (`Wrapping<T>`), besides the type identical to their +/// normal integer variants. +/// +/// Also supported is the generation of the following +/// compound types where all component types are supported: +/// +/// * Tuples (up to 12 elements): each element is generated sequentially. +/// * Arrays (up to 32 elements): each element is generated sequentially; +/// see also [`Rng::fill`] which supports arbitrary array length for integer +/// types and tends to be faster for `u32` and smaller types. +/// * `Option<T>` first generates a `bool`, and if true generates and returns +/// `Some(value)` where `value: T`, otherwise returning `None`. +/// +/// # Examples +/// +/// ``` +/// let x = rand::random::<u8>(); +/// println!("{}", x); +/// +/// let y = rand::random::<f64>(); +/// println!("{}", y); +/// +/// if rand::random() { // generates a boolean +/// println!("Better lucky than good!"); +/// } +/// ``` +/// +/// If you're calling `random()` in a loop, caching the generator as in the +/// following example can increase performance. +/// +/// ``` +/// use rand::Rng; +/// +/// let mut v = vec![1, 2, 3]; +/// +/// for x in v.iter_mut() { +/// *x = rand::random() +/// } +/// +/// // can be made faster by caching thread_rng +/// +/// let mut rng = rand::thread_rng(); +/// +/// for x in v.iter_mut() { +/// *x = rng.gen(); +/// } +/// ``` +/// +/// [`Standard`]: distributions::Standard +#[cfg(all(feature = "std", feature = "std_rng"))] +#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] +#[inline] +pub fn random<T>() -> T +where Standard: Distribution<T> { + thread_rng().gen() +} + +#[cfg(test)] +mod test { + use super::*; + + /// Construct a deterministic RNG with the given seed + pub fn rng(seed: u64) -> impl RngCore { + // For tests, we want a statistically good, fast, reproducible RNG. + // PCG32 will do fine, and will be easy to embed if we ever need to. + const INC: u64 = 11634580027462260723; + rand_pcg::Pcg32::new(seed, INC) + } + + #[test] + #[cfg(all(feature = "std", feature = "std_rng"))] + fn test_random() { + let _n: usize = random(); + let _f: f32 = random(); + let _o: Option<Option<i8>> = random(); + #[allow(clippy::type_complexity)] + let _many: ( + (), + (usize, isize, Option<(u32, (bool,))>), + (u8, i8, u16, i16, u32, i32, u64, i64), + (f32, (f64, (f64,))), + ) = random(); + } +} diff --git a/third_party/rust/rand/src/prelude.rs b/third_party/rust/rand/src/prelude.rs new file mode 100644 index 0000000000..51c457e3f9 --- /dev/null +++ b/third_party/rust/rand/src/prelude.rs @@ -0,0 +1,34 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Convenience re-export of common members +//! +//! Like the standard library's prelude, this module simplifies importing of +//! common items. Unlike the standard prelude, the contents of this module must +//! be imported manually: +//! +//! ``` +//! use rand::prelude::*; +//! # let mut r = StdRng::from_rng(thread_rng()).unwrap(); +//! # let _: f32 = r.gen(); +//! ``` + +#[doc(no_inline)] pub use crate::distributions::Distribution; +#[cfg(feature = "small_rng")] +#[doc(no_inline)] +pub use crate::rngs::SmallRng; +#[cfg(feature = "std_rng")] +#[doc(no_inline)] pub use crate::rngs::StdRng; +#[doc(no_inline)] +#[cfg(all(feature = "std", feature = "std_rng"))] +pub use crate::rngs::ThreadRng; +#[doc(no_inline)] pub use crate::seq::{IteratorRandom, SliceRandom}; +#[doc(no_inline)] +#[cfg(all(feature = "std", feature = "std_rng"))] +pub use crate::{random, thread_rng}; +#[doc(no_inline)] pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; diff --git a/third_party/rust/rand/src/rng.rs b/third_party/rust/rand/src/rng.rs new file mode 100644 index 0000000000..79a9fbff46 --- /dev/null +++ b/third_party/rust/rand/src/rng.rs @@ -0,0 +1,600 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! [`Rng`] trait + +use rand_core::{Error, RngCore}; +use crate::distributions::uniform::{SampleRange, SampleUniform}; +use crate::distributions::{self, Distribution, Standard}; +use core::num::Wrapping; +use core::{mem, slice}; + +/// An automatically-implemented extension trait on [`RngCore`] providing high-level +/// generic methods for sampling values and other convenience methods. +/// +/// This is the primary trait to use when generating random values. +/// +/// # Generic usage +/// +/// The basic pattern is `fn foo<R: Rng + ?Sized>(rng: &mut R)`. Some +/// things are worth noting here: +/// +/// - Since `Rng: RngCore` and every `RngCore` implements `Rng`, it makes no +/// difference whether we use `R: Rng` or `R: RngCore`. +/// - The `+ ?Sized` un-bounding allows functions to be called directly on +/// type-erased references; i.e. `foo(r)` where `r: &mut dyn RngCore`. Without +/// this it would be necessary to write `foo(&mut r)`. +/// +/// An alternative pattern is possible: `fn foo<R: Rng>(rng: R)`. This has some +/// trade-offs. It allows the argument to be consumed directly without a `&mut` +/// (which is how `from_rng(thread_rng())` works); also it still works directly +/// on references (including type-erased references). Unfortunately within the +/// function `foo` it is not known whether `rng` is a reference type or not, +/// hence many uses of `rng` require an extra reference, either explicitly +/// (`distr.sample(&mut rng)`) or implicitly (`rng.gen()`); one may hope the +/// optimiser can remove redundant references later. +/// +/// Example: +/// +/// ``` +/// # use rand::thread_rng; +/// use rand::Rng; +/// +/// fn foo<R: Rng + ?Sized>(rng: &mut R) -> f32 { +/// rng.gen() +/// } +/// +/// # let v = foo(&mut thread_rng()); +/// ``` +pub trait Rng: RngCore { + /// Return a random value supporting the [`Standard`] distribution. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// let x: u32 = rng.gen(); + /// println!("{}", x); + /// println!("{:?}", rng.gen::<(f64, bool)>()); + /// ``` + /// + /// # Arrays and tuples + /// + /// The `rng.gen()` method is able to generate arrays (up to 32 elements) + /// and tuples (up to 12 elements), so long as all element types can be + /// generated. + /// When using `rustc` โฅ 1.51, enable the `min_const_gen` feature to support + /// arrays larger than 32 elements. + /// + /// For arrays of integers, especially for those with small element types + /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`]. + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// let tuple: (u8, i32, char) = rng.gen(); // arbitrary tuple support + /// + /// let arr1: [f32; 32] = rng.gen(); // array construction + /// let mut arr2 = [0u8; 128]; + /// rng.fill(&mut arr2); // array fill + /// ``` + /// + /// [`Standard`]: distributions::Standard + #[inline] + fn gen<T>(&mut self) -> T + where Standard: Distribution<T> { + Standard.sample(self) + } + + /// Generate a random value in the given range. + /// + /// This function is optimised for the case that only a single sample is + /// made from the given range. See also the [`Uniform`] distribution + /// type which may be faster if sampling from the same range repeatedly. + /// + /// Only `gen_range(low..high)` and `gen_range(low..=high)` are supported. + /// + /// # Panics + /// + /// Panics if the range is empty. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// + /// // Exclusive range + /// let n: u32 = rng.gen_range(0..10); + /// println!("{}", n); + /// let m: f64 = rng.gen_range(-40.0..1.3e5); + /// println!("{}", m); + /// + /// // Inclusive range + /// let n: u32 = rng.gen_range(0..=10); + /// println!("{}", n); + /// ``` + /// + /// [`Uniform`]: distributions::uniform::Uniform + fn gen_range<T, R>(&mut self, range: R) -> T + where + T: SampleUniform, + R: SampleRange<T> + { + assert!(!range.is_empty(), "cannot sample empty range"); + range.sample_single(self) + } + + /// Sample a new value, using the given distribution. + /// + /// ### Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// use rand::distributions::Uniform; + /// + /// let mut rng = thread_rng(); + /// let x = rng.sample(Uniform::new(10u32, 15)); + /// // Type annotation requires two types, the type and distribution; the + /// // distribution can be inferred. + /// let y = rng.sample::<u16, _>(Uniform::new(10, 15)); + /// ``` + fn sample<T, D: Distribution<T>>(&mut self, distr: D) -> T { + distr.sample(self) + } + + /// Create an iterator that generates values using the given distribution. + /// + /// Note that this function takes its arguments by value. This works since + /// `(&mut R): Rng where R: Rng` and + /// `(&D): Distribution where D: Distribution`, + /// however borrowing is not automatic hence `rng.sample_iter(...)` may + /// need to be replaced with `(&mut rng).sample_iter(...)`. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// use rand::distributions::{Alphanumeric, Uniform, Standard}; + /// + /// let mut rng = thread_rng(); + /// + /// // Vec of 16 x f32: + /// let v: Vec<f32> = (&mut rng).sample_iter(Standard).take(16).collect(); + /// + /// // String: + /// let s: String = (&mut rng).sample_iter(Alphanumeric) + /// .take(7) + /// .map(char::from) + /// .collect(); + /// + /// // Combined values + /// println!("{:?}", (&mut rng).sample_iter(Standard).take(5) + /// .collect::<Vec<(f64, bool)>>()); + /// + /// // Dice-rolling: + /// let die_range = Uniform::new_inclusive(1, 6); + /// let mut roll_die = (&mut rng).sample_iter(die_range); + /// while roll_die.next().unwrap() != 6 { + /// println!("Not a 6; rolling again!"); + /// } + /// ``` + fn sample_iter<T, D>(self, distr: D) -> distributions::DistIter<D, Self, T> + where + D: Distribution<T>, + Self: Sized, + { + distr.sample_iter(self) + } + + /// Fill any type implementing [`Fill`] with random data + /// + /// The distribution is expected to be uniform with portable results, but + /// this cannot be guaranteed for third-party implementations. + /// + /// This is identical to [`try_fill`] except that it panics on error. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut arr = [0i8; 20]; + /// thread_rng().fill(&mut arr[..]); + /// ``` + /// + /// [`fill_bytes`]: RngCore::fill_bytes + /// [`try_fill`]: Rng::try_fill + fn fill<T: Fill + ?Sized>(&mut self, dest: &mut T) { + dest.try_fill(self).unwrap_or_else(|_| panic!("Rng::fill failed")) + } + + /// Fill any type implementing [`Fill`] with random data + /// + /// The distribution is expected to be uniform with portable results, but + /// this cannot be guaranteed for third-party implementations. + /// + /// This is identical to [`fill`] except that it forwards errors. + /// + /// # Example + /// + /// ``` + /// # use rand::Error; + /// use rand::{thread_rng, Rng}; + /// + /// # fn try_inner() -> Result<(), Error> { + /// let mut arr = [0u64; 4]; + /// thread_rng().try_fill(&mut arr[..])?; + /// # Ok(()) + /// # } + /// + /// # try_inner().unwrap() + /// ``` + /// + /// [`try_fill_bytes`]: RngCore::try_fill_bytes + /// [`fill`]: Rng::fill + fn try_fill<T: Fill + ?Sized>(&mut self, dest: &mut T) -> Result<(), Error> { + dest.try_fill(self) + } + + /// Return a bool with a probability `p` of being true. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same probability repeatedly. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// println!("{}", rng.gen_bool(1.0 / 3.0)); + /// ``` + /// + /// # Panics + /// + /// If `p < 0` or `p > 1`. + /// + /// [`Bernoulli`]: distributions::Bernoulli + #[inline] + fn gen_bool(&mut self, p: f64) -> bool { + let d = distributions::Bernoulli::new(p).unwrap(); + self.sample(d) + } + + /// Return a bool with a probability of `numerator/denominator` of being + /// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of + /// returning true. If `numerator == denominator`, then the returned value + /// is guaranteed to be `true`. If `numerator == 0`, then the returned + /// value is guaranteed to be `false`. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same `numerator` and `denominator` repeatedly. + /// + /// # Panics + /// + /// If `denominator == 0` or `numerator > denominator`. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// println!("{}", rng.gen_ratio(2, 3)); + /// ``` + /// + /// [`Bernoulli`]: distributions::Bernoulli + #[inline] + fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool { + let d = distributions::Bernoulli::from_ratio(numerator, denominator).unwrap(); + self.sample(d) + } +} + +impl<R: RngCore + ?Sized> Rng for R {} + +/// Types which may be filled with random data +/// +/// This trait allows arrays to be efficiently filled with random data. +/// +/// Implementations are expected to be portable across machines unless +/// clearly documented otherwise (see the +/// [Chapter on Portability](https://rust-random.github.io/book/portability.html)). +pub trait Fill { + /// Fill self with random data + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error>; +} + +macro_rules! impl_fill_each { + () => {}; + ($t:ty) => { + impl Fill for [$t] { + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error> { + for elt in self.iter_mut() { + *elt = rng.gen(); + } + Ok(()) + } + } + }; + ($t:ty, $($tt:ty,)*) => { + impl_fill_each!($t); + impl_fill_each!($($tt,)*); + }; +} + +impl_fill_each!(bool, char, f32, f64,); + +impl Fill for [u8] { + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error> { + rng.try_fill_bytes(self) + } +} + +macro_rules! impl_fill { + () => {}; + ($t:ty) => { + impl Fill for [$t] { + #[inline(never)] // in micro benchmarks, this improves performance + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error> { + if self.len() > 0 { + rng.try_fill_bytes(unsafe { + slice::from_raw_parts_mut(self.as_mut_ptr() + as *mut u8, + self.len() * mem::size_of::<$t>() + ) + })?; + for x in self { + *x = x.to_le(); + } + } + Ok(()) + } + } + + impl Fill for [Wrapping<$t>] { + #[inline(never)] + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error> { + if self.len() > 0 { + rng.try_fill_bytes(unsafe { + slice::from_raw_parts_mut(self.as_mut_ptr() + as *mut u8, + self.len() * mem::size_of::<$t>() + ) + })?; + for x in self { + *x = Wrapping(x.0.to_le()); + } + } + Ok(()) + } + } + }; + ($t:ty, $($tt:ty,)*) => { + impl_fill!($t); + // TODO: this could replace above impl once Rust #32463 is fixed + // impl_fill!(Wrapping<$t>); + impl_fill!($($tt,)*); + } +} + +impl_fill!(u16, u32, u64, usize, u128,); +impl_fill!(i8, i16, i32, i64, isize, i128,); + +#[cfg_attr(doc_cfg, doc(cfg(feature = "min_const_gen")))] +#[cfg(feature = "min_const_gen")] +impl<T, const N: usize> Fill for [T; N] +where [T]: Fill +{ + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error> { + self[..].try_fill(rng) + } +} + +#[cfg(not(feature = "min_const_gen"))] +macro_rules! impl_fill_arrays { + ($n:expr,) => {}; + ($n:expr, $N:ident) => { + impl<T> Fill for [T; $n] where [T]: Fill { + fn try_fill<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Result<(), Error> { + self[..].try_fill(rng) + } + } + }; + ($n:expr, $N:ident, $($NN:ident,)*) => { + impl_fill_arrays!($n, $N); + impl_fill_arrays!($n - 1, $($NN,)*); + }; + (!div $n:expr,) => {}; + (!div $n:expr, $N:ident, $($NN:ident,)*) => { + impl_fill_arrays!($n, $N); + impl_fill_arrays!(!div $n / 2, $($NN,)*); + }; +} +#[cfg(not(feature = "min_const_gen"))] +#[rustfmt::skip] +impl_fill_arrays!(32, N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,); +#[cfg(not(feature = "min_const_gen"))] +impl_fill_arrays!(!div 4096, N,N,N,N,N,N,N,); + +#[cfg(test)] +mod test { + use super::*; + use crate::test::rng; + use crate::rngs::mock::StepRng; + #[cfg(feature = "alloc")] use alloc::boxed::Box; + + #[test] + fn test_fill_bytes_default() { + let mut r = StepRng::new(0x11_22_33_44_55_66_77_88, 0); + + // check every remainder mod 8, both in small and big vectors. + let lengths = [0, 1, 2, 3, 4, 5, 6, 7, 80, 81, 82, 83, 84, 85, 86, 87]; + for &n in lengths.iter() { + let mut buffer = [0u8; 87]; + let v = &mut buffer[0..n]; + r.fill_bytes(v); + + // use this to get nicer error messages. + for (i, &byte) in v.iter().enumerate() { + if byte == 0 { + panic!("byte {} of {} is zero", i, n) + } + } + } + } + + #[test] + fn test_fill() { + let x = 9041086907909331047; // a random u64 + let mut rng = StepRng::new(x, 0); + + // Convert to byte sequence and back to u64; byte-swap twice if BE. + let mut array = [0u64; 2]; + rng.fill(&mut array[..]); + assert_eq!(array, [x, x]); + assert_eq!(rng.next_u64(), x); + + // Convert to bytes then u32 in LE order + let mut array = [0u32; 2]; + rng.fill(&mut array[..]); + assert_eq!(array, [x as u32, (x >> 32) as u32]); + assert_eq!(rng.next_u32(), x as u32); + + // Check equivalence using wrapped arrays + let mut warray = [Wrapping(0u32); 2]; + rng.fill(&mut warray[..]); + assert_eq!(array[0], warray[0].0); + assert_eq!(array[1], warray[1].0); + + // Check equivalence for generated floats + let mut array = [0f32; 2]; + rng.fill(&mut array); + let gen: [f32; 2] = rng.gen(); + assert_eq!(array, gen); + } + + #[test] + fn test_fill_empty() { + let mut array = [0u32; 0]; + let mut rng = StepRng::new(0, 1); + rng.fill(&mut array); + rng.fill(&mut array[..]); + } + + #[test] + fn test_gen_range_int() { + let mut r = rng(101); + for _ in 0..1000 { + let a = r.gen_range(-4711..17); + assert!((-4711..17).contains(&a)); + let a: i8 = r.gen_range(-3..42); + assert!((-3..42).contains(&a)); + let a: u16 = r.gen_range(10..99); + assert!((10..99).contains(&a)); + let a: i32 = r.gen_range(-100..2000); + assert!((-100..2000).contains(&a)); + let a: u32 = r.gen_range(12..=24); + assert!((12..=24).contains(&a)); + + assert_eq!(r.gen_range(0u32..1), 0u32); + assert_eq!(r.gen_range(-12i64..-11), -12i64); + assert_eq!(r.gen_range(3_000_000..3_000_001), 3_000_000); + } + } + + #[test] + fn test_gen_range_float() { + let mut r = rng(101); + for _ in 0..1000 { + let a = r.gen_range(-4.5..1.7); + assert!((-4.5..1.7).contains(&a)); + let a = r.gen_range(-1.1..=-0.3); + assert!((-1.1..=-0.3).contains(&a)); + + assert_eq!(r.gen_range(0.0f32..=0.0), 0.); + assert_eq!(r.gen_range(-11.0..=-11.0), -11.); + assert_eq!(r.gen_range(3_000_000.0..=3_000_000.0), 3_000_000.); + } + } + + #[test] + #[should_panic] + fn test_gen_range_panic_int() { + #![allow(clippy::reversed_empty_ranges)] + let mut r = rng(102); + r.gen_range(5..-2); + } + + #[test] + #[should_panic] + fn test_gen_range_panic_usize() { + #![allow(clippy::reversed_empty_ranges)] + let mut r = rng(103); + r.gen_range(5..2); + } + + #[test] + fn test_gen_bool() { + #![allow(clippy::bool_assert_comparison)] + + let mut r = rng(105); + for _ in 0..5 { + assert_eq!(r.gen_bool(0.0), false); + assert_eq!(r.gen_bool(1.0), true); + } + } + + #[test] + fn test_rng_trait_object() { + use crate::distributions::{Distribution, Standard}; + let mut rng = rng(109); + let mut r = &mut rng as &mut dyn RngCore; + r.next_u32(); + r.gen::<i32>(); + assert_eq!(r.gen_range(0..1), 0); + let _c: u8 = Standard.sample(&mut r); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_rng_boxed_trait() { + use crate::distributions::{Distribution, Standard}; + let rng = rng(110); + let mut r = Box::new(rng) as Box<dyn RngCore>; + r.next_u32(); + r.gen::<i32>(); + assert_eq!(r.gen_range(0..1), 0); + let _c: u8 = Standard.sample(&mut r); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_gen_ratio_average() { + const NUM: u32 = 3; + const DENOM: u32 = 10; + const N: u32 = 100_000; + + let mut sum: u32 = 0; + let mut rng = rng(111); + for _ in 0..N { + if rng.gen_ratio(NUM, DENOM) { + sum += 1; + } + } + // Have Binomial(N, NUM/DENOM) distribution + let expected = (NUM * N) / DENOM; // exact integer + assert!(((sum - expected) as i32).abs() < 500); + } +} diff --git a/third_party/rust/rand/src/rngs/adapter/mod.rs b/third_party/rust/rand/src/rngs/adapter/mod.rs new file mode 100644 index 0000000000..bd1d294323 --- /dev/null +++ b/third_party/rust/rand/src/rngs/adapter/mod.rs @@ -0,0 +1,16 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Wrappers / adapters forming RNGs + +mod read; +mod reseeding; + +#[allow(deprecated)] +pub use self::read::{ReadError, ReadRng}; +pub use self::reseeding::ReseedingRng; diff --git a/third_party/rust/rand/src/rngs/adapter/read.rs b/third_party/rust/rand/src/rngs/adapter/read.rs new file mode 100644 index 0000000000..25a9ca7fca --- /dev/null +++ b/third_party/rust/rand/src/rngs/adapter/read.rs @@ -0,0 +1,150 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A wrapper around any Read to treat it as an RNG. + +#![allow(deprecated)] + +use std::fmt; +use std::io::Read; + +use rand_core::{impls, Error, RngCore}; + + +/// An RNG that reads random bytes straight from any type supporting +/// [`std::io::Read`], for example files. +/// +/// This will work best with an infinite reader, but that is not required. +/// +/// This can be used with `/dev/urandom` on Unix but it is recommended to use +/// [`OsRng`] instead. +/// +/// # Panics +/// +/// `ReadRng` uses [`std::io::Read::read_exact`], which retries on interrupts. +/// All other errors from the underlying reader, including when it does not +/// have enough data, will only be reported through [`try_fill_bytes`]. +/// The other [`RngCore`] methods will panic in case of an error. +/// +/// [`OsRng`]: crate::rngs::OsRng +/// [`try_fill_bytes`]: RngCore::try_fill_bytes +#[derive(Debug)] +#[deprecated(since="0.8.4", note="removal due to lack of usage")] +pub struct ReadRng<R> { + reader: R, +} + +impl<R: Read> ReadRng<R> { + /// Create a new `ReadRng` from a `Read`. + pub fn new(r: R) -> ReadRng<R> { + ReadRng { reader: r } + } +} + +impl<R: Read> RngCore for ReadRng<R> { + fn next_u32(&mut self) -> u32 { + impls::next_u32_via_fill(self) + } + + fn next_u64(&mut self) -> u64 { + impls::next_u64_via_fill(self) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.try_fill_bytes(dest).unwrap_or_else(|err| { + panic!( + "reading random bytes from Read implementation failed; error: {}", + err + ) + }); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + if dest.is_empty() { + return Ok(()); + } + // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. + self.reader + .read_exact(dest) + .map_err(|e| Error::new(ReadError(e))) + } +} + +/// `ReadRng` error type +#[derive(Debug)] +#[deprecated(since="0.8.4")] +pub struct ReadError(std::io::Error); + +impl fmt::Display for ReadError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ReadError: {}", self.0) + } +} + +impl std::error::Error for ReadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.0) + } +} + + +#[cfg(test)] +mod test { + use std::println; + + use super::ReadRng; + use crate::RngCore; + + #[test] + fn test_reader_rng_u64() { + // transmute from the target to avoid endianness concerns. + #[rustfmt::skip] + let v = [0u8, 0, 0, 0, 0, 0, 0, 1, + 0, 4, 0, 0, 3, 0, 0, 2, + 5, 0, 0, 0, 0, 0, 0, 0]; + let mut rng = ReadRng::new(&v[..]); + + assert_eq!(rng.next_u64(), 1 << 56); + assert_eq!(rng.next_u64(), (2 << 56) + (3 << 32) + (4 << 8)); + assert_eq!(rng.next_u64(), 5); + } + + #[test] + fn test_reader_rng_u32() { + let v = [0u8, 0, 0, 1, 0, 0, 2, 0, 3, 0, 0, 0]; + let mut rng = ReadRng::new(&v[..]); + + assert_eq!(rng.next_u32(), 1 << 24); + assert_eq!(rng.next_u32(), 2 << 16); + assert_eq!(rng.next_u32(), 3); + } + + #[test] + fn test_reader_rng_fill_bytes() { + let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; + let mut w = [0u8; 8]; + + let mut rng = ReadRng::new(&v[..]); + rng.fill_bytes(&mut w); + + assert!(v == w); + } + + #[test] + fn test_reader_rng_insufficient_bytes() { + let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; + let mut w = [0u8; 9]; + + let mut rng = ReadRng::new(&v[..]); + + let result = rng.try_fill_bytes(&mut w); + assert!(result.is_err()); + println!("Error: {}", result.unwrap_err()); + } +} diff --git a/third_party/rust/rand/src/rngs/adapter/reseeding.rs b/third_party/rust/rand/src/rngs/adapter/reseeding.rs new file mode 100644 index 0000000000..ae3fcbb2fc --- /dev/null +++ b/third_party/rust/rand/src/rngs/adapter/reseeding.rs @@ -0,0 +1,386 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A wrapper around another PRNG that reseeds it after it +//! generates a certain number of random bytes. + +use core::mem::size_of; + +use rand_core::block::{BlockRng, BlockRngCore}; +use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; + +/// A wrapper around any PRNG that implements [`BlockRngCore`], that adds the +/// ability to reseed it. +/// +/// `ReseedingRng` reseeds the underlying PRNG in the following cases: +/// +/// - On a manual call to [`reseed()`]. +/// - After `clone()`, the clone will be reseeded on first use. +/// - When a process is forked on UNIX, the RNGs in both the parent and child +/// processes will be reseeded just before the next call to +/// [`BlockRngCore::generate`], i.e. "soon". For ChaCha and Hc128 this is a +/// maximum of fifteen `u32` values before reseeding. +/// - After the PRNG has generated a configurable number of random bytes. +/// +/// # When should reseeding after a fixed number of generated bytes be used? +/// +/// Reseeding after a fixed number of generated bytes is never strictly +/// *necessary*. Cryptographic PRNGs don't have a limited number of bytes they +/// can output, or at least not a limit reachable in any practical way. There is +/// no such thing as 'running out of entropy'. +/// +/// Occasionally reseeding can be seen as some form of 'security in depth'. Even +/// if in the future a cryptographic weakness is found in the CSPRNG being used, +/// or a flaw in the implementation, occasionally reseeding should make +/// exploiting it much more difficult or even impossible. +/// +/// Use [`ReseedingRng::new`] with a `threshold` of `0` to disable reseeding +/// after a fixed number of generated bytes. +/// +/// # Limitations +/// +/// It is recommended that a `ReseedingRng` (including `ThreadRng`) not be used +/// from a fork handler. +/// Use `OsRng` or `getrandom`, or defer your use of the RNG until later. +/// +/// # Error handling +/// +/// Although unlikely, reseeding the wrapped PRNG can fail. `ReseedingRng` will +/// never panic but try to handle the error intelligently through some +/// combination of retrying and delaying reseeding until later. +/// If handling the source error fails `ReseedingRng` will continue generating +/// data from the wrapped PRNG without reseeding. +/// +/// Manually calling [`reseed()`] will not have this retry or delay logic, but +/// reports the error. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand_chacha::ChaCha20Core; // Internal part of ChaChaRng that +/// // implements BlockRngCore +/// use rand::rngs::OsRng; +/// use rand::rngs::adapter::ReseedingRng; +/// +/// let prng = ChaCha20Core::from_entropy(); +/// let mut reseeding_rng = ReseedingRng::new(prng, 0, OsRng); +/// +/// println!("{}", reseeding_rng.gen::<u64>()); +/// +/// let mut cloned_rng = reseeding_rng.clone(); +/// assert!(reseeding_rng.gen::<u64>() != cloned_rng.gen::<u64>()); +/// ``` +/// +/// [`BlockRngCore`]: rand_core::block::BlockRngCore +/// [`ReseedingRng::new`]: ReseedingRng::new +/// [`reseed()`]: ReseedingRng::reseed +#[derive(Debug)] +pub struct ReseedingRng<R, Rsdr>(BlockRng<ReseedingCore<R, Rsdr>>) +where + R: BlockRngCore + SeedableRng, + Rsdr: RngCore; + +impl<R, Rsdr> ReseedingRng<R, Rsdr> +where + R: BlockRngCore + SeedableRng, + Rsdr: RngCore, +{ + /// Create a new `ReseedingRng` from an existing PRNG, combined with a RNG + /// to use as reseeder. + /// + /// `threshold` sets the number of generated bytes after which to reseed the + /// PRNG. Set it to zero to never reseed based on the number of generated + /// values. + pub fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { + ReseedingRng(BlockRng::new(ReseedingCore::new(rng, threshold, reseeder))) + } + + /// Reseed the internal PRNG. + pub fn reseed(&mut self) -> Result<(), Error> { + self.0.core.reseed() + } +} + +// TODO: this should be implemented for any type where the inner type +// implements RngCore, but we can't specify that because ReseedingCore is private +impl<R, Rsdr: RngCore> RngCore for ReseedingRng<R, Rsdr> +where + R: BlockRngCore<Item = u32> + SeedableRng, + <R as BlockRngCore>::Results: AsRef<[u32]> + AsMut<[u32]>, +{ + #[inline(always)] + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } + + #[inline(always)] + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.0.try_fill_bytes(dest) + } +} + +impl<R, Rsdr> Clone for ReseedingRng<R, Rsdr> +where + R: BlockRngCore + SeedableRng + Clone, + Rsdr: RngCore + Clone, +{ + fn clone(&self) -> ReseedingRng<R, Rsdr> { + // Recreating `BlockRng` seems easier than cloning it and resetting + // the index. + ReseedingRng(BlockRng::new(self.0.core.clone())) + } +} + +impl<R, Rsdr> CryptoRng for ReseedingRng<R, Rsdr> +where + R: BlockRngCore + SeedableRng + CryptoRng, + Rsdr: RngCore + CryptoRng, +{ +} + +#[derive(Debug)] +struct ReseedingCore<R, Rsdr> { + inner: R, + reseeder: Rsdr, + threshold: i64, + bytes_until_reseed: i64, + fork_counter: usize, +} + +impl<R, Rsdr> BlockRngCore for ReseedingCore<R, Rsdr> +where + R: BlockRngCore + SeedableRng, + Rsdr: RngCore, +{ + type Item = <R as BlockRngCore>::Item; + type Results = <R as BlockRngCore>::Results; + + fn generate(&mut self, results: &mut Self::Results) { + let global_fork_counter = fork::get_fork_counter(); + if self.bytes_until_reseed <= 0 || self.is_forked(global_fork_counter) { + // We get better performance by not calling only `reseed` here + // and continuing with the rest of the function, but by directly + // returning from a non-inlined function. + return self.reseed_and_generate(results, global_fork_counter); + } + let num_bytes = results.as_ref().len() * size_of::<Self::Item>(); + self.bytes_until_reseed -= num_bytes as i64; + self.inner.generate(results); + } +} + +impl<R, Rsdr> ReseedingCore<R, Rsdr> +where + R: BlockRngCore + SeedableRng, + Rsdr: RngCore, +{ + /// Create a new `ReseedingCore`. + fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { + use ::core::i64::MAX; + fork::register_fork_handler(); + + // Because generating more values than `i64::MAX` takes centuries on + // current hardware, we just clamp to that value. + // Also we set a threshold of 0, which indicates no limit, to that + // value. + let threshold = if threshold == 0 { + MAX + } else if threshold <= MAX as u64 { + threshold as i64 + } else { + MAX + }; + + ReseedingCore { + inner: rng, + reseeder, + threshold: threshold as i64, + bytes_until_reseed: threshold as i64, + fork_counter: 0, + } + } + + /// Reseed the internal PRNG. + fn reseed(&mut self) -> Result<(), Error> { + R::from_rng(&mut self.reseeder).map(|result| { + self.bytes_until_reseed = self.threshold; + self.inner = result + }) + } + + fn is_forked(&self, global_fork_counter: usize) -> bool { + // In theory, on 32-bit platforms, it is possible for + // `global_fork_counter` to wrap around after ~4e9 forks. + // + // This check will detect a fork in the normal case where + // `fork_counter < global_fork_counter`, and also when the difference + // between both is greater than `isize::MAX` (wrapped around). + // + // It will still fail to detect a fork if there have been more than + // `isize::MAX` forks, without any reseed in between. Seems unlikely + // enough. + (self.fork_counter.wrapping_sub(global_fork_counter) as isize) < 0 + } + + #[inline(never)] + fn reseed_and_generate( + &mut self, results: &mut <Self as BlockRngCore>::Results, global_fork_counter: usize, + ) { + #![allow(clippy::if_same_then_else)] // false positive + if self.is_forked(global_fork_counter) { + info!("Fork detected, reseeding RNG"); + } else { + trace!("Reseeding RNG (periodic reseed)"); + } + + let num_bytes = results.as_ref().len() * size_of::<<R as BlockRngCore>::Item>(); + + if let Err(e) = self.reseed() { + warn!("Reseeding RNG failed: {}", e); + let _ = e; + } + self.fork_counter = global_fork_counter; + + self.bytes_until_reseed = self.threshold - num_bytes as i64; + self.inner.generate(results); + } +} + +impl<R, Rsdr> Clone for ReseedingCore<R, Rsdr> +where + R: BlockRngCore + SeedableRng + Clone, + Rsdr: RngCore + Clone, +{ + fn clone(&self) -> ReseedingCore<R, Rsdr> { + ReseedingCore { + inner: self.inner.clone(), + reseeder: self.reseeder.clone(), + threshold: self.threshold, + bytes_until_reseed: 0, // reseed clone on first use + fork_counter: self.fork_counter, + } + } +} + +impl<R, Rsdr> CryptoRng for ReseedingCore<R, Rsdr> +where + R: BlockRngCore + SeedableRng + CryptoRng, + Rsdr: RngCore + CryptoRng, +{ +} + + +#[cfg(all(unix, not(target_os = "emscripten")))] +mod fork { + use core::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Once; + + // Fork protection + // + // We implement fork protection on Unix using `pthread_atfork`. + // When the process is forked, we increment `RESEEDING_RNG_FORK_COUNTER`. + // Every `ReseedingRng` stores the last known value of the static in + // `fork_counter`. If the cached `fork_counter` is less than + // `RESEEDING_RNG_FORK_COUNTER`, it is time to reseed this RNG. + // + // If reseeding fails, we don't deal with this by setting a delay, but just + // don't update `fork_counter`, so a reseed is attempted as soon as + // possible. + + static RESEEDING_RNG_FORK_COUNTER: AtomicUsize = AtomicUsize::new(0); + + pub fn get_fork_counter() -> usize { + RESEEDING_RNG_FORK_COUNTER.load(Ordering::Relaxed) + } + + extern "C" fn fork_handler() { + // Note: fetch_add is defined to wrap on overflow + // (which is what we want). + RESEEDING_RNG_FORK_COUNTER.fetch_add(1, Ordering::Relaxed); + } + + pub fn register_fork_handler() { + static REGISTER: Once = Once::new(); + REGISTER.call_once(|| { + // Bump the counter before and after forking (see #1169): + let ret = unsafe { libc::pthread_atfork( + Some(fork_handler), + Some(fork_handler), + Some(fork_handler), + ) }; + if ret != 0 { + panic!("libc::pthread_atfork failed with code {}", ret); + } + }); + } +} + +#[cfg(not(all(unix, not(target_os = "emscripten"))))] +mod fork { + pub fn get_fork_counter() -> usize { + 0 + } + pub fn register_fork_handler() {} +} + + +#[cfg(feature = "std_rng")] +#[cfg(test)] +mod test { + use super::ReseedingRng; + use crate::rngs::mock::StepRng; + use crate::rngs::std::Core; + use crate::{Rng, SeedableRng}; + + #[test] + fn test_reseeding() { + let mut zero = StepRng::new(0, 0); + let rng = Core::from_rng(&mut zero).unwrap(); + let thresh = 1; // reseed every time the buffer is exhausted + let mut reseeding = ReseedingRng::new(rng, thresh, zero); + + // RNG buffer size is [u32; 64] + // Debug is only implemented up to length 32 so use two arrays + let mut buf = ([0u32; 32], [0u32; 32]); + reseeding.fill(&mut buf.0); + reseeding.fill(&mut buf.1); + let seq = buf; + for _ in 0..10 { + reseeding.fill(&mut buf.0); + reseeding.fill(&mut buf.1); + assert_eq!(buf, seq); + } + } + + #[test] + fn test_clone_reseeding() { + #![allow(clippy::redundant_clone)] + + let mut zero = StepRng::new(0, 0); + let rng = Core::from_rng(&mut zero).unwrap(); + let mut rng1 = ReseedingRng::new(rng, 32 * 4, zero); + + let first: u32 = rng1.gen(); + for _ in 0..10 { + let _ = rng1.gen::<u32>(); + } + + let mut rng2 = rng1.clone(); + assert_eq!(first, rng2.gen::<u32>()); + } +} diff --git a/third_party/rust/rand/src/rngs/mock.rs b/third_party/rust/rand/src/rngs/mock.rs new file mode 100644 index 0000000000..a1745a490d --- /dev/null +++ b/third_party/rust/rand/src/rngs/mock.rs @@ -0,0 +1,87 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Mock random number generator + +use rand_core::{impls, Error, RngCore}; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; + +/// A simple implementation of `RngCore` for testing purposes. +/// +/// This generates an arithmetic sequence (i.e. adds a constant each step) +/// over a `u64` number, using wrapping arithmetic. If the increment is 0 +/// the generator yields a constant. +/// +/// ``` +/// use rand::Rng; +/// use rand::rngs::mock::StepRng; +/// +/// let mut my_rng = StepRng::new(2, 1); +/// let sample: [u64; 3] = my_rng.gen(); +/// assert_eq!(sample, [2, 3, 4]); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct StepRng { + v: u64, + a: u64, +} + +impl StepRng { + /// Create a `StepRng`, yielding an arithmetic sequence starting with + /// `initial` and incremented by `increment` each time. + pub fn new(initial: u64, increment: u64) -> Self { + StepRng { + v: initial, + a: increment, + } + } +} + +impl RngCore for StepRng { + #[inline] + fn next_u32(&mut self) -> u32 { + self.next_u64() as u32 + } + + #[inline] + fn next_u64(&mut self) -> u64 { + let result = self.v; + self.v = self.v.wrapping_add(self.a); + result + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + impls::fill_bytes_via_next(self, dest); + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + #[test] + #[cfg(feature = "serde1")] + fn test_serialization_step_rng() { + use super::StepRng; + + let some_rng = StepRng::new(42, 7); + let de_some_rng: StepRng = + bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap(); + assert_eq!(some_rng.v, de_some_rng.v); + assert_eq!(some_rng.a, de_some_rng.a); + + } +} diff --git a/third_party/rust/rand/src/rngs/mod.rs b/third_party/rust/rand/src/rngs/mod.rs new file mode 100644 index 0000000000..ac3c2c595d --- /dev/null +++ b/third_party/rust/rand/src/rngs/mod.rs @@ -0,0 +1,119 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Random number generators and adapters +//! +//! ## Background: Random number generators (RNGs) +//! +//! Computers cannot produce random numbers from nowhere. We classify +//! random number generators as follows: +//! +//! - "True" random number generators (TRNGs) use hard-to-predict data sources +//! (e.g. the high-resolution parts of event timings and sensor jitter) to +//! harvest random bit-sequences, apply algorithms to remove bias and +//! estimate available entropy, then combine these bits into a byte-sequence +//! or an entropy pool. This job is usually done by the operating system or +//! a hardware generator (HRNG). +//! - "Pseudo"-random number generators (PRNGs) use algorithms to transform a +//! seed into a sequence of pseudo-random numbers. These generators can be +//! fast and produce well-distributed unpredictable random numbers (or not). +//! They are usually deterministic: given algorithm and seed, the output +//! sequence can be reproduced. They have finite period and eventually loop; +//! with many algorithms this period is fixed and can be proven sufficiently +//! long, while others are chaotic and the period depends on the seed. +//! - "Cryptographically secure" pseudo-random number generators (CSPRNGs) +//! are the sub-set of PRNGs which are secure. Security of the generator +//! relies both on hiding the internal state and using a strong algorithm. +//! +//! ## Traits and functionality +//! +//! All RNGs implement the [`RngCore`] trait, as a consequence of which the +//! [`Rng`] extension trait is automatically implemented. Secure RNGs may +//! additionally implement the [`CryptoRng`] trait. +//! +//! All PRNGs require a seed to produce their random number sequence. The +//! [`SeedableRng`] trait provides three ways of constructing PRNGs: +//! +//! - `from_seed` accepts a type specific to the PRNG +//! - `from_rng` allows a PRNG to be seeded from any other RNG +//! - `seed_from_u64` allows any PRNG to be seeded from a `u64` insecurely +//! - `from_entropy` securely seeds a PRNG from fresh entropy +//! +//! Use the [`rand_core`] crate when implementing your own RNGs. +//! +//! ## Our generators +//! +//! This crate provides several random number generators: +//! +//! - [`OsRng`] is an interface to the operating system's random number +//! source. Typically the operating system uses a CSPRNG with entropy +//! provided by a TRNG and some type of on-going re-seeding. +//! - [`ThreadRng`], provided by the [`thread_rng`] function, is a handle to a +//! thread-local CSPRNG with periodic seeding from [`OsRng`]. Because this +//! is local, it is typically much faster than [`OsRng`]. It should be +//! secure, though the paranoid may prefer [`OsRng`]. +//! - [`StdRng`] is a CSPRNG chosen for good performance and trust of security +//! (based on reviews, maturity and usage). The current algorithm is ChaCha12, +//! which is well established and rigorously analysed. +//! [`StdRng`] provides the algorithm used by [`ThreadRng`] but without +//! periodic reseeding. +//! - [`SmallRng`] is an **insecure** PRNG designed to be fast, simple, require +//! little memory, and have good output quality. +//! +//! The algorithms selected for [`StdRng`] and [`SmallRng`] may change in any +//! release and may be platform-dependent, therefore they should be considered +//! **not reproducible**. +//! +//! ## Additional generators +//! +//! **TRNGs**: The [`rdrand`] crate provides an interface to the RDRAND and +//! RDSEED instructions available in modern Intel and AMD CPUs. +//! The [`rand_jitter`] crate provides a user-space implementation of +//! entropy harvesting from CPU timer jitter, but is very slow and has +//! [security issues](https://github.com/rust-random/rand/issues/699). +//! +//! **PRNGs**: Several companion crates are available, providing individual or +//! families of PRNG algorithms. These provide the implementations behind +//! [`StdRng`] and [`SmallRng`] but can also be used directly, indeed *should* +//! be used directly when **reproducibility** matters. +//! Some suggestions are: [`rand_chacha`], [`rand_pcg`], [`rand_xoshiro`]. +//! A full list can be found by searching for crates with the [`rng` tag]. +//! +//! [`Rng`]: crate::Rng +//! [`RngCore`]: crate::RngCore +//! [`CryptoRng`]: crate::CryptoRng +//! [`SeedableRng`]: crate::SeedableRng +//! [`thread_rng`]: crate::thread_rng +//! [`rdrand`]: https://crates.io/crates/rdrand +//! [`rand_jitter`]: https://crates.io/crates/rand_jitter +//! [`rand_chacha`]: https://crates.io/crates/rand_chacha +//! [`rand_pcg`]: https://crates.io/crates/rand_pcg +//! [`rand_xoshiro`]: https://crates.io/crates/rand_xoshiro +//! [`rng` tag]: https://crates.io/keywords/rng + +#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] +#[cfg(feature = "std")] pub mod adapter; + +pub mod mock; // Public so we don't export `StepRng` directly, making it a bit + // more clear it is intended for testing. + +#[cfg(all(feature = "small_rng", target_pointer_width = "64"))] +mod xoshiro256plusplus; +#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))] +mod xoshiro128plusplus; +#[cfg(feature = "small_rng")] mod small; + +#[cfg(feature = "std_rng")] mod std; +#[cfg(all(feature = "std", feature = "std_rng"))] pub(crate) mod thread; + +#[cfg(feature = "small_rng")] pub use self::small::SmallRng; +#[cfg(feature = "std_rng")] pub use self::std::StdRng; +#[cfg(all(feature = "std", feature = "std_rng"))] pub use self::thread::ThreadRng; + +#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] +#[cfg(feature = "getrandom")] pub use rand_core::OsRng; diff --git a/third_party/rust/rand/src/rngs/small.rs b/third_party/rust/rand/src/rngs/small.rs new file mode 100644 index 0000000000..fb0e0d119b --- /dev/null +++ b/third_party/rust/rand/src/rngs/small.rs @@ -0,0 +1,117 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A small fast RNG + +use rand_core::{Error, RngCore, SeedableRng}; + +#[cfg(target_pointer_width = "64")] +type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus; +#[cfg(not(target_pointer_width = "64"))] +type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; + +/// A small-state, fast non-crypto PRNG +/// +/// `SmallRng` may be a good choice when a PRNG with small state, cheap +/// initialization, good statistical quality and good performance are required. +/// Note that depending on the application, [`StdRng`] may be faster on many +/// modern platforms while providing higher-quality randomness. Furthermore, +/// `SmallRng` is **not** a good choice when: +/// - Security against prediction is important. Use [`StdRng`] instead. +/// - Seeds with many zeros are provided. In such cases, it takes `SmallRng` +/// about 10 samples to produce 0 and 1 bits with equal probability. Either +/// provide seeds with an approximately equal number of 0 and 1 (for example +/// by using [`SeedableRng::from_entropy`] or [`SeedableRng::seed_from_u64`]), +/// or use [`StdRng`] instead. +/// +/// The algorithm is deterministic but should not be considered reproducible +/// due to dependence on platform and possible replacement in future +/// library versions. For a reproducible generator, use a named PRNG from an +/// external crate, e.g. [rand_xoshiro] or [rand_chacha]. +/// Refer also to [The Book](https://rust-random.github.io/book/guide-rngs.html). +/// +/// The PRNG algorithm in `SmallRng` is chosen to be efficient on the current +/// platform, without consideration for cryptography or security. The size of +/// its state is much smaller than [`StdRng`]. The current algorithm is +/// `Xoshiro256PlusPlus` on 64-bit platforms and `Xoshiro128PlusPlus` on 32-bit +/// platforms. Both are also implemented by the [rand_xoshiro] crate. +/// +/// # Examples +/// +/// Initializing `SmallRng` with a random seed can be done using [`SeedableRng::from_entropy`]: +/// +/// ``` +/// use rand::{Rng, SeedableRng}; +/// use rand::rngs::SmallRng; +/// +/// // Create small, cheap to initialize and fast RNG with a random seed. +/// // The randomness is supplied by the operating system. +/// let mut small_rng = SmallRng::from_entropy(); +/// # let v: u32 = small_rng.gen(); +/// ``` +/// +/// When initializing a lot of `SmallRng`'s, using [`thread_rng`] can be more +/// efficient: +/// +/// ``` +/// use rand::{SeedableRng, thread_rng}; +/// use rand::rngs::SmallRng; +/// +/// // Create a big, expensive to initialize and slower, but unpredictable RNG. +/// // This is cached and done only once per thread. +/// let mut thread_rng = thread_rng(); +/// // Create small, cheap to initialize and fast RNGs with random seeds. +/// // One can generally assume this won't fail. +/// let rngs: Vec<SmallRng> = (0..10) +/// .map(|_| SmallRng::from_rng(&mut thread_rng).unwrap()) +/// .collect(); +/// ``` +/// +/// [`StdRng`]: crate::rngs::StdRng +/// [`thread_rng`]: crate::thread_rng +/// [rand_chacha]: https://crates.io/crates/rand_chacha +/// [rand_xoshiro]: https://crates.io/crates/rand_xoshiro +#[cfg_attr(doc_cfg, doc(cfg(feature = "small_rng")))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SmallRng(Rng); + +impl RngCore for SmallRng { + #[inline(always)] + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } + + #[inline(always)] + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } + + #[inline(always)] + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest); + } + + #[inline(always)] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.0.try_fill_bytes(dest) + } +} + +impl SeedableRng for SmallRng { + type Seed = <Rng as SeedableRng>::Seed; + + #[inline(always)] + fn from_seed(seed: Self::Seed) -> Self { + SmallRng(Rng::from_seed(seed)) + } + + #[inline(always)] + fn from_rng<R: RngCore>(rng: R) -> Result<Self, Error> { + Rng::from_rng(rng).map(SmallRng) + } +} diff --git a/third_party/rust/rand/src/rngs/std.rs b/third_party/rust/rand/src/rngs/std.rs new file mode 100644 index 0000000000..cdae8fab01 --- /dev/null +++ b/third_party/rust/rand/src/rngs/std.rs @@ -0,0 +1,98 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The standard RNG + +use crate::{CryptoRng, Error, RngCore, SeedableRng}; + +pub(crate) use rand_chacha::ChaCha12Core as Core; + +use rand_chacha::ChaCha12Rng as Rng; + +/// The standard RNG. The PRNG algorithm in `StdRng` is chosen to be efficient +/// on the current platform, to be statistically strong and unpredictable +/// (meaning a cryptographically secure PRNG). +/// +/// The current algorithm used is the ChaCha block cipher with 12 rounds. Please +/// see this relevant [rand issue] for the discussion. This may change as new +/// evidence of cipher security and performance becomes available. +/// +/// The algorithm is deterministic but should not be considered reproducible +/// due to dependence on configuration and possible replacement in future +/// library versions. For a secure reproducible generator, we recommend use of +/// the [rand_chacha] crate directly. +/// +/// [rand_chacha]: https://crates.io/crates/rand_chacha +/// [rand issue]: https://github.com/rust-random/rand/issues/932 +#[cfg_attr(doc_cfg, doc(cfg(feature = "std_rng")))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct StdRng(Rng); + +impl RngCore for StdRng { + #[inline(always)] + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } + + #[inline(always)] + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } + + #[inline(always)] + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest); + } + + #[inline(always)] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.0.try_fill_bytes(dest) + } +} + +impl SeedableRng for StdRng { + type Seed = <Rng as SeedableRng>::Seed; + + #[inline(always)] + fn from_seed(seed: Self::Seed) -> Self { + StdRng(Rng::from_seed(seed)) + } + + #[inline(always)] + fn from_rng<R: RngCore>(rng: R) -> Result<Self, Error> { + Rng::from_rng(rng).map(StdRng) + } +} + +impl CryptoRng for StdRng {} + + +#[cfg(test)] +mod test { + use crate::rngs::StdRng; + use crate::{RngCore, SeedableRng}; + + #[test] + fn test_stdrng_construction() { + // Test value-stability of StdRng. This is expected to break any time + // the algorithm is changed. + #[rustfmt::skip] + let seed = [1,0,0,0, 23,0,0,0, 200,1,0,0, 210,30,0,0, + 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; + + let target = [10719222850664546238, 14064965282130556830]; + + let mut rng0 = StdRng::from_seed(seed); + let x0 = rng0.next_u64(); + + let mut rng1 = StdRng::from_rng(rng0).unwrap(); + let x1 = rng1.next_u64(); + + assert_eq!([x0, x1], target); + } +} diff --git a/third_party/rust/rand/src/rngs/thread.rs b/third_party/rust/rand/src/rngs/thread.rs new file mode 100644 index 0000000000..baebb1d99c --- /dev/null +++ b/third_party/rust/rand/src/rngs/thread.rs @@ -0,0 +1,143 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Thread-local random number generator + +use core::cell::UnsafeCell; +use std::rc::Rc; +use std::thread_local; + +use super::std::Core; +use crate::rngs::adapter::ReseedingRng; +use crate::rngs::OsRng; +use crate::{CryptoRng, Error, RngCore, SeedableRng}; + +// Rationale for using `UnsafeCell` in `ThreadRng`: +// +// Previously we used a `RefCell`, with an overhead of ~15%. There will only +// ever be one mutable reference to the interior of the `UnsafeCell`, because +// we only have such a reference inside `next_u32`, `next_u64`, etc. Within a +// single thread (which is the definition of `ThreadRng`), there will only ever +// be one of these methods active at a time. +// +// A possible scenario where there could be multiple mutable references is if +// `ThreadRng` is used inside `next_u32` and co. But the implementation is +// completely under our control. We just have to ensure none of them use +// `ThreadRng` internally, which is nonsensical anyway. We should also never run +// `ThreadRng` in destructors of its implementation, which is also nonsensical. + + +// Number of generated bytes after which to reseed `ThreadRng`. +// According to benchmarks, reseeding has a noticeable impact with thresholds +// of 32 kB and less. We choose 64 kB to avoid significant overhead. +const THREAD_RNG_RESEED_THRESHOLD: u64 = 1024 * 64; + +/// A reference to the thread-local generator +/// +/// An instance can be obtained via [`thread_rng`] or via `ThreadRng::default()`. +/// This handle is safe to use everywhere (including thread-local destructors), +/// though it is recommended not to use inside a fork handler. +/// The handle cannot be passed between threads (is not `Send` or `Sync`). +/// +/// `ThreadRng` uses the same PRNG as [`StdRng`] for security and performance +/// and is automatically seeded from [`OsRng`]. +/// +/// Unlike `StdRng`, `ThreadRng` uses the [`ReseedingRng`] wrapper to reseed +/// the PRNG from fresh entropy every 64 kiB of random data as well as after a +/// fork on Unix (though not quite immediately; see documentation of +/// [`ReseedingRng`]). +/// Note that the reseeding is done as an extra precaution against side-channel +/// attacks and mis-use (e.g. if somehow weak entropy were supplied initially). +/// The PRNG algorithms used are assumed to be secure. +/// +/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng +/// [`StdRng`]: crate::rngs::StdRng +#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] +#[derive(Clone, Debug)] +pub struct ThreadRng { + // Rc is explicitly !Send and !Sync + rng: Rc<UnsafeCell<ReseedingRng<Core, OsRng>>>, +} + +thread_local!( + // We require Rc<..> to avoid premature freeing when thread_rng is used + // within thread-local destructors. See #968. + static THREAD_RNG_KEY: Rc<UnsafeCell<ReseedingRng<Core, OsRng>>> = { + let r = Core::from_rng(OsRng).unwrap_or_else(|err| + panic!("could not initialize thread_rng: {}", err)); + let rng = ReseedingRng::new(r, + THREAD_RNG_RESEED_THRESHOLD, + OsRng); + Rc::new(UnsafeCell::new(rng)) + } +); + +/// Retrieve the lazily-initialized thread-local random number generator, +/// seeded by the system. Intended to be used in method chaining style, +/// e.g. `thread_rng().gen::<i32>()`, or cached locally, e.g. +/// `let mut rng = thread_rng();`. Invoked by the `Default` trait, making +/// `ThreadRng::default()` equivalent. +/// +/// For more information see [`ThreadRng`]. +#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] +pub fn thread_rng() -> ThreadRng { + let rng = THREAD_RNG_KEY.with(|t| t.clone()); + ThreadRng { rng } +} + +impl Default for ThreadRng { + fn default() -> ThreadRng { + crate::prelude::thread_rng() + } +} + +impl RngCore for ThreadRng { + #[inline(always)] + fn next_u32(&mut self) -> u32 { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.next_u32() + } + + #[inline(always)] + fn next_u64(&mut self) -> u64 { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.fill_bytes(dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.try_fill_bytes(dest) + } +} + +impl CryptoRng for ThreadRng {} + + +#[cfg(test)] +mod test { + #[test] + fn test_thread_rng() { + use crate::Rng; + let mut r = crate::thread_rng(); + r.gen::<i32>(); + assert_eq!(r.gen_range(0..1), 0); + } +} diff --git a/third_party/rust/rand/src/rngs/xoshiro128plusplus.rs b/third_party/rust/rand/src/rngs/xoshiro128plusplus.rs new file mode 100644 index 0000000000..ece98fafd6 --- /dev/null +++ b/third_party/rust/rand/src/rngs/xoshiro128plusplus.rs @@ -0,0 +1,118 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; +use rand_core::impls::{next_u64_via_u32, fill_bytes_via_next}; +use rand_core::le::read_u32_into; +use rand_core::{SeedableRng, RngCore, Error}; + +/// A xoshiro128++ random number generator. +/// +/// The xoshiro128++ algorithm is not suitable for cryptographic purposes, but +/// is very fast and has excellent statistical properties. +/// +/// The algorithm used here is translated from [the `xoshiro128plusplus.c` +/// reference source code](http://xoshiro.di.unimi.it/xoshiro128plusplus.c) by +/// David Blackman and Sebastiano Vigna. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +pub struct Xoshiro128PlusPlus { + s: [u32; 4], +} + +impl SeedableRng for Xoshiro128PlusPlus { + type Seed = [u8; 16]; + + /// Create a new `Xoshiro128PlusPlus`. If `seed` is entirely 0, it will be + /// mapped to a different seed. + #[inline] + fn from_seed(seed: [u8; 16]) -> Xoshiro128PlusPlus { + if seed.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } + let mut state = [0; 4]; + read_u32_into(&seed, &mut state); + Xoshiro128PlusPlus { s: state } + } + + /// Create a new `Xoshiro128PlusPlus` from a `u64` seed. + /// + /// This uses the SplitMix64 generator internally. + fn seed_from_u64(mut state: u64) -> Self { + const PHI: u64 = 0x9e3779b97f4a7c15; + let mut seed = Self::Seed::default(); + for chunk in seed.as_mut().chunks_mut(8) { + state = state.wrapping_add(PHI); + let mut z = state; + z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); + z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); + z = z ^ (z >> 31); + chunk.copy_from_slice(&z.to_le_bytes()); + } + Self::from_seed(seed) + } +} + +impl RngCore for Xoshiro128PlusPlus { + #[inline] + fn next_u32(&mut self) -> u32 { + let result_starstar = self.s[0] + .wrapping_add(self.s[3]) + .rotate_left(7) + .wrapping_add(self.s[0]); + + let t = self.s[1] << 9; + + self.s[2] ^= self.s[0]; + self.s[3] ^= self.s[1]; + self.s[1] ^= self.s[2]; + self.s[0] ^= self.s[3]; + + self.s[2] ^= t; + + self.s[3] = self.s[3].rotate_left(11); + + result_starstar + } + + #[inline] + fn next_u64(&mut self) -> u64 { + next_u64_via_u32(self) + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + fill_bytes_via_next(self, dest); + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn reference() { + let mut rng = Xoshiro128PlusPlus::from_seed( + [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); + // These values were produced with the reference implementation: + // http://xoshiro.di.unimi.it/xoshiro128plusplus.c + let expected = [ + 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, + 3867114732, 1355841295, 495546011, 621204420, + ]; + for &e in &expected { + assert_eq!(rng.next_u32(), e); + } + } +} diff --git a/third_party/rust/rand/src/rngs/xoshiro256plusplus.rs b/third_party/rust/rand/src/rngs/xoshiro256plusplus.rs new file mode 100644 index 0000000000..8ffb18b803 --- /dev/null +++ b/third_party/rust/rand/src/rngs/xoshiro256plusplus.rs @@ -0,0 +1,122 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; +use rand_core::impls::fill_bytes_via_next; +use rand_core::le::read_u64_into; +use rand_core::{SeedableRng, RngCore, Error}; + +/// A xoshiro256++ random number generator. +/// +/// The xoshiro256++ algorithm is not suitable for cryptographic purposes, but +/// is very fast and has excellent statistical properties. +/// +/// The algorithm used here is translated from [the `xoshiro256plusplus.c` +/// reference source code](http://xoshiro.di.unimi.it/xoshiro256plusplus.c) by +/// David Blackman and Sebastiano Vigna. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +pub struct Xoshiro256PlusPlus { + s: [u64; 4], +} + +impl SeedableRng for Xoshiro256PlusPlus { + type Seed = [u8; 32]; + + /// Create a new `Xoshiro256PlusPlus`. If `seed` is entirely 0, it will be + /// mapped to a different seed. + #[inline] + fn from_seed(seed: [u8; 32]) -> Xoshiro256PlusPlus { + if seed.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } + let mut state = [0; 4]; + read_u64_into(&seed, &mut state); + Xoshiro256PlusPlus { s: state } + } + + /// Create a new `Xoshiro256PlusPlus` from a `u64` seed. + /// + /// This uses the SplitMix64 generator internally. + fn seed_from_u64(mut state: u64) -> Self { + const PHI: u64 = 0x9e3779b97f4a7c15; + let mut seed = Self::Seed::default(); + for chunk in seed.as_mut().chunks_mut(8) { + state = state.wrapping_add(PHI); + let mut z = state; + z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); + z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); + z = z ^ (z >> 31); + chunk.copy_from_slice(&z.to_le_bytes()); + } + Self::from_seed(seed) + } +} + +impl RngCore for Xoshiro256PlusPlus { + #[inline] + fn next_u32(&mut self) -> u32 { + // The lowest bits have some linear dependencies, so we use the + // upper bits instead. + (self.next_u64() >> 32) as u32 + } + + #[inline] + fn next_u64(&mut self) -> u64 { + let result_plusplus = self.s[0] + .wrapping_add(self.s[3]) + .rotate_left(23) + .wrapping_add(self.s[0]); + + let t = self.s[1] << 17; + + self.s[2] ^= self.s[0]; + self.s[3] ^= self.s[1]; + self.s[1] ^= self.s[2]; + self.s[0] ^= self.s[3]; + + self.s[2] ^= t; + + self.s[3] = self.s[3].rotate_left(45); + + result_plusplus + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + fill_bytes_via_next(self, dest); + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn reference() { + let mut rng = Xoshiro256PlusPlus::from_seed( + [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); + // These values were produced with the reference implementation: + // http://xoshiro.di.unimi.it/xoshiro256plusplus.c + let expected = [ + 41943041, 58720359, 3588806011781223, 3591011842654386, + 9228616714210784205, 9973669472204895162, 14011001112246962877, + 12406186145184390807, 15849039046786891736, 10450023813501588000, + ]; + for &e in &expected { + assert_eq!(rng.next_u64(), e); + } + } +} diff --git a/third_party/rust/rand/src/seq/index.rs b/third_party/rust/rand/src/seq/index.rs new file mode 100644 index 0000000000..b38e4649d1 --- /dev/null +++ b/third_party/rust/rand/src/seq/index.rs @@ -0,0 +1,678 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Low-level API for sampling indices + +#[cfg(feature = "alloc")] use core::slice; + +#[cfg(feature = "alloc")] use alloc::vec::{self, Vec}; +// BTreeMap is not as fast in tests, but better than nothing. +#[cfg(all(feature = "alloc", not(feature = "std")))] +use alloc::collections::BTreeSet; +#[cfg(feature = "std")] use std::collections::HashSet; + +#[cfg(feature = "std")] +use crate::distributions::WeightedError; + +#[cfg(feature = "alloc")] +use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}}; + +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; + +/// A vector of indices. +/// +/// Multiple internal representations are possible. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub enum IndexVec { + #[doc(hidden)] + U32(Vec<u32>), + #[doc(hidden)] + USize(Vec<usize>), +} + +impl IndexVec { + /// Returns the number of indices + #[inline] + pub fn len(&self) -> usize { + match *self { + IndexVec::U32(ref v) => v.len(), + IndexVec::USize(ref v) => v.len(), + } + } + + /// Returns `true` if the length is 0. + #[inline] + pub fn is_empty(&self) -> bool { + match *self { + IndexVec::U32(ref v) => v.is_empty(), + IndexVec::USize(ref v) => v.is_empty(), + } + } + + /// Return the value at the given `index`. + /// + /// (Note: we cannot implement [`std::ops::Index`] because of lifetime + /// restrictions.) + #[inline] + pub fn index(&self, index: usize) -> usize { + match *self { + IndexVec::U32(ref v) => v[index] as usize, + IndexVec::USize(ref v) => v[index], + } + } + + /// Return result as a `Vec<usize>`. Conversion may or may not be trivial. + #[inline] + pub fn into_vec(self) -> Vec<usize> { + match self { + IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(), + IndexVec::USize(v) => v, + } + } + + /// Iterate over the indices as a sequence of `usize` values + #[inline] + pub fn iter(&self) -> IndexVecIter<'_> { + match *self { + IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()), + IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()), + } + } +} + +impl IntoIterator for IndexVec { + type Item = usize; + type IntoIter = IndexVecIntoIter; + + /// Convert into an iterator over the indices as a sequence of `usize` values + #[inline] + fn into_iter(self) -> IndexVecIntoIter { + match self { + IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()), + IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()), + } + } +} + +impl PartialEq for IndexVec { + fn eq(&self, other: &IndexVec) -> bool { + use self::IndexVec::*; + match (self, other) { + (&U32(ref v1), &U32(ref v2)) => v1 == v2, + (&USize(ref v1), &USize(ref v2)) => v1 == v2, + (&U32(ref v1), &USize(ref v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y)) + } + (&USize(ref v1), &U32(ref v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize)) + } + } + } +} + +impl From<Vec<u32>> for IndexVec { + #[inline] + fn from(v: Vec<u32>) -> Self { + IndexVec::U32(v) + } +} + +impl From<Vec<usize>> for IndexVec { + #[inline] + fn from(v: Vec<usize>) -> Self { + IndexVec::USize(v) + } +} + +/// Return type of `IndexVec::iter`. +#[derive(Debug)] +pub enum IndexVecIter<'a> { + #[doc(hidden)] + U32(slice::Iter<'a, u32>), + #[doc(hidden)] + USize(slice::Iter<'a, usize>), +} + +impl<'a> Iterator for IndexVecIter<'a> { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option<usize> { + use self::IndexVecIter::*; + match *self { + U32(ref mut iter) => iter.next().map(|i| *i as usize), + USize(ref mut iter) => iter.next().cloned(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + match *self { + IndexVecIter::U32(ref v) => v.size_hint(), + IndexVecIter::USize(ref v) => v.size_hint(), + } + } +} + +impl<'a> ExactSizeIterator for IndexVecIter<'a> {} + +/// Return type of `IndexVec::into_iter`. +#[derive(Clone, Debug)] +pub enum IndexVecIntoIter { + #[doc(hidden)] + U32(vec::IntoIter<u32>), + #[doc(hidden)] + USize(vec::IntoIter<usize>), +} + +impl Iterator for IndexVecIntoIter { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option<Self::Item> { + use self::IndexVecIntoIter::*; + match *self { + U32(ref mut v) => v.next().map(|i| i as usize), + USize(ref mut v) => v.next(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option<usize>) { + use self::IndexVecIntoIter::*; + match *self { + U32(ref v) => v.size_hint(), + USize(ref v) => v.size_hint(), + } + } +} + +impl ExactSizeIterator for IndexVecIntoIter {} + + +/// Randomly sample exactly `amount` distinct indices from `0..length`, and +/// return them in random order (fully shuffled). +/// +/// This method is used internally by the slice sampling methods, but it can +/// sometimes be useful to have the indices themselves so this is provided as +/// an alternative. +/// +/// The implementation used is not specified; we automatically select the +/// fastest available algorithm for the `length` and `amount` parameters +/// (based on detailed profiling on an Intel Haswell CPU). Roughly speaking, +/// complexity is `O(amount)`, except that when `amount` is small, performance +/// is closer to `O(amount^2)`, and when `length` is close to `amount` then +/// `O(length)`. +/// +/// Note that performance is significantly better over `u32` indices than over +/// `u64` indices. Because of this we hide the underlying type behind an +/// abstraction, `IndexVec`. +/// +/// If an allocation-free `no_std` function is required, it is suggested +/// to adapt the internal `sample_floyd` implementation. +/// +/// Panics if `amount > length`. +pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec +where R: Rng + ?Sized { + if amount > length { + panic!("`amount` of samples must be less than or equal to `length`"); + } + if length > (::core::u32::MAX as usize) { + // We never want to use inplace here, but could use floyd's alg + // Lazy version: always use the cache alg. + return sample_rejection(rng, length, amount); + } + let amount = amount as u32; + let length = length as u32; + + // Choice of algorithm here depends on both length and amount. See: + // https://github.com/rust-random/rand/pull/479 + // We do some calculations with f32. Accuracy is not very important. + + if amount < 163 { + const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]]; + let j = if length < 500_000 { 0 } else { 1 }; + let amount_fp = amount as f32; + let m4 = C[0][j] * amount_fp; + // Short-cut: when amount < 12, floyd's is always faster + if amount > 11 && (length as f32) < (C[1][j] + m4) * amount_fp { + sample_inplace(rng, length, amount) + } else { + sample_floyd(rng, length, amount) + } + } else { + const C: [f32; 2] = [270.0, 330.0 / 9.0]; + let j = if length < 500_000 { 0 } else { 1 }; + if (length as f32) < C[j] * (amount as f32) { + sample_inplace(rng, length, amount) + } else { + sample_rejection(rng, length, amount) + } + } +} + +/// Randomly sample exactly `amount` distinct indices from `0..length`, and +/// return them in an arbitrary order (there is no guarantee of shuffling or +/// ordering). The weights are to be provided by the input function `weights`, +/// which will be called once for each index. +/// +/// This method is used internally by the slice sampling methods, but it can +/// sometimes be useful to have the indices themselves so this is provided as +/// an alternative. +/// +/// This implementation uses `O(length + amount)` space and `O(length)` time +/// if the "nightly" feature is enabled, or `O(length)` space and +/// `O(length + amount * log length)` time otherwise. +/// +/// Panics if `amount > length`. +#[cfg(feature = "std")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] +pub fn sample_weighted<R, F, X>( + rng: &mut R, length: usize, weight: F, amount: usize, +) -> Result<IndexVec, WeightedError> +where + R: Rng + ?Sized, + F: Fn(usize) -> X, + X: Into<f64>, +{ + if length > (core::u32::MAX as usize) { + sample_efraimidis_spirakis(rng, length, weight, amount) + } else { + assert!(amount <= core::u32::MAX as usize); + let amount = amount as u32; + let length = length as u32; + sample_efraimidis_spirakis(rng, length, weight, amount) + } +} + + +/// Randomly sample exactly `amount` distinct indices from `0..length`, and +/// return them in an arbitrary order (there is no guarantee of shuffling or +/// ordering). The weights are to be provided by the input function `weights`, +/// which will be called once for each index. +/// +/// This implementation uses the algorithm described by Efraimidis and Spirakis +/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 +/// It uses `O(length + amount)` space and `O(length)` time if the +/// "nightly" feature is enabled, or `O(length)` space and `O(length +/// + amount * log length)` time otherwise. +/// +/// Panics if `amount > length`. +#[cfg(feature = "std")] +fn sample_efraimidis_spirakis<R, F, X, N>( + rng: &mut R, length: N, weight: F, amount: N, +) -> Result<IndexVec, WeightedError> +where + R: Rng + ?Sized, + F: Fn(usize) -> X, + X: Into<f64>, + N: UInt, + IndexVec: From<Vec<N>>, +{ + if amount == N::zero() { + return Ok(IndexVec::U32(Vec::new())); + } + + if amount > length { + panic!("`amount` of samples must be less than or equal to `length`"); + } + + struct Element<N> { + index: N, + key: f64, + } + impl<N> PartialOrd for Element<N> { + fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { + self.key.partial_cmp(&other.key) + } + } + impl<N> Ord for Element<N> { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + // partial_cmp will always produce a value, + // because we check that the weights are not nan + self.partial_cmp(other).unwrap() + } + } + impl<N> PartialEq for Element<N> { + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } + } + impl<N> Eq for Element<N> {} + + #[cfg(feature = "nightly")] + { + let mut candidates = Vec::with_capacity(length.as_usize()); + let mut index = N::zero(); + while index < length { + let weight = weight(index.as_usize()).into(); + if !(weight >= 0.) { + return Err(WeightedError::InvalidWeight); + } + + let key = rng.gen::<f64>().powf(1.0 / weight); + candidates.push(Element { index, key }); + + index += N::one(); + } + + // Partially sort the array to find the `amount` elements with the greatest + // keys. Do this by using `select_nth_unstable` to put the elements with + // the *smallest* keys at the beginning of the list in `O(n)` time, which + // provides equivalent information about the elements with the *greatest* keys. + let (_, mid, greater) + = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); + + let mut result: Vec<N> = Vec::with_capacity(amount.as_usize()); + result.push(mid.index); + for element in greater { + result.push(element.index); + } + Ok(IndexVec::from(result)) + } + + #[cfg(not(feature = "nightly"))] + { + use alloc::collections::BinaryHeap; + + // Partially sort the array such that the `amount` elements with the largest + // keys are first using a binary max heap. + let mut candidates = BinaryHeap::with_capacity(length.as_usize()); + let mut index = N::zero(); + while index < length { + let weight = weight(index.as_usize()).into(); + if !(weight >= 0.) { + return Err(WeightedError::InvalidWeight); + } + + let key = rng.gen::<f64>().powf(1.0 / weight); + candidates.push(Element { index, key }); + + index += N::one(); + } + + let mut result: Vec<N> = Vec::with_capacity(amount.as_usize()); + while result.len() < amount.as_usize() { + result.push(candidates.pop().unwrap().index); + } + Ok(IndexVec::from(result)) + } +} + +/// Randomly sample exactly `amount` indices from `0..length`, using Floyd's +/// combination algorithm. +/// +/// The output values are fully shuffled. (Overhead is under 50%.) +/// +/// This implementation uses `O(amount)` memory and `O(amount^2)` time. +fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec +where R: Rng + ?Sized { + // For small amount we use Floyd's fully-shuffled variant. For larger + // amounts this is slow due to Vec::insert performance, so we shuffle + // afterwards. Benchmarks show little overhead from extra logic. + let floyd_shuffle = amount < 50; + + debug_assert!(amount <= length); + let mut indices = Vec::with_capacity(amount as usize); + for j in length - amount..length { + let t = rng.gen_range(0..=j); + if floyd_shuffle { + if let Some(pos) = indices.iter().position(|&x| x == t) { + indices.insert(pos, j); + continue; + } + } else if indices.contains(&t) { + indices.push(j); + continue; + } + indices.push(t); + } + if !floyd_shuffle { + // Reimplement SliceRandom::shuffle with smaller indices + for i in (1..amount).rev() { + // invariant: elements with index > i have been locked in place. + indices.swap(i as usize, rng.gen_range(0..=i) as usize); + } + } + IndexVec::from(indices) +} + +/// Randomly sample exactly `amount` indices from `0..length`, using an inplace +/// partial Fisher-Yates method. +/// Sample an amount of indices using an inplace partial fisher yates method. +/// +/// This allocates the entire `length` of indices and randomizes only the first `amount`. +/// It then truncates to `amount` and returns. +/// +/// This method is not appropriate for large `length` and potentially uses a lot +/// of memory; because of this we only implement for `u32` index (which improves +/// performance in all cases). +/// +/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. +fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec +where R: Rng + ?Sized { + debug_assert!(amount <= length); + let mut indices: Vec<u32> = Vec::with_capacity(length as usize); + indices.extend(0..length); + for i in 0..amount { + let j: u32 = rng.gen_range(i..length); + indices.swap(i as usize, j as usize); + } + indices.truncate(amount as usize); + debug_assert_eq!(indices.len(), amount as usize); + IndexVec::from(indices) +} + +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + + core::hash::Hash + core::ops::AddAssign { + fn zero() -> Self; + fn one() -> Self; + fn as_usize(self) -> usize; +} +impl UInt for u32 { + #[inline] + fn zero() -> Self { + 0 + } + + #[inline] + fn one() -> Self { + 1 + } + + #[inline] + fn as_usize(self) -> usize { + self as usize + } +} +impl UInt for usize { + #[inline] + fn zero() -> Self { + 0 + } + + #[inline] + fn one() -> Self { + 1 + } + + #[inline] + fn as_usize(self) -> usize { + self + } +} + +/// Randomly sample exactly `amount` indices from `0..length`, using rejection +/// sampling. +/// +/// Since `amount <<< length` there is a low chance of a random sample in +/// `0..length` being a duplicate. We test for duplicates and resample where +/// necessary. The algorithm is `O(amount)` time and memory. +/// +/// This function is generic over X primarily so that results are value-stable +/// over 32-bit and 64-bit platforms. +fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec +where + R: Rng + ?Sized, + IndexVec: From<Vec<X>>, +{ + debug_assert!(amount < length); + #[cfg(feature = "std")] + let mut cache = HashSet::with_capacity(amount.as_usize()); + #[cfg(not(feature = "std"))] + let mut cache = BTreeSet::new(); + let distr = Uniform::new(X::zero(), length); + let mut indices = Vec::with_capacity(amount.as_usize()); + for _ in 0..amount.as_usize() { + let mut pos = distr.sample(rng); + while !cache.insert(pos) { + pos = distr.sample(rng); + } + indices.push(pos); + } + + debug_assert_eq!(indices.len(), amount.as_usize()); + IndexVec::from(indices) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[cfg(feature = "serde1")] + fn test_serialization_index_vec() { + let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]); + let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); + match (some_index_vec, de_some_index_vec) { + (IndexVec::U32(a), IndexVec::U32(b)) => { + assert_eq!(a, b); + }, + (IndexVec::USize(a), IndexVec::USize(b)) => { + assert_eq!(a, b); + }, + _ => {panic!("failed to seralize/deserialize `IndexVec`")} + } + } + + #[cfg(feature = "alloc")] use alloc::vec; + + #[test] + fn test_sample_boundaries() { + let mut r = crate::test::rng(404); + + assert_eq!(sample_inplace(&mut r, 0, 0).len(), 0); + assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0); + assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]); + + assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0); + + assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0); + assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0); + assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]); + + // These algorithms should be fast with big numbers. Test average. + let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32).into_iter().sum(); + assert!(1 << 25 < sum && sum < (1 << 25) * 25); + + let sum: usize = sample_floyd(&mut r, 1 << 25, 10).into_iter().sum(); + assert!(1 << 25 < sum && sum < (1 << 25) * 25); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_sample_alg() { + let seed_rng = crate::test::rng; + + // We can't test which algorithm is used directly, but Floyd's alg + // should produce different results from the others. (Also, `inplace` + // and `cached` currently use different sizes thus produce different results.) + + // A small length and relatively large amount should use inplace + let (length, amount): (usize, usize) = (100, 50); + let v1 = sample(&mut seed_rng(420), length, amount); + let v2 = sample_inplace(&mut seed_rng(420), length as u32, amount as u32); + assert!(v1.iter().all(|e| e < length)); + assert_eq!(v1, v2); + + // Test Floyd's alg does produce different results + let v3 = sample_floyd(&mut seed_rng(420), length as u32, amount as u32); + assert!(v1 != v3); + + // A large length and small amount should use Floyd + let (length, amount): (usize, usize) = (1 << 20, 50); + let v1 = sample(&mut seed_rng(421), length, amount); + let v2 = sample_floyd(&mut seed_rng(421), length as u32, amount as u32); + assert!(v1.iter().all(|e| e < length)); + assert_eq!(v1, v2); + + // A large length and larger amount should use cache + let (length, amount): (usize, usize) = (1 << 20, 600); + let v1 = sample(&mut seed_rng(422), length, amount); + let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32); + assert!(v1.iter().all(|e| e < length)); + assert_eq!(v1, v2); + } + + #[cfg(feature = "std")] + #[test] + fn test_sample_weighted() { + let seed_rng = crate::test::rng; + for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] { + let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap(); + match v { + IndexVec::U32(mut indices) => { + assert_eq!(indices.len(), amount); + indices.sort_unstable(); + indices.dedup(); + assert_eq!(indices.len(), amount); + for &i in &indices { + assert!((i as usize) < len); + } + }, + IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), + } + } + } + + #[test] + fn value_stability_sample() { + let do_test = |length, amount, values: &[u32]| { + let mut buf = [0u32; 8]; + let mut rng = crate::test::rng(410); + + let res = sample(&mut rng, length, amount); + let len = res.len().min(buf.len()); + for (x, y) in res.into_iter().zip(buf.iter_mut()) { + *y = x as u32; + } + assert_eq!( + &buf[0..len], + values, + "failed sampling {}, {}", + length, + amount + ); + }; + + do_test(10, 6, &[8, 0, 3, 5, 9, 6]); // floyd + do_test(25, 10, &[18, 15, 14, 9, 0, 13, 5, 24]); // floyd + do_test(300, 8, &[30, 283, 150, 1, 73, 13, 285, 35]); // floyd + do_test(300, 80, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace + do_test(300, 180, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace + + do_test(1_000_000, 8, &[ + 103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573, + ]); // floyd + do_test(1_000_000, 180, &[ + 103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573, + ]); // rejection + } +} diff --git a/third_party/rust/rand/src/seq/mod.rs b/third_party/rust/rand/src/seq/mod.rs new file mode 100644 index 0000000000..069e9e6b19 --- /dev/null +++ b/third_party/rust/rand/src/seq/mod.rs @@ -0,0 +1,1356 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Sequence-related functionality +//! +//! This module provides: +//! +//! * [`SliceRandom`] slice sampling and mutation +//! * [`IteratorRandom`] iterator sampling +//! * [`index::sample`] low-level API to choose multiple indices from +//! `0..length` +//! +//! Also see: +//! +//! * [`crate::distributions::WeightedIndex`] distribution which provides +//! weighted index sampling. +//! +//! In order to make results reproducible across 32-64 bit architectures, all +//! `usize` indices are sampled as a `u32` where possible (also providing a +//! small performance boost in some cases). + + +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub mod index; + +#[cfg(feature = "alloc")] use core::ops::Index; + +#[cfg(feature = "alloc")] use alloc::vec::Vec; + +#[cfg(feature = "alloc")] +use crate::distributions::uniform::{SampleBorrow, SampleUniform}; +#[cfg(feature = "alloc")] use crate::distributions::WeightedError; +use crate::Rng; + +/// Extension trait on slices, providing random mutation and sampling methods. +/// +/// This trait is implemented on all `[T]` slice types, providing several +/// methods for choosing and shuffling elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let mut rng = rand::thread_rng(); +/// let mut bytes = "Hello, random!".to_string().into_bytes(); +/// bytes.shuffle(&mut rng); +/// let str = String::from_utf8(bytes).unwrap(); +/// println!("{}", str); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// l,nmroHado !le +/// ``` +pub trait SliceRandom { + /// The element type. + type Item; + + /// Returns a reference to one random element of the slice, or `None` if the + /// slice is empty. + /// + /// For slices, complexity is `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::seq::SliceRandom; + /// + /// let choices = [1, 2, 4, 8, 16, 32]; + /// let mut rng = thread_rng(); + /// println!("{:?}", choices.choose(&mut rng)); + /// assert_eq!(choices[..0].choose(&mut rng), None); + /// ``` + fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized; + + /// Returns a mutable reference to one random element of the slice, or + /// `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized; + + /// Chooses `amount` elements from the slice at random, without repetition, + /// and in random order. The returned iterator is appropriate both for + /// collection into a `Vec` and filling an existing buffer (see example). + /// + /// In case this API is not sufficiently flexible, use [`index::sample`]. + /// + /// For slices, complexity is the same as [`index::sample`]. + /// + /// # Example + /// ``` + /// use rand::seq::SliceRandom; + /// + /// let mut rng = &mut rand::thread_rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// // collect the results into a vector: + /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect(); + /// + /// // store in a buffer: + /// let mut buf = [0u8; 5]; + /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// *slot = *b; + /// } + /// ``` + #[cfg(feature = "alloc")] + #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] + fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> + where R: Rng + ?Sized; + + /// Similar to [`choose`], but where the likelihood of each outcome may be + /// specified. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// See also [`choose_weighted_mut`], [`distributions::weighted`]. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = thread_rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' + /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); + /// ``` + /// [`choose`]: SliceRandom::choose + /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut + /// [`distributions::weighted`]: crate::distributions::weighted + #[cfg(feature = "alloc")] + #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] + fn choose_weighted<R, F, B, X>( + &self, rng: &mut R, weight: F, + ) -> Result<&Self::Item, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow<X>, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd<X> + + Clone + + Default; + + /// Similar to [`choose_mut`], but where the likelihood of each outcome may + /// be specified. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// See also [`choose_weighted`], [`distributions::weighted`]. + /// + /// [`choose_mut`]: SliceRandom::choose_mut + /// [`choose_weighted`]: SliceRandom::choose_weighted + /// [`distributions::weighted`]: crate::distributions::weighted + #[cfg(feature = "alloc")] + #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] + fn choose_weighted_mut<R, F, B, X>( + &mut self, rng: &mut R, weight: F, + ) -> Result<&mut Self::Item, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow<X>, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd<X> + + Clone + + Default; + + /// Similar to [`choose_multiple`], but where the likelihood of each element's + /// inclusion in the output may be specified. The elements are returned in an + /// arbitrary, unspecified order. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// If all of the weights are equal, even if they are all zero, each element has + /// an equal likelihood of being selected. + /// + /// The complexity of this method depends on the feature `partition_at_index`. + /// If the feature is enabled, then for slices of length `n`, the complexity + /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and + /// `O(n * log amount)` time. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = thread_rng(); + /// // First Draw * Second Draw = total odds + /// // ----------------------- + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. + /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. + /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>()); + /// ``` + /// [`choose_multiple`]: SliceRandom::choose_multiple + // + // Note: this is feature-gated on std due to usage of f64::powf. + // If necessary, we may use alloc+libm as an alternative (see PR #1089). + #[cfg(feature = "std")] + #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] + fn choose_multiple_weighted<R, F, X>( + &self, rng: &mut R, amount: usize, weight: F, + ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> X, + X: Into<f64>; + + /// Shuffle a mutable slice in place. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// + /// # Example + /// + /// ``` + /// use rand::seq::SliceRandom; + /// use rand::thread_rng; + /// + /// let mut rng = thread_rng(); + /// let mut y = [1, 2, 3, 4, 5]; + /// println!("Unshuffled: {:?}", y); + /// y.shuffle(&mut rng); + /// println!("Shuffled: {:?}", y); + /// ``` + fn shuffle<R>(&mut self, rng: &mut R) + where R: Rng + ?Sized; + + /// Shuffle a slice in place, but exit early. + /// + /// Returns two mutable slices from the source slice. The first contains + /// `amount` elements randomly permuted. The second has the remaining + /// elements that are not fully shuffled. + /// + /// This is an efficient method to select `amount` elements at random from + /// the slice, provided the slice may be mutated. + /// + /// If you only need to choose elements randomly and `amount > self.len()/2` + /// then you may improve performance by taking + /// `amount = values.len() - amount` and using only the second slice. + /// + /// If `amount` is greater than the number of elements in the slice, this + /// will perform a full shuffle. + /// + /// For slices, complexity is `O(m)` where `m = amount`. + fn partial_shuffle<R>( + &mut self, rng: &mut R, amount: usize, + ) -> (&mut [Self::Item], &mut [Self::Item]) + where R: Rng + ?Sized; +} + +/// Extension trait on iterators, providing random sampling methods. +/// +/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` +/// and provides methods for +/// choosing one or more elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::IteratorRandom; +/// +/// let mut rng = rand::thread_rng(); +/// +/// let faces = "๐๐๐๐๐ ๐ข"; +/// println!("I am {}!", faces.chars().choose(&mut rng).unwrap()); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// I am ๐! +/// ``` +pub trait IteratorRandom: Iterator + Sized { + /// Choose one element at random from the iterator. + /// + /// Returns `None` if and only if the iterator is empty. + /// + /// This method uses [`Iterator::size_hint`] for optimisation. With an + /// accurate hint and where [`Iterator::nth`] is a constant-time operation + /// this method can offer `O(1)` performance. Where no size hint is + /// available, complexity is `O(n)` where `n` is the iterator length. + /// Partial hints (where `lower > 0`) also improve performance. + /// + /// Note that the output values and the number of RNG samples used + /// depends on size hints. In particular, `Iterator` combinators that don't + /// change the values yielded but change the size hints may result in + /// `choose` returning different elements. If you want consistent results + /// and RNG usage consider using [`IteratorRandom::choose_stable`]. + fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item> + where R: Rng + ?Sized { + let (mut lower, mut upper) = self.size_hint(); + let mut consumed = 0; + let mut result = None; + + // Handling for this condition outside the loop allows the optimizer to eliminate the loop + // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. + // seq_iter_choose_from_1000. + if upper == Some(lower) { + return if lower == 0 { + None + } else { + self.nth(gen_index(rng, lower)) + }; + } + + // Continue until the iterator is exhausted + loop { + if lower > 1 { + let ix = gen_index(rng, lower + consumed); + let skip = if ix < lower { + result = self.nth(ix); + lower - (ix + 1) + } else { + lower + }; + if upper == Some(lower) { + return result; + } + consumed += lower; + if skip > 0 { + self.nth(skip - 1); + } + } else { + let elem = self.next(); + if elem.is_none() { + return result; + } + consumed += 1; + if gen_index(rng, consumed) == 0 { + result = elem; + } + } + + let hint = self.size_hint(); + lower = hint.0; + upper = hint.1; + } + } + + /// Choose one element at random from the iterator. + /// + /// Returns `None` if and only if the iterator is empty. + /// + /// This method is very similar to [`choose`] except that the result + /// only depends on the length of the iterator and the values produced by + /// `rng`. Notably for any iterator of a given length this will make the + /// same requests to `rng` and if the same sequence of values are produced + /// the same index will be selected from `self`. This may be useful if you + /// need consistent results no matter what type of iterator you are working + /// with. If you do not need this stability prefer [`choose`]. + /// + /// Note that this method still uses [`Iterator::size_hint`] to skip + /// constructing elements where possible, however the selection and `rng` + /// calls are the same in the face of this optimization. If you want to + /// force every element to be created regardless call `.inspect(|e| ())`. + /// + /// [`choose`]: IteratorRandom::choose + fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item> + where R: Rng + ?Sized { + let mut consumed = 0; + let mut result = None; + + loop { + // Currently the only way to skip elements is `nth()`. So we need to + // store what index to access next here. + // This should be replaced by `advance_by()` once it is stable: + // https://github.com/rust-lang/rust/issues/77404 + let mut next = 0; + + let (lower, _) = self.size_hint(); + if lower >= 2 { + let highest_selected = (0..lower) + .filter(|ix| gen_index(rng, consumed+ix+1) == 0) + .last(); + + consumed += lower; + next = lower; + + if let Some(ix) = highest_selected { + result = self.nth(ix); + next -= ix + 1; + debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); + } + } + + let elem = self.nth(next); + if elem.is_none() { + return result + } + + if gen_index(rng, consumed+1) == 0 { + result = elem; + } + consumed += 1; + } + } + + /// Collects values at random from the iterator into a supplied buffer + /// until that buffer is filled. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// Returns the number of elements added to the buffer. This equals the length + /// of the buffer unless the iterator contains insufficient elements, in which + /// case this equals the number of elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`SliceRandom::choose_multiple`]. + fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize + where R: Rng + ?Sized { + let amount = buf.len(); + let mut len = 0; + while len < amount { + if let Some(elem) = self.next() { + buf[len] = elem; + len += 1; + } else { + // Iterator exhausted; stop early + return len; + } + } + + // Continue, since the iterator was not exhausted + for (i, elem) in self.enumerate() { + let k = gen_index(rng, i + 1 + amount); + if let Some(slot) = buf.get_mut(k) { + *slot = elem; + } + } + len + } + + /// Collects `amount` values at random from the iterator into a vector. + /// + /// This is equivalent to `choose_multiple_fill` except for the result type. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// The length of the returned vector equals `amount` unless the iterator + /// contains insufficient elements, in which case it equals the number of + /// elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`SliceRandom::choose_multiple`]. + #[cfg(feature = "alloc")] + #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] + fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item> + where R: Rng + ?Sized { + let mut reservoir = Vec::with_capacity(amount); + reservoir.extend(self.by_ref().take(amount)); + + // Continue unless the iterator was exhausted + // + // note: this prevents iterators that "restart" from causing problems. + // If the iterator stops once, then so do we. + if reservoir.len() == amount { + for (i, elem) in self.enumerate() { + let k = gen_index(rng, i + 1 + amount); + if let Some(slot) = reservoir.get_mut(k) { + *slot = elem; + } + } + } else { + // Don't hang onto extra memory. There is a corner case where + // `amount` was much less than `self.len()`. + reservoir.shrink_to_fit(); + } + reservoir + } +} + + +impl<T> SliceRandom for [T] { + type Item = T; + + fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item> + where R: Rng + ?Sized { + if self.is_empty() { + None + } else { + Some(&self[gen_index(rng, self.len())]) + } + } + + fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> + where R: Rng + ?Sized { + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[gen_index(rng, len)]) + } + } + + #[cfg(feature = "alloc")] + fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> + where R: Rng + ?Sized { + let amount = ::core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample(rng, self.len(), amount).into_iter(), + } + } + + #[cfg(feature = "alloc")] + fn choose_weighted<R, F, B, X>( + &self, rng: &mut R, weight: F, + ) -> Result<&Self::Item, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow<X>, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd<X> + + Clone + + Default, + { + use crate::distributions::{Distribution, WeightedIndex}; + let distr = WeightedIndex::new(self.iter().map(weight))?; + Ok(&self[distr.sample(rng)]) + } + + #[cfg(feature = "alloc")] + fn choose_weighted_mut<R, F, B, X>( + &mut self, rng: &mut R, weight: F, + ) -> Result<&mut Self::Item, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> B, + B: SampleBorrow<X>, + X: SampleUniform + + for<'a> ::core::ops::AddAssign<&'a X> + + ::core::cmp::PartialOrd<X> + + Clone + + Default, + { + use crate::distributions::{Distribution, WeightedIndex}; + let distr = WeightedIndex::new(self.iter().map(weight))?; + Ok(&mut self[distr.sample(rng)]) + } + + #[cfg(feature = "std")] + fn choose_multiple_weighted<R, F, X>( + &self, rng: &mut R, amount: usize, weight: F, + ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> X, + X: Into<f64>, + { + let amount = ::core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )? + .into_iter(), + }) + } + + fn shuffle<R>(&mut self, rng: &mut R) + where R: Rng + ?Sized { + for i in (1..self.len()).rev() { + // invariant: elements with index > i have been locked in place. + self.swap(i, gen_index(rng, i + 1)); + } + } + + fn partial_shuffle<R>( + &mut self, rng: &mut R, amount: usize, + ) -> (&mut [Self::Item], &mut [Self::Item]) + where R: Rng + ?Sized { + // This applies Durstenfeld's algorithm for the + // [FisherโYates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation, but exits early after choosing `amount` + // elements. + + let len = self.len(); + let end = if amount >= len { 0 } else { len - amount }; + + for i in (end..len).rev() { + // invariant: elements with index > i have been locked in place. + self.swap(i, gen_index(rng, i + 1)); + } + let r = self.split_at_mut(end); + (r.1, r.0) + } +} + +impl<I> IteratorRandom for I where I: Iterator + Sized {} + + +/// An iterator over multiple slice elements. +/// +/// This struct is created by +/// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple). +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[derive(Debug)] +pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { + slice: &'a S, + _phantom: ::core::marker::PhantomData<T>, + indices: index::IndexVecIntoIter, +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { + type Item = &'a T; + + fn next(&mut self) -> Option<Self::Item> { + // TODO: investigate using SliceIndex::get_unchecked when stable + self.indices.next().map(|i| &self.slice[i as usize]) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + (self.indices.len(), Some(self.indices.len())) + } +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator + for SliceChooseIter<'a, S, T> +{ + fn len(&self) -> usize { + self.indices.len() + } +} + + +// Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where +// possible, primarily in order to produce the same output on 32-bit and 64-bit +// platforms. +#[inline] +fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize { + if ubound <= (core::u32::MAX as usize) { + rng.gen_range(0..ubound as u32) as usize + } else { + rng.gen_range(0..ubound) + } +} + + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "alloc")] use crate::Rng; + #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec; + + #[test] + fn test_slice_choose() { + let mut r = crate::test::rng(107); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut chosen = [0i32; 14]; + // The below all use a binomial distribution with n=1000, p=1/14. + // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 + for _ in 0..1000 { + let picked = *chars.choose(&mut r).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *chosen.choose_mut(&mut r).unwrap() += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + let mut v: [isize; 0] = []; + assert_eq!(v.choose(&mut r), None); + assert_eq!(v.choose_mut(&mut r), None); + } + + #[test] + fn value_stability_slice() { + let mut r = crate::test::rng(413); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + + assert_eq!(chars.choose(&mut r), Some(&'l')); + assert_eq!(nums.choose_mut(&mut r), Some(&mut 10)); + + #[cfg(feature = "alloc")] + assert_eq!( + &chars + .choose_multiple(&mut r, 8) + .cloned() + .collect::<Vec<char>>(), + &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e'] + ); + + #[cfg(feature = "alloc")] + assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f')); + #[cfg(feature = "alloc")] + assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5)); + + let mut r = crate::test::rng(414); + nums.shuffle(&mut r); + assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); + nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let res = nums.partial_shuffle(&mut r, 6); + assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); + assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); + } + + #[derive(Clone)] + struct UnhintedIterator<I: Iterator + Clone> { + iter: I, + } + impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> { + type Item = I::Item; + + fn next(&mut self) -> Option<Self::Item> { + self.iter.next() + } + } + + #[derive(Clone)] + struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> { + iter: I, + chunk_remaining: usize, + chunk_size: usize, + hint_total_size: bool, + } + impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> { + type Item = I::Item; + + fn next(&mut self) -> Option<Self::Item> { + if self.chunk_remaining == 0 { + self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len()); + } + self.chunk_remaining = self.chunk_remaining.saturating_sub(1); + + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option<usize>) { + ( + self.chunk_remaining, + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[derive(Clone)] + struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> { + iter: I, + window_size: usize, + hint_total_size: bool, + } + impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> { + type Item = I::Item; + + fn next(&mut self) -> Option<Self::Item> { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option<usize>) { + ( + ::core::cmp::min(self.iter.len(), self.window_size), + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose() { + let r = &mut crate::test::rng(109); + fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::<Vec<_>>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter(r, ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }); + test_iter(r, ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }); + test_iter(r, WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }); + test_iter(r, WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable() { + let r = &mut crate::test::rng(109); + fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::<Vec<_>>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter(r, ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }); + test_iter(r, ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }); + test_iter(r, WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }); + test_iter(r, WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable_stability() { + fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] { + let r = &mut crate::test::rng(109); + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + chosen + } + + let reference = test_iter(0..9); + assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); + + #[cfg(feature = "alloc")] + assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference); + assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); + assert_eq!(test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), reference); + assert_eq!(test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), reference); + assert_eq!(test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), reference); + assert_eq!(test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), reference); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_shuffle() { + let mut r = crate::test::rng(108); + let empty: &mut [isize] = &mut []; + empty.shuffle(&mut r); + let mut one = [1]; + one.shuffle(&mut r); + let b: &[_] = &[1]; + assert_eq!(one, b); + + let mut two = [1, 2]; + two.shuffle(&mut r); + assert!(two == [1, 2] || two == [2, 1]); + + fn move_last(slice: &mut [usize], pos: usize) { + // use slice[pos..].rotate_left(1); once we can use that + let last_val = slice[pos]; + for i in pos..slice.len() - 1 { + slice[i] = slice[i + 1]; + } + *slice.last_mut().unwrap() = last_val; + } + let mut counts = [0i32; 24]; + for _ in 0..10000 { + let mut arr: [usize; 4] = [0, 1, 2, 3]; + arr.shuffle(&mut r); + let mut permutation = 0usize; + let mut pos_value = counts.len(); + for i in 0..4 { + pos_value /= 4 - i; + let pos = arr.iter().position(|&x| x == i).unwrap(); + assert!(pos < (4 - i)); + permutation += pos * pos_value; + move_last(&mut arr, pos); + assert_eq!(arr[3], i); + } + for (i, &a) in arr.iter().enumerate() { + assert_eq!(a, i); + } + counts[permutation] += 1; + } + for count in counts.iter() { + // Binomial(10000, 1/24) with average 416.667 + // Octave: binocdf(n, 10000, 1/24) + // 99.9% chance samples lie within this range: + assert!(352 <= *count && *count <= 483, "count: {}", count); + } + } + + #[test] + fn test_partial_shuffle() { + let mut r = crate::test::rng(118); + + let mut empty: [u32; 0] = []; + let res = empty.partial_shuffle(&mut r, 10); + assert_eq!((res.0.len(), res.1.len()), (0, 0)); + + let mut v = [1, 2, 3, 4, 5]; + let res = v.partial_shuffle(&mut r, 2); + assert_eq!((res.0.len(), res.1.len()), (2, 3)); + assert!(res.0[0] != res.0[1]); + // First elements are only modified if selected, so at least one isn't modified: + assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_sample_iter() { + let min_val = 1; + let max_val = 100; + + let mut r = crate::test::rng(401); + let vals = (min_val..max_val).collect::<Vec<i32>>(); + let small_sample = vals.iter().choose_multiple(&mut r, 5); + let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); + + assert_eq!(small_sample.len(), 5); + assert_eq!(large_sample.len(), vals.len()); + // no randomization happens when amount >= len + assert_eq!(large_sample, vals.iter().collect::<Vec<_>>()); + + assert!(small_sample + .iter() + .all(|e| { **e >= min_val && **e <= max_val })); + } + + #[test] + #[cfg(feature = "alloc")] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weighted() { + let mut r = crate::test::rng(406); + const N_REPS: u32 = 3000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::<u32>() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // choose_weighted + fn get_weight<T>(item: &(u32, T)) -> u32 { + item.0 + } + let mut chosen = [0i32; 14]; + let mut items = [(0u32, 0usize); 14]; // (weight, index) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], i); + } + for _ in 0..N_REPS { + let item = items.choose_weighted(&mut r, get_weight).unwrap(); + chosen[item.1] += 1; + } + verify(chosen); + + // choose_weighted_mut + let mut items = [(0u32, 0i32); 14]; // (weight, count) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], 0); + } + for _ in 0..N_REPS { + items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; + } + for (ch, item) in chosen.iter_mut().zip(items.iter()) { + *ch = item.1; + } + verify(chosen); + + // Check error cases + let empty_slice = &mut [10][0..0]; + assert_eq!( + empty_slice.choose_weighted(&mut r, |_| 1), + Err(WeightedError::NoItem) + ); + assert_eq!( + empty_slice.choose_weighted_mut(&mut r, |_| 1), + Err(WeightedError::NoItem) + ); + assert_eq!( + ['x'].choose_weighted_mut(&mut r, |_| 0), + Err(WeightedError::AllWeightsZero) + ); + assert_eq!( + [0, -1].choose_weighted_mut(&mut r, |x| *x), + Err(WeightedError::InvalidWeight) + ); + assert_eq!( + [-1, 0].choose_weighted_mut(&mut r, |x| *x), + Err(WeightedError::InvalidWeight) + ); + } + + #[test] + fn value_stability_choose() { + fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> { + let mut rng = crate::test::rng(411); + iter.choose(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(33)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(39) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(39) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(90) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(90) + ); + } + + #[test] + fn value_stability_choose_stable() { + fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> { + let mut rng = crate::test::rng(411); + iter.choose_stable(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(40)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(40) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(40) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(40) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(40) + ); + } + + #[test] + fn value_stability_choose_multiple() { + fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) { + let mut rng = crate::test::rng(412); + let mut buf = [0u32; 8]; + assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len()); + assert_eq!(&buf[0..v.len()], v); + } + + do_test(0..4, &[0, 1, 2, 3]); + do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); + do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); + + #[cfg(feature = "alloc")] + { + fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) { + let mut rng = crate::test::rng(412); + assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); + } + + do_test(0..4, &[0, 1, 2, 3]); + do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); + do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); + } + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_edge_cases() { + use super::*; + + let mut rng = crate::test::rng(413); + + // Case 1: One of the weights is 0 + let choices = [('a', 2), ('b', 1), ('c', 0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::<Vec<_>>(); + + assert_eq!(result.len(), 2); + assert!(!result.iter().any(|val| val.0 == 'c')); + } + + // Case 2: All of the weights are 0 + let choices = [('a', 0), ('b', 0), ('c', 0)]; + + assert_eq!(choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap().count(), 2); + + // Case 3: Negative weights + let choices = [('a', -1), ('b', 1), ('c', 1)]; + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); + + // Case 4: Empty list + let choices = []; + assert_eq!(choices + .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) + .unwrap().count(), 0); + + // Case 5: NaN weights + let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); + + // Case 6: +infinity weights + let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::<Vec<_>>(); + assert_eq!(result.len(), 2); + assert!(result.iter().any(|val| val.0 == 'a')); + } + + // Case 7: -infinity weights + let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); + + // Case 8: -0 weights + let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; + assert!(choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .is_ok()); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_distributions() { + use super::*; + + // The theoretical probabilities of the different outcomes are: + // AB: 0.5 * 0.5 = 0.250 + // AC: 0.5 * 0.5 = 0.250 + // BA: 0.25 * 0.67 = 0.167 + // BC: 0.25 * 0.33 = 0.082 + // CA: 0.25 * 0.67 = 0.167 + // CB: 0.25 * 0.33 = 0.082 + let choices = [('a', 2), ('b', 1), ('c', 1)]; + let mut rng = crate::test::rng(414); + + let mut results = [0i32; 3]; + let expected_results = [4167, 4167, 1666]; + for _ in 0..10000 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::<Vec<_>>(); + + assert_eq!(result.len(), 2); + + match (result[0].0, result[1].0) { + ('a', 'b') | ('b', 'a') => { + results[0] += 1; + } + ('a', 'c') | ('c', 'a') => { + results[1] += 1; + } + ('b', 'c') | ('c', 'b') => { + results[2] += 1; + } + (_, _) => panic!("unexpected result"), + } + } + + let mut diffs = results + .iter() + .zip(&expected_results) + .map(|(a, b)| (a - b).abs()); + assert!(!diffs.any(|deviation| deviation > 100)); + } +} |