summaryrefslogtreecommitdiffstats
path: root/third_party/rust/rand/src/distributions/weighted_index.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
commit43a97878ce14b72f0981164f87f2e35e14151312 (patch)
tree620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/rust/rand/src/distributions/weighted_index.rs
parentInitial commit. (diff)
downloadfirefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz
firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/rand/src/distributions/weighted_index.rs')
-rw-r--r--third_party/rust/rand/src/distributions/weighted_index.rs458
1 files changed, 458 insertions, 0 deletions
diff --git a/third_party/rust/rand/src/distributions/weighted_index.rs b/third_party/rust/rand/src/distributions/weighted_index.rs
new file mode 100644
index 0000000000..8252b172f7
--- /dev/null
+++ b/third_party/rust/rand/src/distributions/weighted_index.rs
@@ -0,0 +1,458 @@
+// Copyright 2018 Developers of the Rand project.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+//! Weighted index sampling
+
+use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
+use crate::distributions::Distribution;
+use crate::Rng;
+use core::cmp::PartialOrd;
+use core::fmt;
+
+// Note that this whole module is only imported if feature="alloc" is enabled.
+use alloc::vec::Vec;
+
+#[cfg(feature = "serde1")]
+use serde::{Serialize, Deserialize};
+
+/// A distribution using weighted sampling of discrete items
+///
+/// Sampling a `WeightedIndex` distribution returns the index of a randomly
+/// selected element from the iterator used when the `WeightedIndex` was
+/// created. The chance of a given element being picked is proportional to the
+/// value of the element. The weights can use any type `X` for which an
+/// implementation of [`Uniform<X>`] exists.
+///
+/// # Performance
+///
+/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
+/// `N` is the number of weights. As an alternative,
+/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html)
+/// supports `O(1)` sampling, but with much higher initialisation cost.
+///
+/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
+/// size is the sum of the size of those objects, possibly plus some alignment.
+///
+/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
+/// weights of type `X`, where `N` is the number of weights. However, since
+/// `Vec` doesn't guarantee a particular growth strategy, additional memory
+/// might be allocated but not used. Since the `WeightedIndex` object also
+/// contains, this might cause additional allocations, though for primitive
+/// types, [`Uniform<X>`] doesn't allocate any memory.
+///
+/// Sampling from `WeightedIndex` will result in a single call to
+/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
+/// will request a single value from the underlying [`RngCore`], though the
+/// exact number depends on the implementation of `Uniform<X>::sample`.
+///
+/// # Example
+///
+/// ```
+/// use rand::prelude::*;
+/// use rand::distributions::WeightedIndex;
+///
+/// let choices = ['a', 'b', 'c'];
+/// let weights = [2, 1, 1];
+/// let dist = WeightedIndex::new(&weights).unwrap();
+/// let mut rng = thread_rng();
+/// for _ in 0..100 {
+/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
+/// println!("{}", choices[dist.sample(&mut rng)]);
+/// }
+///
+/// let items = [('a', 0), ('b', 3), ('c', 7)];
+/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
+/// for _ in 0..100 {
+/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
+/// println!("{}", items[dist2.sample(&mut rng)].0);
+/// }
+/// ```
+///
+/// [`Uniform<X>`]: crate::distributions::Uniform
+/// [`RngCore`]: crate::RngCore
+#[derive(Debug, Clone, PartialEq)]
+#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
+#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
+pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
+ cumulative_weights: Vec<X>,
+ total_weight: X,
+ weight_distribution: X::Sampler,
+}
+
+impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
+ /// Creates a new a `WeightedIndex` [`Distribution`] using the values
+ /// in `weights`. The weights can use any type `X` for which an
+ /// implementation of [`Uniform<X>`] exists.
+ ///
+ /// Returns an error if the iterator is empty, if any weight is `< 0`, or
+ /// if its total value is 0.
+ ///
+ /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
+ pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
+ where
+ I: IntoIterator,
+ I::Item: SampleBorrow<X>,
+ X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
+ {
+ let mut iter = weights.into_iter();
+ let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
+
+ let zero = <X as Default>::default();
+ if !(total_weight >= zero) {
+ return Err(WeightedError::InvalidWeight);
+ }
+
+ let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
+ for w in iter {
+ // Note that `!(w >= x)` is not equivalent to `w < x` for partially
+ // ordered types due to NaNs which are equal to nothing.
+ if !(w.borrow() >= &zero) {
+ return Err(WeightedError::InvalidWeight);
+ }
+ weights.push(total_weight.clone());
+ total_weight += w.borrow();
+ }
+
+ if total_weight == zero {
+ return Err(WeightedError::AllWeightsZero);
+ }
+ let distr = X::Sampler::new(zero, total_weight.clone());
+
+ Ok(WeightedIndex {
+ cumulative_weights: weights,
+ total_weight,
+ weight_distribution: distr,
+ })
+ }
+
+ /// Update a subset of weights, without changing the number of weights.
+ ///
+ /// `new_weights` must be sorted by the index.
+ ///
+ /// Using this method instead of `new` might be more efficient if only a small number of
+ /// weights is modified. No allocations are performed, unless the weight type `X` uses
+ /// allocation internally.
+ ///
+ /// In case of error, `self` is not modified.
+ pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
+ where X: for<'a> ::core::ops::AddAssign<&'a X>
+ + for<'a> ::core::ops::SubAssign<&'a X>
+ + Clone
+ + Default {
+ if new_weights.is_empty() {
+ return Ok(());
+ }
+
+ let zero = <X as Default>::default();
+
+ let mut total_weight = self.total_weight.clone();
+
+ // Check for errors first, so we don't modify `self` in case something
+ // goes wrong.
+ let mut prev_i = None;
+ for &(i, w) in new_weights {
+ if let Some(old_i) = prev_i {
+ if old_i >= i {
+ return Err(WeightedError::InvalidWeight);
+ }
+ }
+ if !(*w >= zero) {
+ return Err(WeightedError::InvalidWeight);
+ }
+ if i > self.cumulative_weights.len() {
+ return Err(WeightedError::TooMany);
+ }
+
+ let mut old_w = if i < self.cumulative_weights.len() {
+ self.cumulative_weights[i].clone()
+ } else {
+ self.total_weight.clone()
+ };
+ if i > 0 {
+ old_w -= &self.cumulative_weights[i - 1];
+ }
+
+ total_weight -= &old_w;
+ total_weight += w;
+ prev_i = Some(i);
+ }
+ if total_weight <= zero {
+ return Err(WeightedError::AllWeightsZero);
+ }
+
+ // Update the weights. Because we checked all the preconditions in the
+ // previous loop, this should never panic.
+ let mut iter = new_weights.iter();
+
+ let mut prev_weight = zero.clone();
+ let mut next_new_weight = iter.next();
+ let &(first_new_index, _) = next_new_weight.unwrap();
+ let mut cumulative_weight = if first_new_index > 0 {
+ self.cumulative_weights[first_new_index - 1].clone()
+ } else {
+ zero.clone()
+ };
+ for i in first_new_index..self.cumulative_weights.len() {
+ match next_new_weight {
+ Some(&(j, w)) if i == j => {
+ cumulative_weight += w;
+ next_new_weight = iter.next();
+ }
+ _ => {
+ let mut tmp = self.cumulative_weights[i].clone();
+ tmp -= &prev_weight; // We know this is positive.
+ cumulative_weight += &tmp;
+ }
+ }
+ prev_weight = cumulative_weight.clone();
+ core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
+ }
+
+ self.total_weight = total_weight;
+ self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
+
+ Ok(())
+ }
+}
+
+impl<X> Distribution<usize> for WeightedIndex<X>
+where X: SampleUniform + PartialOrd
+{
+ fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
+ use ::core::cmp::Ordering;
+ let chosen_weight = self.weight_distribution.sample(rng);
+ // Find the first item which has a weight *higher* than the chosen weight.
+ self.cumulative_weights
+ .binary_search_by(|w| {
+ if *w <= chosen_weight {
+ Ordering::Less
+ } else {
+ Ordering::Greater
+ }
+ })
+ .unwrap_err()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[cfg(feature = "serde1")]
+ #[test]
+ fn test_weightedindex_serde1() {
+ let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
+
+ let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
+ let de_weighted_index: WeightedIndex<i32> =
+ bincode::deserialize(&ser_weighted_index).unwrap();
+
+ assert_eq!(
+ de_weighted_index.cumulative_weights,
+ weighted_index.cumulative_weights
+ );
+ assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
+ }
+
+ #[test]
+ fn test_accepting_nan(){
+ assert_eq!(
+ WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
+ WeightedError::InvalidWeight,
+ );
+ assert_eq!(
+ WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
+ WeightedError::InvalidWeight,
+ );
+ assert_eq!(
+ WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
+ WeightedError::InvalidWeight,
+ );
+
+ assert_eq!(
+ WeightedIndex::new(&[0.5, 7.0])
+ .unwrap()
+ .update_weights(&[(0, &core::f32::NAN)])
+ .unwrap_err(),
+ WeightedError::InvalidWeight,
+ )
+ }
+
+
+ #[test]
+ #[cfg_attr(miri, ignore)] // Miri is too slow
+ fn test_weightedindex() {
+ let mut r = crate::test::rng(700);
+ const N_REPS: u32 = 5000;
+ let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
+ let total_weight = weights.iter().sum::<u32>() as f32;
+
+ let verify = |result: [i32; 14]| {
+ for (i, count) in result.iter().enumerate() {
+ let exp = (weights[i] * N_REPS) as f32 / total_weight;
+ let mut err = (*count as f32 - exp).abs();
+ if err != 0.0 {
+ err /= exp;
+ }
+ assert!(err <= 0.25);
+ }
+ };
+
+ // WeightedIndex from vec
+ let mut chosen = [0i32; 14];
+ let distr = WeightedIndex::new(weights.to_vec()).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ // WeightedIndex from slice
+ chosen = [0i32; 14];
+ let distr = WeightedIndex::new(&weights[..]).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ // WeightedIndex from iterator
+ chosen = [0i32; 14];
+ let distr = WeightedIndex::new(weights.iter()).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ for _ in 0..5 {
+ assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
+ assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
+ assert_eq!(
+ WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
+ .unwrap()
+ .sample(&mut r),
+ 4
+ );
+ }
+
+ assert_eq!(
+ WeightedIndex::new(&[10][0..0]).unwrap_err(),
+ WeightedError::NoItem
+ );
+ assert_eq!(
+ WeightedIndex::new(&[0]).unwrap_err(),
+ WeightedError::AllWeightsZero
+ );
+ assert_eq!(
+ WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(&[-10]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ }
+
+ #[test]
+ fn test_update_weights() {
+ let data = [
+ (
+ &[10u32, 2, 3, 4][..],
+ &[(1, &100), (2, &4)][..], // positive change
+ &[10, 100, 4, 4][..],
+ ),
+ (
+ &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
+ &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
+ &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
+ ),
+ ];
+
+ for (weights, update, expected_weights) in data.iter() {
+ let total_weight = weights.iter().sum::<u32>();
+ let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
+ assert_eq!(distr.total_weight, total_weight);
+
+ distr.update_weights(update).unwrap();
+ let expected_total_weight = expected_weights.iter().sum::<u32>();
+ let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
+ assert_eq!(distr.total_weight, expected_total_weight);
+ assert_eq!(distr.total_weight, expected_distr.total_weight);
+ assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
+ }
+ }
+
+ #[test]
+ fn value_stability() {
+ fn test_samples<X: SampleUniform + PartialOrd, I>(
+ weights: I, buf: &mut [usize], expected: &[usize],
+ ) where
+ I: IntoIterator,
+ I::Item: SampleBorrow<X>,
+ X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
+ {
+ assert_eq!(buf.len(), expected.len());
+ let distr = WeightedIndex::new(weights).unwrap();
+ let mut rng = crate::test::rng(701);
+ for r in buf.iter_mut() {
+ *r = rng.sample(&distr);
+ }
+ assert_eq!(buf, expected);
+ }
+
+ let mut buf = [0; 10];
+ test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
+ 0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
+ ]);
+ test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
+ 0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
+ ]);
+ test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
+ 2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
+ ]);
+ }
+
+ #[test]
+ fn weighted_index_distributions_can_be_compared() {
+ assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
+ }
+}
+
+/// Error type returned from `WeightedIndex::new`.
+#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum WeightedError {
+ /// The provided weight collection contains no items.
+ NoItem,
+
+ /// A weight is either less than zero, greater than the supported maximum,
+ /// NaN, or otherwise invalid.
+ InvalidWeight,
+
+ /// All items in the provided weight collection are zero.
+ AllWeightsZero,
+
+ /// Too many weights are provided (length greater than `u32::MAX`)
+ TooMany,
+}
+
+#[cfg(feature = "std")]
+impl std::error::Error for WeightedError {}
+
+impl fmt::Display for WeightedError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str(match *self {
+ WeightedError::NoItem => "No weights provided in distribution",
+ WeightedError::InvalidWeight => "A weight is invalid in distribution",
+ WeightedError::AllWeightsZero => "All weights are zero in distribution",
+ WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
+ })
+ }
+}