summaryrefslogtreecommitdiffstats
path: root/third_party/rust/rust_decimal/src/rand.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/rust_decimal/src/rand.rs')
-rw-r--r--third_party/rust/rust_decimal/src/rand.rs176
1 files changed, 176 insertions, 0 deletions
diff --git a/third_party/rust/rust_decimal/src/rand.rs b/third_party/rust/rust_decimal/src/rand.rs
new file mode 100644
index 0000000000..afd8a64361
--- /dev/null
+++ b/third_party/rust/rust_decimal/src/rand.rs
@@ -0,0 +1,176 @@
+use crate::Decimal;
+use rand::{
+ distributions::{
+ uniform::{SampleBorrow, SampleUniform, UniformInt, UniformSampler},
+ Distribution, Standard,
+ },
+ Rng,
+};
+
+impl Distribution<Decimal> for Standard {
+ fn sample<R>(&self, rng: &mut R) -> Decimal
+ where
+ R: Rng + ?Sized,
+ {
+ Decimal::from_parts(
+ rng.next_u32(),
+ rng.next_u32(),
+ rng.next_u32(),
+ rng.gen(),
+ rng.next_u32(),
+ )
+ }
+}
+
+impl SampleUniform for Decimal {
+ type Sampler = DecimalSampler;
+}
+
+#[derive(Clone, Copy, Debug, PartialEq)]
+pub struct DecimalSampler {
+ mantissa_sampler: UniformInt<i128>,
+ scale: u32,
+}
+
+impl UniformSampler for DecimalSampler {
+ type X = Decimal;
+
+ /// Creates a new sampler that will yield random decimal objects between `low` and `high`.
+ ///
+ /// The sampler will always provide decimals at the same scale as the inputs; if the inputs
+ /// have different scales, the higher scale is used.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use rand::Rng;
+ /// # use rust_decimal_macros::dec;
+ /// let mut rng = rand::rngs::OsRng;
+ /// let random = rng.gen_range(dec!(1.00)..dec!(2.00));
+ /// assert!(random >= dec!(1.00));
+ /// assert!(random < dec!(2.00));
+ /// assert_eq!(random.scale(), 2);
+ /// ```
+ #[inline]
+ fn new<B1, B2>(low: B1, high: B2) -> Self
+ where
+ B1: SampleBorrow<Self::X> + Sized,
+ B2: SampleBorrow<Self::X> + Sized,
+ {
+ let (low, high) = sync_scales(*low.borrow(), *high.borrow());
+ let high = Decimal::from_i128_with_scale(high.mantissa() - 1, high.scale());
+ UniformSampler::new_inclusive(low, high)
+ }
+
+ /// Creates a new sampler that will yield random decimal objects between `low` and `high`.
+ ///
+ /// The sampler will always provide decimals at the same scale as the inputs; if the inputs
+ /// have different scales, the higher scale is used.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use rand::Rng;
+ /// # use rust_decimal_macros::dec;
+ /// let mut rng = rand::rngs::OsRng;
+ /// let random = rng.gen_range(dec!(1.00)..=dec!(2.00));
+ /// assert!(random >= dec!(1.00));
+ /// assert!(random <= dec!(2.00));
+ /// assert_eq!(random.scale(), 2);
+ /// ```
+ #[inline]
+ fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
+ where
+ B1: SampleBorrow<Self::X> + Sized,
+ B2: SampleBorrow<Self::X> + Sized,
+ {
+ let (low, high) = sync_scales(*low.borrow(), *high.borrow());
+
+ // Return our sampler, which contains an underlying i128 sampler so we
+ // outsource the actual randomness implementation.
+ Self {
+ mantissa_sampler: UniformInt::new_inclusive(low.mantissa(), high.mantissa()),
+ scale: low.scale(),
+ }
+ }
+
+ #[inline]
+ fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
+ let mantissa = self.mantissa_sampler.sample(rng);
+ Decimal::from_i128_with_scale(mantissa, self.scale)
+ }
+}
+
+/// Return equivalent Decimal objects with the same scale as one another.
+#[inline]
+fn sync_scales(mut a: Decimal, mut b: Decimal) -> (Decimal, Decimal) {
+ if a.scale() == b.scale() {
+ return (a, b);
+ }
+
+ // Set scales to match one another, because we are relying on mantissas'
+ // being comparable in order outsource the actual sampling implementation.
+ a.rescale(a.scale().max(b.scale()));
+ b.rescale(a.scale().max(b.scale()));
+
+ // Edge case: If the values have _wildly_ different scales, the values may not have rescaled far enough to match one another.
+ //
+ // In this case, we accept some precision loss because the randomization approach we are using assumes that the scales will necessarily match.
+ if a.scale() != b.scale() {
+ a.rescale(a.scale().min(b.scale()));
+ b.rescale(a.scale().min(b.scale()));
+ }
+
+ (a, b)
+}
+
+#[cfg(test)]
+mod tests {
+ use std::collections::HashSet;
+
+ use super::*;
+
+ macro_rules! dec {
+ ($e:expr) => {
+ Decimal::from_str_exact(stringify!($e)).unwrap()
+ };
+ }
+
+ #[test]
+ fn has_random_decimal_instances() {
+ let mut rng = rand::rngs::OsRng;
+ let random: [Decimal; 32] = rng.gen();
+ assert!(random.windows(2).any(|slice| { slice[0] != slice[1] }));
+ }
+
+ #[test]
+ fn generates_within_range() {
+ let mut rng = rand::rngs::OsRng;
+ for _ in 0..128 {
+ let random = rng.gen_range(dec!(1.00)..dec!(1.05));
+ assert!(random < dec!(1.05));
+ assert!(random >= dec!(1.00));
+ }
+ }
+
+ #[test]
+ fn generates_within_inclusive_range() {
+ let mut rng = rand::rngs::OsRng;
+ let mut values: HashSet<Decimal> = HashSet::new();
+ for _ in 0..256 {
+ let random = rng.gen_range(dec!(1.00)..=dec!(1.01));
+ // The scale is 2, so 1.00 and 1.01 are the only two valid choices.
+ assert!(random == dec!(1.00) || random == dec!(1.01));
+ values.insert(random);
+ }
+ // Somewhat flaky, will fail 1 out of every 2^255 times this is run.
+ // Probably acceptable in the real world.
+ assert_eq!(values.len(), 2);
+ }
+
+ #[test]
+ fn test_edge_case_scales_match() {
+ let (low, high) = sync_scales(dec!(1.000_000_000_000_000_000_01), dec!(100_000_000_000_000_000_001));
+ assert_eq!(low.scale(), high.scale());
+ }
+}