diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 17:32:43 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 17:32:43 +0000 |
commit | 6bf0a5cb5034a7e684dcc3500e841785237ce2dd (patch) | |
tree | a68f146d7fa01f0134297619fbe7e33db084e0aa /third_party/rust/rust_decimal/src | |
parent | Initial commit. (diff) | |
download | thunderbird-6bf0a5cb5034a7e684dcc3500e841785237ce2dd.tar.xz thunderbird-6bf0a5cb5034a7e684dcc3500e841785237ce2dd.zip |
Adding upstream version 1:115.7.0.upstream/1%115.7.0upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/rust_decimal/src')
25 files changed, 10418 insertions, 0 deletions
diff --git a/third_party/rust/rust_decimal/src/arithmetic_impls.rs b/third_party/rust/rust_decimal/src/arithmetic_impls.rs new file mode 100644 index 0000000000..81cfa305d2 --- /dev/null +++ b/third_party/rust/rust_decimal/src/arithmetic_impls.rs @@ -0,0 +1,329 @@ +// #[rustfmt::skip] is being used because `rustfmt` poorly formats `#[doc = concat!(..)]`. See +// https://github.com/rust-lang/rustfmt/issues/5062 for more information. + +use crate::{decimal::CalculationResult, ops, Decimal}; +use core::ops::{Add, Div, Mul, Rem, Sub}; +use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedRem, CheckedSub, Inv}; + +// Macros and `Decimal` implementations + +#[rustfmt::skip] +macro_rules! impl_checked { + ($long:literal, $short:literal, $fun:ident, $impl:ident) => { + #[doc = concat!( + "Checked ", + $long, + ". Computes `self ", + $short, + " other`, returning `None` if overflow occurred." + )] + #[inline(always)] + #[must_use] + pub fn $fun(self, other: Decimal) -> Option<Decimal> { + match ops::$impl(&self, &other) { + CalculationResult::Ok(result) => Some(result), + _ => None, + } + } + }; +} + +#[rustfmt::skip] +macro_rules! impl_saturating { + ($long:literal, $short:literal, $fun:ident, $impl:ident, $cmp:ident) => { + #[doc = concat!( + "Saturating ", + $long, + ". Computes `self ", + $short, + " other`, saturating at the relevant upper or lower boundary.", + )] + #[inline(always)] + #[must_use] + pub fn $fun(self, other: Decimal) -> Decimal { + if let Some(elem) = self.$impl(other) { + elem + } else { + $cmp(&self, &other) + } + } + }; +} + +macro_rules! impl_checked_and_saturating { + ( + $op_long:literal, + $op_short:literal, + $checked_fun:ident, + $checked_impl:ident, + + $saturating_fun:ident, + $saturating_cmp:ident + ) => { + impl_checked!($op_long, $op_short, $checked_fun, $checked_impl); + impl_saturating!( + $op_long, + $op_short, + $saturating_fun, + $checked_fun, + $saturating_cmp + ); + }; +} + +impl Decimal { + impl_checked_and_saturating!( + "addition", + "+", + checked_add, + add_impl, + saturating_add, + if_a_is_positive_then_max + ); + impl_checked_and_saturating!( + "multiplication", + "*", + checked_mul, + mul_impl, + saturating_mul, + if_xnor_then_max + ); + impl_checked_and_saturating!( + "subtraction", + "-", + checked_sub, + sub_impl, + saturating_sub, + if_a_is_positive_then_max + ); + + impl_checked!("division", "/", checked_div, div_impl); + impl_checked!("remainder", "%", checked_rem, rem_impl); +} + +// Macros and trait implementations + +macro_rules! forward_all_binop { + (impl $imp:ident for $res:ty, $method:ident) => { + forward_val_val_binop!(impl $imp for $res, $method); + forward_ref_val_binop!(impl $imp for $res, $method); + forward_val_ref_binop!(impl $imp for $res, $method); + }; +} + +macro_rules! forward_ref_val_binop { + (impl $imp:ident for $res:ty, $method:ident) => { + impl<'a> $imp<$res> for &'a $res { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + self.$method(&other) + } + } + }; +} + +macro_rules! forward_val_ref_binop { + (impl $imp:ident for $res:ty, $method:ident) => { + impl<'a> $imp<&'a $res> for $res { + type Output = $res; + + #[inline] + fn $method(self, other: &$res) -> $res { + (&self).$method(other) + } + } + }; +} + +macro_rules! forward_val_val_binop { + (impl $imp:ident for $res:ty, $method:ident) => { + impl $imp<$res> for $res { + type Output = $res; + + #[inline] + fn $method(self, other: $res) -> $res { + (&self).$method(&other) + } + } + }; +} + +forward_all_binop!(impl Add for Decimal, add); +impl<'a, 'b> Add<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + #[inline(always)] + fn add(self, other: &Decimal) -> Decimal { + match ops::add_impl(self, other) { + CalculationResult::Ok(sum) => sum, + _ => panic!("Addition overflowed"), + } + } +} + +impl CheckedAdd for Decimal { + #[inline] + fn checked_add(&self, v: &Decimal) -> Option<Decimal> { + Decimal::checked_add(*self, *v) + } +} + +impl CheckedSub for Decimal { + #[inline] + fn checked_sub(&self, v: &Decimal) -> Option<Decimal> { + Decimal::checked_sub(*self, *v) + } +} + +impl CheckedMul for Decimal { + #[inline] + fn checked_mul(&self, v: &Decimal) -> Option<Decimal> { + Decimal::checked_mul(*self, *v) + } +} + +impl CheckedDiv for Decimal { + #[inline] + fn checked_div(&self, v: &Decimal) -> Option<Decimal> { + Decimal::checked_div(*self, *v) + } +} + +impl CheckedRem for Decimal { + #[inline] + fn checked_rem(&self, v: &Decimal) -> Option<Decimal> { + Decimal::checked_rem(*self, *v) + } +} + +impl Inv for Decimal { + type Output = Self; + + #[inline] + fn inv(self) -> Self { + Decimal::ONE / self + } +} + +forward_all_binop!(impl Div for Decimal, div); +impl<'a, 'b> Div<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + #[inline] + fn div(self, other: &Decimal) -> Decimal { + match ops::div_impl(self, other) { + CalculationResult::Ok(quot) => quot, + CalculationResult::Overflow => panic!("Division overflowed"), + CalculationResult::DivByZero => panic!("Division by zero"), + } + } +} + +forward_all_binop!(impl Mul for Decimal, mul); +impl<'a, 'b> Mul<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + #[inline] + fn mul(self, other: &Decimal) -> Decimal { + match ops::mul_impl(self, other) { + CalculationResult::Ok(prod) => prod, + _ => panic!("Multiplication overflowed"), + } + } +} + +forward_all_binop!(impl Rem for Decimal, rem); +impl<'a, 'b> Rem<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + #[inline] + fn rem(self, other: &Decimal) -> Decimal { + match ops::rem_impl(self, other) { + CalculationResult::Ok(rem) => rem, + CalculationResult::Overflow => panic!("Division overflowed"), + CalculationResult::DivByZero => panic!("Division by zero"), + } + } +} + +forward_all_binop!(impl Sub for Decimal, sub); +impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + #[inline(always)] + fn sub(self, other: &Decimal) -> Decimal { + match ops::sub_impl(self, other) { + CalculationResult::Ok(sum) => sum, + _ => panic!("Subtraction overflowed"), + } + } +} + +// This function signature is expected by `impl_saturating`, thus the reason of `_b`. +#[inline(always)] +const fn if_a_is_positive_then_max(a: &Decimal, _b: &Decimal) -> Decimal { + if a.is_sign_positive() { + Decimal::MAX + } else { + Decimal::MIN + } +} + +// Used by saturating multiplications. +// +// If the `a` and `b` combination represents a XNOR bit operation, returns MAX. Otherwise, +// returns MIN. +#[inline(always)] +const fn if_xnor_then_max(a: &Decimal, b: &Decimal) -> Decimal { + match (a.is_sign_positive(), b.is_sign_positive()) { + (true, true) => Decimal::MAX, + (true, false) => Decimal::MIN, + (false, true) => Decimal::MIN, + (false, false) => Decimal::MAX, + } +} + +#[cfg(test)] +mod tests { + use crate::Decimal; + + #[test] + fn checked_methods_have_correct_output() { + assert_eq!(Decimal::MAX.checked_add(Decimal::MAX), None); + assert_eq!(Decimal::MAX.checked_add(Decimal::MIN), Some(Decimal::ZERO)); + assert_eq!(Decimal::MAX.checked_div(Decimal::ZERO), None); + assert_eq!(Decimal::MAX.checked_mul(Decimal::MAX), None); + assert_eq!(Decimal::MAX.checked_mul(Decimal::MIN), None); + assert_eq!(Decimal::MAX.checked_rem(Decimal::ZERO), None); + assert_eq!(Decimal::MAX.checked_sub(Decimal::MAX), Some(Decimal::ZERO)); + assert_eq!(Decimal::MAX.checked_sub(Decimal::MIN), None); + + assert_eq!(Decimal::MIN.checked_add(Decimal::MAX), Some(Decimal::ZERO)); + assert_eq!(Decimal::MIN.checked_add(Decimal::MIN), None); + assert_eq!(Decimal::MIN.checked_div(Decimal::ZERO), None); + assert_eq!(Decimal::MIN.checked_mul(Decimal::MAX), None); + assert_eq!(Decimal::MIN.checked_mul(Decimal::MIN), None); + assert_eq!(Decimal::MIN.checked_rem(Decimal::ZERO), None); + assert_eq!(Decimal::MIN.checked_sub(Decimal::MAX), None); + assert_eq!(Decimal::MIN.checked_sub(Decimal::MIN), Some(Decimal::ZERO)); + } + + #[test] + fn saturated_methods_have_correct_output() { + assert_eq!(Decimal::MAX.saturating_add(Decimal::MAX), Decimal::MAX); + assert_eq!(Decimal::MAX.saturating_add(Decimal::MIN), Decimal::ZERO); + assert_eq!(Decimal::MAX.saturating_mul(Decimal::MAX), Decimal::MAX); + assert_eq!(Decimal::MAX.saturating_mul(Decimal::MIN), Decimal::MIN); + assert_eq!(Decimal::MAX.saturating_sub(Decimal::MAX), Decimal::ZERO); + assert_eq!(Decimal::MAX.saturating_sub(Decimal::MIN), Decimal::MAX); + + assert_eq!(Decimal::MIN.saturating_add(Decimal::MAX), Decimal::ZERO); + assert_eq!(Decimal::MIN.saturating_add(Decimal::MIN), Decimal::MIN); + assert_eq!(Decimal::MIN.saturating_mul(Decimal::MAX), Decimal::MIN); + assert_eq!(Decimal::MIN.saturating_mul(Decimal::MIN), Decimal::MAX); + assert_eq!(Decimal::MIN.saturating_sub(Decimal::MAX), Decimal::MIN); + assert_eq!(Decimal::MIN.saturating_sub(Decimal::MIN), Decimal::ZERO); + } +} diff --git a/third_party/rust/rust_decimal/src/constants.rs b/third_party/rust/rust_decimal/src/constants.rs new file mode 100644 index 0000000000..59f3366587 --- /dev/null +++ b/third_party/rust/rust_decimal/src/constants.rs @@ -0,0 +1,72 @@ +// Sign mask for the flags field. A value of zero in this bit indicates a +// positive Decimal value, and a value of one in this bit indicates a +// negative Decimal value. +pub const SIGN_MASK: u32 = 0x8000_0000; +pub const UNSIGN_MASK: u32 = 0x4FFF_FFFF; + +// Scale mask for the flags field. This byte in the flags field contains +// the power of 10 to divide the Decimal value by. The scale byte must +// contain a value between 0 and 28 inclusive. +pub const SCALE_MASK: u32 = 0x00FF_0000; +pub const U8_MASK: u32 = 0x0000_00FF; +pub const U32_MASK: u64 = u32::MAX as _; + +// Number of bits scale is shifted by. +pub const SCALE_SHIFT: u32 = 16; +// Number of bits sign is shifted by. +pub const SIGN_SHIFT: u32 = 31; + +// The maximum string buffer size used for serialization purposes. 31 is optimal, however we align +// to the byte boundary for simplicity. +pub const MAX_STR_BUFFER_SIZE: usize = 32; + +// The maximum supported precision +pub const MAX_PRECISION: u8 = 28; +#[cfg(not(feature = "legacy-ops"))] +// u8 to i32 is infallible, therefore, this cast will never overflow +pub const MAX_PRECISION_I32: i32 = MAX_PRECISION as _; +// u8 to u32 is infallible, therefore, this cast will never overflow +pub const MAX_PRECISION_U32: u32 = MAX_PRECISION as _; +// 79,228,162,514,264,337,593,543,950,335 +pub const MAX_I128_REPR: i128 = 0x0000_0000_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF; + +// Fast access for 10^n where n is 0-9 +pub const POWERS_10: [u32; 10] = [ + 1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000, +]; +// Fast access for 10^n where n is 1-19 +pub const BIG_POWERS_10: [u64; 19] = [ + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, +]; + +#[cfg(not(feature = "legacy-ops"))] +// The maximum power of 10 that a 32 bit integer can store +pub const MAX_I32_SCALE: i32 = 9; +#[cfg(not(feature = "legacy-ops"))] +// The maximum power of 10 that a 64 bit integer can store +pub const MAX_I64_SCALE: u32 = 19; +#[cfg(not(feature = "legacy-ops"))] +pub const U32_MAX: u64 = u32::MAX as u64; + +// Determines potential overflow for 128 bit operations +pub const OVERFLOW_U96: u128 = 1u128 << 96; +pub const WILL_OVERFLOW_U64: u64 = u64::MAX / 10 - u8::MAX as u64; +pub const BYTES_TO_OVERFLOW_U64: usize = 18; // We can probably get away with less diff --git a/third_party/rust/rust_decimal/src/decimal.rs b/third_party/rust/rust_decimal/src/decimal.rs new file mode 100644 index 0000000000..e98f7b4e4e --- /dev/null +++ b/third_party/rust/rust_decimal/src/decimal.rs @@ -0,0 +1,2578 @@ +use crate::constants::{ + MAX_I128_REPR, MAX_PRECISION_U32, POWERS_10, SCALE_MASK, SCALE_SHIFT, SIGN_MASK, SIGN_SHIFT, U32_MASK, U8_MASK, + UNSIGN_MASK, +}; +use crate::ops; +use crate::Error; +use core::{ + cmp::{Ordering::Equal, *}, + fmt, + hash::{Hash, Hasher}, + iter::{Product, Sum}, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign}, + str::FromStr, +}; + +// Diesel configuration +#[cfg(feature = "diesel2")] +use diesel::deserialize::FromSqlRow; +#[cfg(feature = "diesel2")] +use diesel::expression::AsExpression; +#[cfg(any(feature = "diesel1", feature = "diesel2"))] +use diesel::sql_types::Numeric; + +#[allow(unused_imports)] // It's not actually dead code below, but the compiler thinks it is. +#[cfg(not(feature = "std"))] +use num_traits::float::FloatCore; +use num_traits::{FromPrimitive, Num, One, Signed, ToPrimitive, Zero}; +#[cfg(feature = "rkyv")] +use rkyv::{Archive, Deserialize, Serialize}; + +/// The smallest value that can be represented by this decimal type. +const MIN: Decimal = Decimal { + flags: 2_147_483_648, + lo: 4_294_967_295, + mid: 4_294_967_295, + hi: 4_294_967_295, +}; + +/// The largest value that can be represented by this decimal type. +const MAX: Decimal = Decimal { + flags: 0, + lo: 4_294_967_295, + mid: 4_294_967_295, + hi: 4_294_967_295, +}; + +const ZERO: Decimal = Decimal { + flags: 0, + lo: 0, + mid: 0, + hi: 0, +}; +const ONE: Decimal = Decimal { + flags: 0, + lo: 1, + mid: 0, + hi: 0, +}; +const TWO: Decimal = Decimal { + flags: 0, + lo: 2, + mid: 0, + hi: 0, +}; +const TEN: Decimal = Decimal { + flags: 0, + lo: 10, + mid: 0, + hi: 0, +}; +const ONE_HUNDRED: Decimal = Decimal { + flags: 0, + lo: 100, + mid: 0, + hi: 0, +}; +const ONE_THOUSAND: Decimal = Decimal { + flags: 0, + lo: 1000, + mid: 0, + hi: 0, +}; +const NEGATIVE_ONE: Decimal = Decimal { + flags: 2147483648, + lo: 1, + mid: 0, + hi: 0, +}; + +/// `UnpackedDecimal` contains unpacked representation of `Decimal` where each component +/// of decimal-format stored in it's own field +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct UnpackedDecimal { + pub negative: bool, + pub scale: u32, + pub hi: u32, + pub mid: u32, + pub lo: u32, +} + +/// `Decimal` represents a 128 bit representation of a fixed-precision decimal number. +/// The finite set of values of type `Decimal` are of the form m / 10<sup>e</sup>, +/// where m is an integer such that -2<sup>96</sup> < m < 2<sup>96</sup>, and e is an integer +/// between 0 and 28 inclusive. +#[derive(Clone, Copy)] +#[cfg_attr( + all(feature = "diesel1", not(feature = "diesel2")), + derive(FromSqlRow, AsExpression), + sql_type = "Numeric" +)] +#[cfg_attr(feature = "diesel2", derive(FromSqlRow, AsExpression), diesel(sql_type = Numeric))] +#[cfg_attr(feature = "c-repr", repr(C))] +#[cfg_attr( + feature = "borsh", + derive(borsh::BorshDeserialize, borsh::BorshSerialize, borsh::BorshSchema) +)] +#[cfg_attr( + feature = "rkyv", + derive(Archive, Deserialize, Serialize), + archive(compare(PartialEq)), + archive_attr(derive(Clone, Copy, Debug)) +)] +#[cfg_attr(feature = "rkyv-safe", archive_attr(derive(bytecheck::CheckBytes)))] +pub struct Decimal { + // Bits 0-15: unused + // Bits 16-23: Contains "e", a value between 0-28 that indicates the scale + // Bits 24-30: unused + // Bit 31: the sign of the Decimal value, 0 meaning positive and 1 meaning negative. + flags: u32, + // The lo, mid, hi, and flags fields contain the representation of the + // Decimal value as a 96-bit integer. + hi: u32, + lo: u32, + mid: u32, +} + +/// `RoundingStrategy` represents the different rounding strategies that can be used by +/// `round_dp_with_strategy`. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum RoundingStrategy { + /// When a number is halfway between two others, it is rounded toward the nearest even number. + /// Also known as "Bankers Rounding". + /// e.g. + /// 6.5 -> 6, 7.5 -> 8 + MidpointNearestEven, + /// When a number is halfway between two others, it is rounded toward the nearest number that + /// is away from zero. e.g. 6.4 -> 6, 6.5 -> 7, -6.5 -> -7 + MidpointAwayFromZero, + /// When a number is halfway between two others, it is rounded toward the nearest number that + /// is toward zero. e.g. 6.4 -> 6, 6.5 -> 6, -6.5 -> -6 + MidpointTowardZero, + /// The number is always rounded toward zero. e.g. -6.8 -> -6, 6.8 -> 6 + ToZero, + /// The number is always rounded away from zero. e.g. -6.8 -> -7, 6.8 -> 7 + AwayFromZero, + /// The number is always rounded towards negative infinity. e.g. 6.8 -> 6, -6.8 -> -7 + ToNegativeInfinity, + /// The number is always rounded towards positive infinity. e.g. 6.8 -> 7, -6.8 -> -6 + ToPositiveInfinity, + + /// When a number is halfway between two others, it is rounded toward the nearest even number. + /// e.g. + /// 6.5 -> 6, 7.5 -> 8 + #[deprecated(since = "1.11.0", note = "Please use RoundingStrategy::MidpointNearestEven instead")] + BankersRounding, + /// Rounds up if the value >= 5, otherwise rounds down, e.g. 6.5 -> 7 + #[deprecated(since = "1.11.0", note = "Please use RoundingStrategy::MidpointAwayFromZero instead")] + RoundHalfUp, + /// Rounds down if the value =< 5, otherwise rounds up, e.g. 6.5 -> 6, 6.51 -> 7 1.4999999 -> 1 + #[deprecated(since = "1.11.0", note = "Please use RoundingStrategy::MidpointTowardZero instead")] + RoundHalfDown, + /// Always round down. + #[deprecated(since = "1.11.0", note = "Please use RoundingStrategy::ToZero instead")] + RoundDown, + /// Always round up. + #[deprecated(since = "1.11.0", note = "Please use RoundingStrategy::AwayFromZero instead")] + RoundUp, +} + +#[allow(dead_code)] +impl Decimal { + /// The smallest value that can be represented by this decimal type. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::MIN, dec!(-79_228_162_514_264_337_593_543_950_335)); + /// ``` + pub const MIN: Decimal = MIN; + /// The largest value that can be represented by this decimal type. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::MAX, dec!(79_228_162_514_264_337_593_543_950_335)); + /// ``` + pub const MAX: Decimal = MAX; + /// A constant representing 0. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::ZERO, dec!(0)); + /// ``` + pub const ZERO: Decimal = ZERO; + /// A constant representing 1. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::ONE, dec!(1)); + /// ``` + pub const ONE: Decimal = ONE; + /// A constant representing -1. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::NEGATIVE_ONE, dec!(-1)); + /// ``` + pub const NEGATIVE_ONE: Decimal = NEGATIVE_ONE; + /// A constant representing 2. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::TWO, dec!(2)); + /// ``` + pub const TWO: Decimal = TWO; + /// A constant representing 10. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::TEN, dec!(10)); + /// ``` + pub const TEN: Decimal = TEN; + /// A constant representing 100. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::ONE_HUNDRED, dec!(100)); + /// ``` + pub const ONE_HUNDRED: Decimal = ONE_HUNDRED; + /// A constant representing 1000. + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::ONE_THOUSAND, dec!(1000)); + /// ``` + pub const ONE_THOUSAND: Decimal = ONE_THOUSAND; + + /// A constant representing π as 3.1415926535897932384626433833 + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::PI, dec!(3.1415926535897932384626433833)); + /// ``` + #[cfg(feature = "maths")] + pub const PI: Decimal = Decimal { + flags: 1835008, + lo: 1102470953, + mid: 185874565, + hi: 1703060790, + }; + /// A constant representing π/2 as 1.5707963267948966192313216916 + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::HALF_PI, dec!(1.5707963267948966192313216916)); + /// ``` + #[cfg(feature = "maths")] + pub const HALF_PI: Decimal = Decimal { + flags: 1835008, + lo: 2698719124, + mid: 92937282, + hi: 851530395, + }; + /// A constant representing π/4 as 0.7853981633974483096156608458 + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::QUARTER_PI, dec!(0.7853981633974483096156608458)); + /// ``` + #[cfg(feature = "maths")] + pub const QUARTER_PI: Decimal = Decimal { + flags: 1835008, + lo: 1349359562, + mid: 2193952289, + hi: 425765197, + }; + /// A constant representing 2π as 6.2831853071795864769252867666 + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::TWO_PI, dec!(6.2831853071795864769252867666)); + /// ``` + #[cfg(feature = "maths")] + pub const TWO_PI: Decimal = Decimal { + flags: 1835008, + lo: 2204941906, + mid: 371749130, + hi: 3406121580, + }; + /// A constant representing Euler's number (e) as 2.7182818284590452353602874714 + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::E, dec!(2.7182818284590452353602874714)); + /// ``` + #[cfg(feature = "maths")] + pub const E: Decimal = Decimal { + flags: 1835008, + lo: 2239425882, + mid: 3958169141, + hi: 1473583531, + }; + /// A constant representing the inverse of Euler's number (1/e) as 0.3678794411714423215955237702 + /// + /// # Examples + /// + /// Basic usage: + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// assert_eq!(Decimal::E_INVERSE, dec!(0.3678794411714423215955237702)); + /// ``` + #[cfg(feature = "maths")] + pub const E_INVERSE: Decimal = Decimal { + flags: 1835008, + lo: 2384059206, + mid: 2857938002, + hi: 199427844, + }; + + /// Returns a `Decimal` with a 64 bit `m` representation and corresponding `e` scale. + /// + /// # Arguments + /// + /// * `num` - An i64 that represents the `m` portion of the decimal number + /// * `scale` - A u32 representing the `e` portion of the decimal number. + /// + /// # Panics + /// + /// This function panics if `scale` is > 28. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let pi = Decimal::new(3141, 3); + /// assert_eq!(pi.to_string(), "3.141"); + /// ``` + #[must_use] + pub fn new(num: i64, scale: u32) -> Decimal { + match Self::try_new(num, scale) { + Err(e) => panic!("{}", e), + Ok(d) => d, + } + } + + /// Checked version of `Decimal::new`. Will return `Err` instead of panicking at run-time. + /// + /// # Example + /// + /// ```rust + /// # use rust_decimal::Decimal; + /// # + /// let max = Decimal::try_new(i64::MAX, u32::MAX); + /// assert!(max.is_err()); + /// ``` + pub const fn try_new(num: i64, scale: u32) -> crate::Result<Decimal> { + if scale > MAX_PRECISION_U32 { + return Err(Error::ScaleExceedsMaximumPrecision(scale)); + } + let flags: u32 = scale << SCALE_SHIFT; + if num < 0 { + let pos_num = num.wrapping_neg() as u64; + return Ok(Decimal { + flags: flags | SIGN_MASK, + hi: 0, + lo: (pos_num & U32_MASK) as u32, + mid: ((pos_num >> 32) & U32_MASK) as u32, + }); + } + Ok(Decimal { + flags, + hi: 0, + lo: (num as u64 & U32_MASK) as u32, + mid: ((num as u64 >> 32) & U32_MASK) as u32, + }) + } + + /// Creates a `Decimal` using a 128 bit signed `m` representation and corresponding `e` scale. + /// + /// # Arguments + /// + /// * `num` - An i128 that represents the `m` portion of the decimal number + /// * `scale` - A u32 representing the `e` portion of the decimal number. + /// + /// # Panics + /// + /// This function panics if `scale` is > 28 or if `num` exceeds the maximum supported 96 bits. + /// + /// # Example + /// + /// ```rust + /// # use rust_decimal::Decimal; + /// # + /// let pi = Decimal::from_i128_with_scale(3141i128, 3); + /// assert_eq!(pi.to_string(), "3.141"); + /// ``` + #[must_use] + pub fn from_i128_with_scale(num: i128, scale: u32) -> Decimal { + match Self::try_from_i128_with_scale(num, scale) { + Ok(d) => d, + Err(e) => panic!("{}", e), + } + } + + /// Checked version of `Decimal::from_i128_with_scale`. Will return `Err` instead + /// of panicking at run-time. + /// + /// # Example + /// + /// ```rust + /// # use rust_decimal::Decimal; + /// # + /// let max = Decimal::try_from_i128_with_scale(i128::MAX, u32::MAX); + /// assert!(max.is_err()); + /// ``` + pub const fn try_from_i128_with_scale(num: i128, scale: u32) -> crate::Result<Decimal> { + if scale > MAX_PRECISION_U32 { + return Err(Error::ScaleExceedsMaximumPrecision(scale)); + } + let mut neg = false; + let mut wrapped = num; + if num > MAX_I128_REPR { + return Err(Error::ExceedsMaximumPossibleValue); + } else if num < -MAX_I128_REPR { + return Err(Error::LessThanMinimumPossibleValue); + } else if num < 0 { + neg = true; + wrapped = -num; + } + let flags: u32 = flags(neg, scale); + Ok(Decimal { + flags, + lo: (wrapped as u64 & U32_MASK) as u32, + mid: ((wrapped as u64 >> 32) & U32_MASK) as u32, + hi: ((wrapped as u128 >> 64) as u64 & U32_MASK) as u32, + }) + } + + /// Returns a `Decimal` using the instances constituent parts. + /// + /// # Arguments + /// + /// * `lo` - The low 32 bits of a 96-bit integer. + /// * `mid` - The middle 32 bits of a 96-bit integer. + /// * `hi` - The high 32 bits of a 96-bit integer. + /// * `negative` - `true` to indicate a negative number. + /// * `scale` - A power of 10 ranging from 0 to 28. + /// + /// # Caution: Undefined behavior + /// + /// While a scale greater than 28 can be passed in, it will be automatically capped by this + /// function at the maximum precision. The library opts towards this functionality as opposed + /// to a panic to ensure that the function can be treated as constant. This may lead to + /// undefined behavior in downstream applications and should be treated with caution. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let pi = Decimal::from_parts(1102470952, 185874565, 1703060790, false, 28); + /// assert_eq!(pi.to_string(), "3.1415926535897932384626433832"); + /// ``` + #[must_use] + pub const fn from_parts(lo: u32, mid: u32, hi: u32, negative: bool, scale: u32) -> Decimal { + Decimal { + lo, + mid, + hi, + flags: flags( + if lo == 0 && mid == 0 && hi == 0 { + false + } else { + negative + }, + scale % (MAX_PRECISION_U32 + 1), + ), + } + } + + #[must_use] + pub(crate) const fn from_parts_raw(lo: u32, mid: u32, hi: u32, flags: u32) -> Decimal { + if lo == 0 && mid == 0 && hi == 0 { + Decimal { + lo, + mid, + hi, + flags: flags & SCALE_MASK, + } + } else { + Decimal { flags, hi, lo, mid } + } + } + + /// Returns a `Result` which if successful contains the `Decimal` constitution of + /// the scientific notation provided by `value`. + /// + /// # Arguments + /// + /// * `value` - The scientific notation of the `Decimal`. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// # fn main() -> Result<(), rust_decimal::Error> { + /// let value = Decimal::from_scientific("9.7e-7")?; + /// assert_eq!(value.to_string(), "0.00000097"); + /// # Ok(()) + /// # } + /// ``` + pub fn from_scientific(value: &str) -> Result<Decimal, Error> { + const ERROR_MESSAGE: &str = "Failed to parse"; + + let mut split = value.splitn(2, |c| c == 'e' || c == 'E'); + + let base = split.next().ok_or_else(|| Error::from(ERROR_MESSAGE))?; + let exp = split.next().ok_or_else(|| Error::from(ERROR_MESSAGE))?; + + let mut ret = Decimal::from_str(base)?; + let current_scale = ret.scale(); + + if let Some(stripped) = exp.strip_prefix('-') { + let exp: u32 = stripped.parse().map_err(|_| Error::from(ERROR_MESSAGE))?; + ret.set_scale(current_scale + exp)?; + } else { + let exp: u32 = exp.parse().map_err(|_| Error::from(ERROR_MESSAGE))?; + if exp <= current_scale { + ret.set_scale(current_scale - exp)?; + } else if exp > 0 { + use crate::constants::BIG_POWERS_10; + + // This is a case whereby the mantissa needs to be larger to be correctly + // represented within the decimal type. A good example is 1.2E10. At this point, + // we've parsed 1.2 as the base and 10 as the exponent. To represent this within a + // Decimal type we effectively store the mantissa as 12,000,000,000 and scale as + // zero. + if exp > MAX_PRECISION_U32 { + return Err(Error::ScaleExceedsMaximumPrecision(exp)); + } + let mut exp = exp as usize; + // Max two iterations. If exp is 1 then it needs to index position 0 of the array. + while exp > 0 { + let pow; + if exp >= BIG_POWERS_10.len() { + pow = BIG_POWERS_10[BIG_POWERS_10.len() - 1]; + exp -= BIG_POWERS_10.len(); + } else { + pow = BIG_POWERS_10[exp - 1]; + exp = 0; + } + + let pow = Decimal { + flags: 0, + lo: pow as u32, + mid: (pow >> 32) as u32, + hi: 0, + }; + match ret.checked_mul(pow) { + Some(r) => ret = r, + None => return Err(Error::ExceedsMaximumPossibleValue), + }; + } + ret.normalize_assign(); + } + } + Ok(ret) + } + + /// Converts a string slice in a given base to a decimal. + /// + /// The string is expected to be an optional + sign followed by digits. + /// Digits are a subset of these characters, depending on radix, and will return an error if outside + /// the expected range: + /// + /// * 0-9 + /// * a-z + /// * A-Z + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # + /// # fn main() -> Result<(), rust_decimal::Error> { + /// assert_eq!(Decimal::from_str_radix("A", 16)?.to_string(), "10"); + /// # Ok(()) + /// # } + /// ``` + pub fn from_str_radix(str: &str, radix: u32) -> Result<Self, crate::Error> { + if radix == 10 { + crate::str::parse_str_radix_10(str) + } else { + crate::str::parse_str_radix_n(str, radix) + } + } + + /// Parses a string slice into a decimal. If the value underflows and cannot be represented with the + /// given scale then this will return an error. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # use rust_decimal::Error; + /// # + /// # fn main() -> Result<(), rust_decimal::Error> { + /// assert_eq!(Decimal::from_str_exact("0.001")?.to_string(), "0.001"); + /// assert_eq!(Decimal::from_str_exact("0.00000_00000_00000_00000_00000_001")?.to_string(), "0.0000000000000000000000000001"); + /// assert_eq!(Decimal::from_str_exact("0.00000_00000_00000_00000_00000_0001"), Err(Error::Underflow)); + /// # Ok(()) + /// # } + /// ``` + pub fn from_str_exact(str: &str) -> Result<Self, crate::Error> { + crate::str::parse_str_radix_10_exact(str) + } + + /// Returns the scale of the decimal number, otherwise known as `e`. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let num = Decimal::new(1234, 3); + /// assert_eq!(num.scale(), 3u32); + /// ``` + #[inline] + #[must_use] + pub const fn scale(&self) -> u32 { + ((self.flags & SCALE_MASK) >> SCALE_SHIFT) as u32 + } + + /// Returns the mantissa of the decimal number. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// use rust_decimal_macros::dec; + /// + /// let num = dec!(-1.2345678); + /// assert_eq!(num.mantissa(), -12345678i128); + /// assert_eq!(num.scale(), 7); + /// ``` + #[must_use] + pub const fn mantissa(&self) -> i128 { + let raw = (self.lo as i128) | ((self.mid as i128) << 32) | ((self.hi as i128) << 64); + if self.is_sign_negative() { + -raw + } else { + raw + } + } + + /// Returns true if this Decimal number is equivalent to zero. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # + /// let num = Decimal::ZERO; + /// assert!(num.is_zero()); + /// ``` + #[must_use] + pub const fn is_zero(&self) -> bool { + self.lo == 0 && self.mid == 0 && self.hi == 0 + } + + /// An optimized method for changing the sign of a decimal number. + /// + /// # Arguments + /// + /// * `positive`: true if the resulting decimal should be positive. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let mut one = Decimal::ONE; + /// one.set_sign(false); + /// assert_eq!(one.to_string(), "-1"); + /// ``` + #[deprecated(since = "1.4.0", note = "please use `set_sign_positive` instead")] + pub fn set_sign(&mut self, positive: bool) { + self.set_sign_positive(positive); + } + + /// An optimized method for changing the sign of a decimal number. + /// + /// # Arguments + /// + /// * `positive`: true if the resulting decimal should be positive. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let mut one = Decimal::ONE; + /// one.set_sign_positive(false); + /// assert_eq!(one.to_string(), "-1"); + /// ``` + #[inline(always)] + pub fn set_sign_positive(&mut self, positive: bool) { + if positive { + self.flags &= UNSIGN_MASK; + } else { + self.flags |= SIGN_MASK; + } + } + + /// An optimized method for changing the sign of a decimal number. + /// + /// # Arguments + /// + /// * `negative`: true if the resulting decimal should be negative. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let mut one = Decimal::ONE; + /// one.set_sign_negative(true); + /// assert_eq!(one.to_string(), "-1"); + /// ``` + #[inline(always)] + pub fn set_sign_negative(&mut self, negative: bool) { + self.set_sign_positive(!negative); + } + + /// An optimized method for changing the scale of a decimal number. + /// + /// # Arguments + /// + /// * `scale`: the new scale of the number + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// # fn main() -> Result<(), rust_decimal::Error> { + /// let mut one = Decimal::ONE; + /// one.set_scale(5)?; + /// assert_eq!(one.to_string(), "0.00001"); + /// # Ok(()) + /// # } + /// ``` + pub fn set_scale(&mut self, scale: u32) -> Result<(), Error> { + if scale > MAX_PRECISION_U32 { + return Err(Error::ScaleExceedsMaximumPrecision(scale)); + } + self.flags = (scale << SCALE_SHIFT) | (self.flags & SIGN_MASK); + Ok(()) + } + + /// Modifies the `Decimal` towards the desired scale, attempting to do so without changing the + /// underlying number itself. + /// + /// Setting the scale to something less then the current `Decimal`s scale will + /// cause the newly created `Decimal` to perform rounding using the `MidpointAwayFromZero` strategy. + /// + /// Scales greater than the maximum precision that can be represented by `Decimal` will be + /// automatically rounded to either `Decimal::MAX_PRECISION` or the maximum precision that can + /// be represented with the given mantissa. + /// + /// # Arguments + /// * `scale`: The desired scale to use for the new `Decimal` number. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// use rust_decimal_macros::dec; + /// + /// // Rescaling to a higher scale preserves the value + /// let mut number = dec!(1.123); + /// assert_eq!(number.scale(), 3); + /// number.rescale(6); + /// assert_eq!(number.to_string(), "1.123000"); + /// assert_eq!(number.scale(), 6); + /// + /// // Rescaling to a lower scale forces the number to be rounded + /// let mut number = dec!(1.45); + /// assert_eq!(number.scale(), 2); + /// number.rescale(1); + /// assert_eq!(number.to_string(), "1.5"); + /// assert_eq!(number.scale(), 1); + /// + /// // This function never fails. Consequently, if a scale is provided that is unable to be + /// // represented using the given mantissa, then the maximum possible scale is used. + /// let mut number = dec!(11.76470588235294); + /// assert_eq!(number.scale(), 14); + /// number.rescale(28); + /// // A scale of 28 cannot be represented given this mantissa, however it was able to represent + /// // a number with a scale of 27 + /// assert_eq!(number.to_string(), "11.764705882352940000000000000"); + /// assert_eq!(number.scale(), 27); + /// ``` + pub fn rescale(&mut self, scale: u32) { + let mut array = [self.lo, self.mid, self.hi]; + let mut value_scale = self.scale(); + ops::array::rescale_internal(&mut array, &mut value_scale, scale); + self.lo = array[0]; + self.mid = array[1]; + self.hi = array[2]; + self.flags = flags(self.is_sign_negative(), value_scale); + } + + /// Returns a serialized version of the decimal number. + /// The resulting byte array will have the following representation: + /// + /// * Bytes 1-4: flags + /// * Bytes 5-8: lo portion of `m` + /// * Bytes 9-12: mid portion of `m` + /// * Bytes 13-16: high portion of `m` + #[must_use] + pub const fn serialize(&self) -> [u8; 16] { + [ + (self.flags & U8_MASK) as u8, + ((self.flags >> 8) & U8_MASK) as u8, + ((self.flags >> 16) & U8_MASK) as u8, + ((self.flags >> 24) & U8_MASK) as u8, + (self.lo & U8_MASK) as u8, + ((self.lo >> 8) & U8_MASK) as u8, + ((self.lo >> 16) & U8_MASK) as u8, + ((self.lo >> 24) & U8_MASK) as u8, + (self.mid & U8_MASK) as u8, + ((self.mid >> 8) & U8_MASK) as u8, + ((self.mid >> 16) & U8_MASK) as u8, + ((self.mid >> 24) & U8_MASK) as u8, + (self.hi & U8_MASK) as u8, + ((self.hi >> 8) & U8_MASK) as u8, + ((self.hi >> 16) & U8_MASK) as u8, + ((self.hi >> 24) & U8_MASK) as u8, + ] + } + + /// Deserializes the given bytes into a decimal number. + /// The deserialized byte representation must be 16 bytes and adhere to the following convention: + /// + /// * Bytes 1-4: flags + /// * Bytes 5-8: lo portion of `m` + /// * Bytes 9-12: mid portion of `m` + /// * Bytes 13-16: high portion of `m` + #[must_use] + pub fn deserialize(bytes: [u8; 16]) -> Decimal { + // We can bound flags by a bitwise mask to correspond to: + // Bits 0-15: unused + // Bits 16-23: Contains "e", a value between 0-28 that indicates the scale + // Bits 24-30: unused + // Bit 31: the sign of the Decimal value, 0 meaning positive and 1 meaning negative. + let mut raw = Decimal { + flags: ((bytes[0] as u32) | (bytes[1] as u32) << 8 | (bytes[2] as u32) << 16 | (bytes[3] as u32) << 24) + & 0x801F_0000, + lo: (bytes[4] as u32) | (bytes[5] as u32) << 8 | (bytes[6] as u32) << 16 | (bytes[7] as u32) << 24, + mid: (bytes[8] as u32) | (bytes[9] as u32) << 8 | (bytes[10] as u32) << 16 | (bytes[11] as u32) << 24, + hi: (bytes[12] as u32) | (bytes[13] as u32) << 8 | (bytes[14] as u32) << 16 | (bytes[15] as u32) << 24, + }; + // Scale must be bound to maximum precision. Only two values can be greater than this + if raw.scale() > MAX_PRECISION_U32 { + let mut bits = raw.mantissa_array3(); + let remainder = match raw.scale() { + 29 => crate::ops::array::div_by_1x(&mut bits, 1), + 30 => crate::ops::array::div_by_1x(&mut bits, 2), + 31 => crate::ops::array::div_by_1x(&mut bits, 3), + _ => 0, + }; + if remainder >= 5 { + ops::array::add_one_internal(&mut bits); + } + raw.lo = bits[0]; + raw.mid = bits[1]; + raw.hi = bits[2]; + raw.flags = flags(raw.is_sign_negative(), MAX_PRECISION_U32); + } + raw + } + + /// Returns `true` if the decimal is negative. + #[deprecated(since = "0.6.3", note = "please use `is_sign_negative` instead")] + #[must_use] + pub fn is_negative(&self) -> bool { + self.is_sign_negative() + } + + /// Returns `true` if the decimal is positive. + #[deprecated(since = "0.6.3", note = "please use `is_sign_positive` instead")] + #[must_use] + pub fn is_positive(&self) -> bool { + self.is_sign_positive() + } + + /// Returns `true` if the sign bit of the decimal is negative. + /// + /// # Example + /// ``` + /// # use rust_decimal::prelude::*; + /// # + /// assert_eq!(true, Decimal::new(-1, 0).is_sign_negative()); + /// assert_eq!(false, Decimal::new(1, 0).is_sign_negative()); + /// ``` + #[inline(always)] + #[must_use] + pub const fn is_sign_negative(&self) -> bool { + self.flags & SIGN_MASK > 0 + } + + /// Returns `true` if the sign bit of the decimal is positive. + /// + /// # Example + /// ``` + /// # use rust_decimal::prelude::*; + /// # + /// assert_eq!(false, Decimal::new(-1, 0).is_sign_positive()); + /// assert_eq!(true, Decimal::new(1, 0).is_sign_positive()); + /// ``` + #[inline(always)] + #[must_use] + pub const fn is_sign_positive(&self) -> bool { + self.flags & SIGN_MASK == 0 + } + + /// Returns the minimum possible number that `Decimal` can represent. + #[deprecated(since = "1.12.0", note = "Use the associated constant Decimal::MIN")] + #[must_use] + pub const fn min_value() -> Decimal { + MIN + } + + /// Returns the maximum possible number that `Decimal` can represent. + #[deprecated(since = "1.12.0", note = "Use the associated constant Decimal::MAX")] + #[must_use] + pub const fn max_value() -> Decimal { + MAX + } + + /// Returns a new `Decimal` integral with no fractional portion. + /// This is a true truncation whereby no rounding is performed. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let pi = Decimal::new(3141, 3); + /// let trunc = Decimal::new(3, 0); + /// // note that it returns a decimal + /// assert_eq!(pi.trunc(), trunc); + /// ``` + #[must_use] + pub fn trunc(&self) -> Decimal { + let mut scale = self.scale(); + if scale == 0 { + // Nothing to do + return *self; + } + let mut working = [self.lo, self.mid, self.hi]; + while scale > 0 { + // We're removing precision, so we don't care about overflow + if scale < 10 { + ops::array::div_by_u32(&mut working, POWERS_10[scale as usize]); + break; + } else { + ops::array::div_by_u32(&mut working, POWERS_10[9]); + // Only 9 as this array starts with 1 + scale -= 9; + } + } + Decimal { + lo: working[0], + mid: working[1], + hi: working[2], + flags: flags(self.is_sign_negative(), 0), + } + } + + /// Returns a new `Decimal` representing the fractional portion of the number. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let pi = Decimal::new(3141, 3); + /// let fract = Decimal::new(141, 3); + /// // note that it returns a decimal + /// assert_eq!(pi.fract(), fract); + /// ``` + #[must_use] + pub fn fract(&self) -> Decimal { + // This is essentially the original number minus the integral. + // Could possibly be optimized in the future + *self - self.trunc() + } + + /// Computes the absolute value of `self`. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let num = Decimal::new(-3141, 3); + /// assert_eq!(num.abs().to_string(), "3.141"); + /// ``` + #[must_use] + pub fn abs(&self) -> Decimal { + let mut me = *self; + me.set_sign_positive(true); + me + } + + /// Returns the largest integer less than or equal to a number. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let num = Decimal::new(3641, 3); + /// assert_eq!(num.floor().to_string(), "3"); + /// ``` + #[must_use] + pub fn floor(&self) -> Decimal { + let scale = self.scale(); + if scale == 0 { + // Nothing to do + return *self; + } + + // Opportunity for optimization here + let floored = self.trunc(); + if self.is_sign_negative() && !self.fract().is_zero() { + floored - ONE + } else { + floored + } + } + + /// Returns the smallest integer greater than or equal to a number. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let num = Decimal::new(3141, 3); + /// assert_eq!(num.ceil().to_string(), "4"); + /// let num = Decimal::new(3, 0); + /// assert_eq!(num.ceil().to_string(), "3"); + /// ``` + #[must_use] + pub fn ceil(&self) -> Decimal { + let scale = self.scale(); + if scale == 0 { + // Nothing to do + return *self; + } + + // Opportunity for optimization here + if self.is_sign_positive() && !self.fract().is_zero() { + self.trunc() + ONE + } else { + self.trunc() + } + } + + /// Returns the maximum of the two numbers. + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let x = Decimal::new(1, 0); + /// let y = Decimal::new(2, 0); + /// assert_eq!(y, x.max(y)); + /// ``` + #[must_use] + pub fn max(self, other: Decimal) -> Decimal { + if self < other { + other + } else { + self + } + } + + /// Returns the minimum of the two numbers. + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// let x = Decimal::new(1, 0); + /// let y = Decimal::new(2, 0); + /// assert_eq!(x, x.min(y)); + /// ``` + #[must_use] + pub fn min(self, other: Decimal) -> Decimal { + if self > other { + other + } else { + self + } + } + + /// Strips any trailing zero's from a `Decimal` and converts -0 to 0. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # fn main() -> Result<(), rust_decimal::Error> { + /// let number = Decimal::from_str("3.100")?; + /// assert_eq!(number.normalize().to_string(), "3.1"); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn normalize(&self) -> Decimal { + let mut result = *self; + result.normalize_assign(); + result + } + + /// An in place version of `normalize`. Strips any trailing zero's from a `Decimal` and converts -0 to 0. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # fn main() -> Result<(), rust_decimal::Error> { + /// let mut number = Decimal::from_str("3.100")?; + /// assert_eq!(number.to_string(), "3.100"); + /// number.normalize_assign(); + /// assert_eq!(number.to_string(), "3.1"); + /// # Ok(()) + /// # } + /// ``` + pub fn normalize_assign(&mut self) { + if self.is_zero() { + self.flags = 0; + return; + } + + let mut scale = self.scale(); + if scale == 0 { + return; + } + + let mut result = self.mantissa_array3(); + let mut working = self.mantissa_array3(); + while scale > 0 { + if ops::array::div_by_u32(&mut working, 10) > 0 { + break; + } + scale -= 1; + result.copy_from_slice(&working); + } + self.lo = result[0]; + self.mid = result[1]; + self.hi = result[2]; + self.flags = flags(self.is_sign_negative(), scale); + } + + /// Returns a new `Decimal` number with no fractional portion (i.e. an integer). + /// Rounding currently follows "Bankers Rounding" rules. e.g. 6.5 -> 6, 7.5 -> 8 + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # + /// // Demonstrating bankers rounding... + /// let number_down = Decimal::new(65, 1); + /// let number_up = Decimal::new(75, 1); + /// assert_eq!(number_down.round().to_string(), "6"); + /// assert_eq!(number_up.round().to_string(), "8"); + /// ``` + #[must_use] + pub fn round(&self) -> Decimal { + self.round_dp(0) + } + + /// Returns a new `Decimal` number with the specified number of decimal points for fractional + /// portion. + /// Rounding is performed using the provided [`RoundingStrategy`] + /// + /// # Arguments + /// * `dp`: the number of decimal points to round to. + /// * `strategy`: the [`RoundingStrategy`] to use. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::{Decimal, RoundingStrategy}; + /// # use rust_decimal_macros::dec; + /// # + /// let tax = dec!(3.4395); + /// assert_eq!(tax.round_dp_with_strategy(2, RoundingStrategy::MidpointAwayFromZero).to_string(), "3.44"); + /// ``` + #[must_use] + pub fn round_dp_with_strategy(&self, dp: u32, strategy: RoundingStrategy) -> Decimal { + // Short circuit for zero + if self.is_zero() { + return Decimal { + lo: 0, + mid: 0, + hi: 0, + flags: flags(self.is_sign_negative(), dp), + }; + } + + let old_scale = self.scale(); + + // return early if decimal has a smaller number of fractional places than dp + // e.g. 2.51 rounded to 3 decimal places is 2.51 + if old_scale <= dp { + return *self; + } + + let mut value = [self.lo, self.mid, self.hi]; + let mut value_scale = self.scale(); + let negative = self.is_sign_negative(); + + value_scale -= dp; + + // Rescale to zero so it's easier to work with + while value_scale > 0 { + if value_scale < 10 { + ops::array::div_by_u32(&mut value, POWERS_10[value_scale as usize]); + value_scale = 0; + } else { + ops::array::div_by_u32(&mut value, POWERS_10[9]); + value_scale -= 9; + } + } + + // Do some midpoint rounding checks + // We're actually doing two things here. + // 1. Figuring out midpoint rounding when we're right on the boundary. e.g. 2.50000 + // 2. Figuring out whether to add one or not e.g. 2.51 + // For this, we need to figure out the fractional portion that is additional to + // the rounded number. e.g. for 0.12345 rounding to 2dp we'd want 345. + // We're doing the equivalent of losing precision (e.g. to get 0.12) + // then increasing the precision back up to 0.12000 + let mut offset = [self.lo, self.mid, self.hi]; + let mut diff = old_scale - dp; + + while diff > 0 { + if diff < 10 { + ops::array::div_by_u32(&mut offset, POWERS_10[diff as usize]); + break; + } else { + ops::array::div_by_u32(&mut offset, POWERS_10[9]); + // Only 9 as this array starts with 1 + diff -= 9; + } + } + + let mut diff = old_scale - dp; + + while diff > 0 { + if diff < 10 { + ops::array::mul_by_u32(&mut offset, POWERS_10[diff as usize]); + break; + } else { + ops::array::mul_by_u32(&mut offset, POWERS_10[9]); + // Only 9 as this array starts with 1 + diff -= 9; + } + } + + let mut decimal_portion = [self.lo, self.mid, self.hi]; + ops::array::sub_by_internal(&mut decimal_portion, &offset); + + // If the decimal_portion is zero then we round based on the other data + let mut cap = [5, 0, 0]; + for _ in 0..(old_scale - dp - 1) { + ops::array::mul_by_u32(&mut cap, 10); + } + let order = ops::array::cmp_internal(&decimal_portion, &cap); + + #[allow(deprecated)] + match strategy { + RoundingStrategy::BankersRounding | RoundingStrategy::MidpointNearestEven => { + match order { + Ordering::Equal => { + if (value[0] & 1) == 1 { + ops::array::add_one_internal(&mut value); + } + } + Ordering::Greater => { + // Doesn't matter about the decimal portion + ops::array::add_one_internal(&mut value); + } + _ => {} + } + } + RoundingStrategy::RoundHalfDown | RoundingStrategy::MidpointTowardZero => { + if let Ordering::Greater = order { + ops::array::add_one_internal(&mut value); + } + } + RoundingStrategy::RoundHalfUp | RoundingStrategy::MidpointAwayFromZero => { + // when Ordering::Equal, decimal_portion is 0.5 exactly + // when Ordering::Greater, decimal_portion is > 0.5 + match order { + Ordering::Equal => { + ops::array::add_one_internal(&mut value); + } + Ordering::Greater => { + // Doesn't matter about the decimal portion + ops::array::add_one_internal(&mut value); + } + _ => {} + } + } + RoundingStrategy::RoundUp | RoundingStrategy::AwayFromZero => { + if !ops::array::is_all_zero(&decimal_portion) { + ops::array::add_one_internal(&mut value); + } + } + RoundingStrategy::ToPositiveInfinity => { + if !negative && !ops::array::is_all_zero(&decimal_portion) { + ops::array::add_one_internal(&mut value); + } + } + RoundingStrategy::ToNegativeInfinity => { + if negative && !ops::array::is_all_zero(&decimal_portion) { + ops::array::add_one_internal(&mut value); + } + } + RoundingStrategy::RoundDown | RoundingStrategy::ToZero => (), + } + + Decimal::from_parts(value[0], value[1], value[2], negative, dp) + } + + /// Returns a new `Decimal` number with the specified number of decimal points for fractional portion. + /// Rounding currently follows "Bankers Rounding" rules. e.g. 6.5 -> 6, 7.5 -> 8 + /// + /// # Arguments + /// * `dp`: the number of decimal points to round to. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// # use rust_decimal_macros::dec; + /// # + /// let pi = dec!(3.1415926535897932384626433832); + /// assert_eq!(pi.round_dp(2).to_string(), "3.14"); + /// ``` + #[must_use] + pub fn round_dp(&self, dp: u32) -> Decimal { + self.round_dp_with_strategy(dp, RoundingStrategy::MidpointNearestEven) + } + + /// Returns `Some(Decimal)` number rounded to the specified number of significant digits. If + /// the resulting number is unable to be represented by the `Decimal` number then `None` will + /// be returned. + /// When the number of significant figures of the `Decimal` being rounded is greater than the requested + /// number of significant digits then rounding will be performed using `MidpointNearestEven` strategy. + /// + /// # Arguments + /// * `digits`: the number of significant digits to round to. + /// + /// # Remarks + /// A significant figure is determined using the following rules: + /// 1. Non-zero digits are always significant. + /// 2. Zeros between non-zero digits are always significant. + /// 3. Leading zeros are never significant. + /// 4. Trailing zeros are only significant if the number contains a decimal point. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// use rust_decimal_macros::dec; + /// + /// let value = dec!(305.459); + /// assert_eq!(value.round_sf(0), Some(dec!(0))); + /// assert_eq!(value.round_sf(1), Some(dec!(300))); + /// assert_eq!(value.round_sf(2), Some(dec!(310))); + /// assert_eq!(value.round_sf(3), Some(dec!(305))); + /// assert_eq!(value.round_sf(4), Some(dec!(305.5))); + /// assert_eq!(value.round_sf(5), Some(dec!(305.46))); + /// assert_eq!(value.round_sf(6), Some(dec!(305.459))); + /// assert_eq!(value.round_sf(7), Some(dec!(305.4590))); + /// assert_eq!(Decimal::MAX.round_sf(1), None); + /// + /// let value = dec!(0.012301); + /// assert_eq!(value.round_sf(3), Some(dec!(0.0123))); + /// ``` + #[must_use] + pub fn round_sf(&self, digits: u32) -> Option<Decimal> { + self.round_sf_with_strategy(digits, RoundingStrategy::MidpointNearestEven) + } + + /// Returns `Some(Decimal)` number rounded to the specified number of significant digits. If + /// the resulting number is unable to be represented by the `Decimal` number then `None` will + /// be returned. + /// When the number of significant figures of the `Decimal` being rounded is greater than the requested + /// number of significant digits then rounding will be performed using the provided [RoundingStrategy]. + /// + /// # Arguments + /// * `digits`: the number of significant digits to round to. + /// * `strategy`: if required, the rounding strategy to use. + /// + /// # Remarks + /// A significant figure is determined using the following rules: + /// 1. Non-zero digits are always significant. + /// 2. Zeros between non-zero digits are always significant. + /// 3. Leading zeros are never significant. + /// 4. Trailing zeros are only significant if the number contains a decimal point. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::{Decimal, RoundingStrategy}; + /// use rust_decimal_macros::dec; + /// + /// let value = dec!(305.459); + /// assert_eq!(value.round_sf_with_strategy(0, RoundingStrategy::ToZero), Some(dec!(0))); + /// assert_eq!(value.round_sf_with_strategy(1, RoundingStrategy::ToZero), Some(dec!(300))); + /// assert_eq!(value.round_sf_with_strategy(2, RoundingStrategy::ToZero), Some(dec!(300))); + /// assert_eq!(value.round_sf_with_strategy(3, RoundingStrategy::ToZero), Some(dec!(305))); + /// assert_eq!(value.round_sf_with_strategy(4, RoundingStrategy::ToZero), Some(dec!(305.4))); + /// assert_eq!(value.round_sf_with_strategy(5, RoundingStrategy::ToZero), Some(dec!(305.45))); + /// assert_eq!(value.round_sf_with_strategy(6, RoundingStrategy::ToZero), Some(dec!(305.459))); + /// assert_eq!(value.round_sf_with_strategy(7, RoundingStrategy::ToZero), Some(dec!(305.4590))); + /// assert_eq!(Decimal::MAX.round_sf_with_strategy(1, RoundingStrategy::ToZero), Some(dec!(70000000000000000000000000000))); + /// + /// let value = dec!(0.012301); + /// assert_eq!(value.round_sf_with_strategy(3, RoundingStrategy::AwayFromZero), Some(dec!(0.0124))); + /// ``` + #[must_use] + pub fn round_sf_with_strategy(&self, digits: u32, strategy: RoundingStrategy) -> Option<Decimal> { + if self.is_zero() || digits == 0 { + return Some(Decimal::ZERO); + } + + // We start by grabbing the mantissa and figuring out how many significant figures it is + // made up of. We do this by just dividing by 10 and checking remainders - effectively + // we're performing a naive log10. + let mut working = self.mantissa_array3(); + let mut mantissa_sf = 0; + while !ops::array::is_all_zero(&working) { + let _remainder = ops::array::div_by_u32(&mut working, 10u32); + mantissa_sf += 1; + if working[2] == 0 && working[1] == 0 && working[0] == 1 { + mantissa_sf += 1; + break; + } + } + let scale = self.scale(); + + match digits.cmp(&mantissa_sf) { + Ordering::Greater => { + // If we're requesting a higher number of significant figures, we rescale + let mut array = [self.lo, self.mid, self.hi]; + let mut value_scale = scale; + ops::array::rescale_internal(&mut array, &mut value_scale, scale + digits - mantissa_sf); + Some(Decimal { + lo: array[0], + mid: array[1], + hi: array[2], + flags: flags(self.is_sign_negative(), value_scale), + }) + } + Ordering::Less => { + // We're requesting a lower number of significant digits. + let diff = mantissa_sf - digits; + // If the diff is greater than the scale we're focused on the integral. Otherwise, we can + // just round. + if diff > scale { + use crate::constants::BIG_POWERS_10; + // We need to adjust the integral portion. This also should be rounded, consequently + // we reduce the number down, round it, and then scale back up. + // E.g. If we have 305.459 scaling to a sf of 2 - we first reduce the number + // down to 30.5459, round it to 31 and then scale it back up to 310. + // Likewise, if we have 12301 scaling to a sf of 3 - we first reduce the number + // down to 123.01, round it to 123 and then scale it back up to 12300. + let mut num = *self; + let mut exp = (diff - scale) as usize; + while exp > 0 { + let pow; + if exp >= BIG_POWERS_10.len() { + pow = Decimal::from(BIG_POWERS_10[BIG_POWERS_10.len() - 1]); + exp -= BIG_POWERS_10.len(); + } else { + pow = Decimal::from(BIG_POWERS_10[exp - 1]); + exp = 0; + } + num = num.checked_div(pow)?; + } + let mut num = num.round_dp_with_strategy(0, strategy).trunc(); + let mut exp = (mantissa_sf - digits - scale) as usize; + while exp > 0 { + let pow; + if exp >= BIG_POWERS_10.len() { + pow = Decimal::from(BIG_POWERS_10[BIG_POWERS_10.len() - 1]); + exp -= BIG_POWERS_10.len(); + } else { + pow = Decimal::from(BIG_POWERS_10[exp - 1]); + exp = 0; + } + num = num.checked_mul(pow)?; + } + Some(num) + } else { + Some(self.round_dp_with_strategy(scale - diff, strategy)) + } + } + Ordering::Equal => { + // Case where significant figures = requested significant digits. + Some(*self) + } + } + } + + /// Convert `Decimal` to an internal representation of the underlying struct. This is useful + /// for debugging the internal state of the object. + /// + /// # Important Disclaimer + /// This is primarily intended for library maintainers. The internal representation of a + /// `Decimal` is considered "unstable" for public use. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::Decimal; + /// use rust_decimal_macros::dec; + /// + /// let pi = dec!(3.1415926535897932384626433832); + /// assert_eq!(format!("{:?}", pi), "3.1415926535897932384626433832"); + /// assert_eq!(format!("{:?}", pi.unpack()), "UnpackedDecimal { \ + /// negative: false, scale: 28, hi: 1703060790, mid: 185874565, lo: 1102470952 \ + /// }"); + /// ``` + #[must_use] + pub const fn unpack(&self) -> UnpackedDecimal { + UnpackedDecimal { + negative: self.is_sign_negative(), + scale: self.scale(), + hi: self.hi, + lo: self.lo, + mid: self.mid, + } + } + + #[inline(always)] + pub(crate) const fn lo(&self) -> u32 { + self.lo + } + + #[inline(always)] + pub(crate) const fn mid(&self) -> u32 { + self.mid + } + + #[inline(always)] + pub(crate) const fn hi(&self) -> u32 { + self.hi + } + + #[inline(always)] + pub(crate) const fn flags(&self) -> u32 { + self.flags + } + + #[inline(always)] + pub(crate) const fn mantissa_array3(&self) -> [u32; 3] { + [self.lo, self.mid, self.hi] + } + + #[inline(always)] + pub(crate) const fn mantissa_array4(&self) -> [u32; 4] { + [self.lo, self.mid, self.hi, 0] + } + + /// Parses a 32-bit float into a Decimal number whilst retaining any non-guaranteed precision. + /// + /// Typically when a float is parsed in Rust Decimal, any excess bits (after ~7.22 decimal points for + /// f32 as per IEEE-754) are removed due to any digits following this are considered an approximation + /// at best. This function bypasses this additional step and retains these excess bits. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # + /// // Usually floats are parsed leveraging float guarantees. i.e. 0.1_f32 => 0.1 + /// assert_eq!("0.1", Decimal::from_f32(0.1_f32).unwrap().to_string()); + /// + /// // Sometimes, we may want to represent the approximation exactly. + /// assert_eq!("0.100000001490116119384765625", Decimal::from_f32_retain(0.1_f32).unwrap().to_string()); + /// ``` + pub fn from_f32_retain(n: f32) -> Option<Self> { + from_f32(n, false) + } + + /// Parses a 64-bit float into a Decimal number whilst retaining any non-guaranteed precision. + /// + /// Typically when a float is parsed in Rust Decimal, any excess bits (after ~15.95 decimal points for + /// f64 as per IEEE-754) are removed due to any digits following this are considered an approximation + /// at best. This function bypasses this additional step and retains these excess bits. + /// + /// # Example + /// + /// ``` + /// # use rust_decimal::prelude::*; + /// # + /// // Usually floats are parsed leveraging float guarantees. i.e. 0.1_f64 => 0.1 + /// assert_eq!("0.1", Decimal::from_f64(0.1_f64).unwrap().to_string()); + /// + /// // Sometimes, we may want to represent the approximation exactly. + /// assert_eq!("0.1000000000000000055511151231", Decimal::from_f64_retain(0.1_f64).unwrap().to_string()); + /// ``` + pub fn from_f64_retain(n: f64) -> Option<Self> { + from_f64(n, false) + } +} + +impl Default for Decimal { + /// Returns the default value for a `Decimal` (equivalent to `Decimal::ZERO`). [Read more] + /// + /// [Read more]: core::default::Default#tymethod.default + #[inline] + fn default() -> Self { + ZERO + } +} + +pub(crate) enum CalculationResult { + Ok(Decimal), + Overflow, + DivByZero, +} + +#[inline] +const fn flags(neg: bool, scale: u32) -> u32 { + (scale << SCALE_SHIFT) | ((neg as u32) << SIGN_SHIFT) +} + +macro_rules! integer_docs { + ( true ) => { + " by truncating and returning the integer component" + }; + ( false ) => { + "" + }; +} + +// #[doc] attributes are formatted poorly with rustfmt so skip for now. +// See https://github.com/rust-lang/rustfmt/issues/5062 for more information. +#[rustfmt::skip] +macro_rules! impl_try_from_decimal { + ($TInto:ty, $conversion_fn:path, $additional_docs:expr) => { + #[doc = concat!( + "Try to convert a `Decimal` to `", + stringify!($TInto), + "`", + $additional_docs, + ".\n\nCan fail if the `Decimal` is out of range for `", + stringify!($TInto), + "`.", + )] + impl TryFrom<Decimal> for $TInto { + type Error = crate::Error; + + #[inline] + fn try_from(t: Decimal) -> Result<Self, Error> { + $conversion_fn(&t).ok_or_else(|| Error::ConversionTo(stringify!($TInto).into())) + } + } + }; +} + +impl_try_from_decimal!(f32, Decimal::to_f32, integer_docs!(false)); +impl_try_from_decimal!(f64, Decimal::to_f64, integer_docs!(false)); +impl_try_from_decimal!(isize, Decimal::to_isize, integer_docs!(true)); +impl_try_from_decimal!(i8, Decimal::to_i8, integer_docs!(true)); +impl_try_from_decimal!(i16, Decimal::to_i16, integer_docs!(true)); +impl_try_from_decimal!(i32, Decimal::to_i32, integer_docs!(true)); +impl_try_from_decimal!(i64, Decimal::to_i64, integer_docs!(true)); +impl_try_from_decimal!(i128, Decimal::to_i128, integer_docs!(true)); +impl_try_from_decimal!(usize, Decimal::to_usize, integer_docs!(true)); +impl_try_from_decimal!(u8, Decimal::to_u8, integer_docs!(true)); +impl_try_from_decimal!(u16, Decimal::to_u16, integer_docs!(true)); +impl_try_from_decimal!(u32, Decimal::to_u32, integer_docs!(true)); +impl_try_from_decimal!(u64, Decimal::to_u64, integer_docs!(true)); +impl_try_from_decimal!(u128, Decimal::to_u128, integer_docs!(true)); + +// #[doc] attributes are formatted poorly with rustfmt so skip for now. +// See https://github.com/rust-lang/rustfmt/issues/5062 for more information. +#[rustfmt::skip] +macro_rules! impl_try_from_primitive { + ($TFrom:ty, $conversion_fn:path $(, $err:expr)?) => { + #[doc = concat!( + "Try to convert a `", + stringify!($TFrom), + "` into a `Decimal`.\n\nCan fail if the value is out of range for `Decimal`." + )] + impl TryFrom<$TFrom> for Decimal { + type Error = crate::Error; + + #[inline] + fn try_from(t: $TFrom) -> Result<Self, Error> { + $conversion_fn(t) $( .ok_or_else(|| $err) )? + } + } + }; +} + +impl_try_from_primitive!(f32, Self::from_f32, Error::ConversionTo("Decimal".into())); +impl_try_from_primitive!(f64, Self::from_f64, Error::ConversionTo("Decimal".into())); +impl_try_from_primitive!(&str, core::str::FromStr::from_str); + +macro_rules! impl_from { + ($T:ty, $from_ty:path) => { + /// + /// Conversion to `Decimal`. + /// + impl core::convert::From<$T> for Decimal { + #[inline] + fn from(t: $T) -> Self { + $from_ty(t).unwrap() + } + } + }; +} + +impl_from!(isize, FromPrimitive::from_isize); +impl_from!(i8, FromPrimitive::from_i8); +impl_from!(i16, FromPrimitive::from_i16); +impl_from!(i32, FromPrimitive::from_i32); +impl_from!(i64, FromPrimitive::from_i64); +impl_from!(usize, FromPrimitive::from_usize); +impl_from!(u8, FromPrimitive::from_u8); +impl_from!(u16, FromPrimitive::from_u16); +impl_from!(u32, FromPrimitive::from_u32); +impl_from!(u64, FromPrimitive::from_u64); + +impl_from!(i128, FromPrimitive::from_i128); +impl_from!(u128, FromPrimitive::from_u128); + +impl Zero for Decimal { + fn zero() -> Decimal { + ZERO + } + + fn is_zero(&self) -> bool { + self.is_zero() + } +} + +impl One for Decimal { + fn one() -> Decimal { + ONE + } +} + +impl Signed for Decimal { + fn abs(&self) -> Self { + self.abs() + } + + fn abs_sub(&self, other: &Self) -> Self { + if self <= other { + ZERO + } else { + self.abs() + } + } + + fn signum(&self) -> Self { + if self.is_zero() { + ZERO + } else { + let mut value = ONE; + if self.is_sign_negative() { + value.set_sign_negative(true); + } + value + } + } + + fn is_positive(&self) -> bool { + self.is_sign_positive() + } + + fn is_negative(&self) -> bool { + self.is_sign_negative() + } +} + +impl Num for Decimal { + type FromStrRadixErr = Error; + + fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> { + Decimal::from_str_radix(str, radix) + } +} + +impl FromStr for Decimal { + type Err = Error; + + fn from_str(value: &str) -> Result<Decimal, Self::Err> { + crate::str::parse_str_radix_10(value) + } +} + +impl FromPrimitive for Decimal { + fn from_i32(n: i32) -> Option<Decimal> { + let flags: u32; + let value_copy: i64; + if n >= 0 { + flags = 0; + value_copy = n as i64; + } else { + flags = SIGN_MASK; + value_copy = -(n as i64); + } + Some(Decimal { + flags, + lo: value_copy as u32, + mid: 0, + hi: 0, + }) + } + + fn from_i64(n: i64) -> Option<Decimal> { + let flags: u32; + let value_copy: i128; + if n >= 0 { + flags = 0; + value_copy = n as i128; + } else { + flags = SIGN_MASK; + value_copy = -(n as i128); + } + Some(Decimal { + flags, + lo: value_copy as u32, + mid: (value_copy >> 32) as u32, + hi: 0, + }) + } + + fn from_i128(n: i128) -> Option<Decimal> { + let flags; + let unsigned; + if n >= 0 { + unsigned = n as u128; + flags = 0; + } else { + unsigned = -n as u128; + flags = SIGN_MASK; + }; + // Check if we overflow + if unsigned >> 96 != 0 { + return None; + } + Some(Decimal { + flags, + lo: unsigned as u32, + mid: (unsigned >> 32) as u32, + hi: (unsigned >> 64) as u32, + }) + } + + fn from_u32(n: u32) -> Option<Decimal> { + Some(Decimal { + flags: 0, + lo: n, + mid: 0, + hi: 0, + }) + } + + fn from_u64(n: u64) -> Option<Decimal> { + Some(Decimal { + flags: 0, + lo: n as u32, + mid: (n >> 32) as u32, + hi: 0, + }) + } + + fn from_u128(n: u128) -> Option<Decimal> { + // Check if we overflow + if n >> 96 != 0 { + return None; + } + Some(Decimal { + flags: 0, + lo: n as u32, + mid: (n >> 32) as u32, + hi: (n >> 64) as u32, + }) + } + + fn from_f32(n: f32) -> Option<Decimal> { + // By default, we remove excess bits. This allows 0.1_f64 == dec!(0.1). + from_f32(n, true) + } + + fn from_f64(n: f64) -> Option<Decimal> { + // By default, we remove excess bits. This allows 0.1_f64 == dec!(0.1). + from_f64(n, true) + } +} + +#[inline] +fn from_f64(n: f64, remove_excess_bits: bool) -> Option<Decimal> { + // Handle the case if it is NaN, Infinity or -Infinity + if !n.is_finite() { + return None; + } + + // It's a shame we can't use a union for this due to it being broken up by bits + // i.e. 1/11/52 (sign, exponent, mantissa) + // See https://en.wikipedia.org/wiki/IEEE_754-1985 + // n = (sign*-1) * 2^exp * mantissa + // Decimal of course stores this differently... 10^-exp * significand + let raw = n.to_bits(); + let positive = (raw >> 63) == 0; + let biased_exponent = ((raw >> 52) & 0x7FF) as i32; + let mantissa = raw & 0x000F_FFFF_FFFF_FFFF; + + // Handle the special zero case + if biased_exponent == 0 && mantissa == 0 { + let mut zero = ZERO; + if !positive { + zero.set_sign_negative(true); + } + return Some(zero); + } + + // Get the bits and exponent2 + let mut exponent2 = biased_exponent - 1023; + let mut bits = [ + (mantissa & 0xFFFF_FFFF) as u32, + ((mantissa >> 32) & 0xFFFF_FFFF) as u32, + 0u32, + ]; + if biased_exponent == 0 { + // Denormalized number - correct the exponent + exponent2 += 1; + } else { + // Add extra hidden bit to mantissa + bits[1] |= 0x0010_0000; + } + + // The act of copying a mantissa as integer bits is equivalent to shifting + // left the mantissa 52 bits. The exponent is reduced to compensate. + exponent2 -= 52; + + // Convert to decimal + base2_to_decimal(&mut bits, exponent2, positive, true, remove_excess_bits) +} + +#[inline] +fn from_f32(n: f32, remove_excess_bits: bool) -> Option<Decimal> { + // Handle the case if it is NaN, Infinity or -Infinity + if !n.is_finite() { + return None; + } + + // It's a shame we can't use a union for this due to it being broken up by bits + // i.e. 1/8/23 (sign, exponent, mantissa) + // See https://en.wikipedia.org/wiki/IEEE_754-1985 + // n = (sign*-1) * 2^exp * mantissa + // Decimal of course stores this differently... 10^-exp * significand + let raw = n.to_bits(); + let positive = (raw >> 31) == 0; + let biased_exponent = ((raw >> 23) & 0xFF) as i32; + let mantissa = raw & 0x007F_FFFF; + + // Handle the special zero case + if biased_exponent == 0 && mantissa == 0 { + let mut zero = ZERO; + if !positive { + zero.set_sign_negative(true); + } + return Some(zero); + } + + // Get the bits and exponent2 + let mut exponent2 = biased_exponent - 127; + let mut bits = [mantissa, 0u32, 0u32]; + if biased_exponent == 0 { + // Denormalized number - correct the exponent + exponent2 += 1; + } else { + // Add extra hidden bit to mantissa + bits[0] |= 0x0080_0000; + } + + // The act of copying a mantissa as integer bits is equivalent to shifting + // left the mantissa 23 bits. The exponent is reduced to compensate. + exponent2 -= 23; + + // Convert to decimal + base2_to_decimal(&mut bits, exponent2, positive, false, remove_excess_bits) +} + +fn base2_to_decimal( + bits: &mut [u32; 3], + exponent2: i32, + positive: bool, + is64: bool, + remove_excess_bits: bool, +) -> Option<Decimal> { + // 2^exponent2 = (10^exponent2)/(5^exponent2) + // = (5^-exponent2)*(10^exponent2) + let mut exponent5 = -exponent2; + let mut exponent10 = exponent2; // Ultimately, we want this for the scale + + while exponent5 > 0 { + // Check to see if the mantissa is divisible by 2 + if bits[0] & 0x1 == 0 { + exponent10 += 1; + exponent5 -= 1; + + // We can divide by 2 without losing precision + let hi_carry = bits[2] & 0x1 == 1; + bits[2] >>= 1; + let mid_carry = bits[1] & 0x1 == 1; + bits[1] = (bits[1] >> 1) | if hi_carry { SIGN_MASK } else { 0 }; + bits[0] = (bits[0] >> 1) | if mid_carry { SIGN_MASK } else { 0 }; + } else { + // The mantissa is NOT divisible by 2. Therefore the mantissa should + // be multiplied by 5, unless the multiplication overflows. + exponent5 -= 1; + + let mut temp = [bits[0], bits[1], bits[2]]; + if ops::array::mul_by_u32(&mut temp, 5) == 0 { + // Multiplication succeeded without overflow, so copy result back + bits[0] = temp[0]; + bits[1] = temp[1]; + bits[2] = temp[2]; + } else { + // Multiplication by 5 overflows. The mantissa should be divided + // by 2, and therefore will lose significant digits. + exponent10 += 1; + + // Shift right + let hi_carry = bits[2] & 0x1 == 1; + bits[2] >>= 1; + let mid_carry = bits[1] & 0x1 == 1; + bits[1] = (bits[1] >> 1) | if hi_carry { SIGN_MASK } else { 0 }; + bits[0] = (bits[0] >> 1) | if mid_carry { SIGN_MASK } else { 0 }; + } + } + } + + // In order to divide the value by 5, it is best to multiply by 2/10. + // Therefore, exponent10 is decremented, and the mantissa should be multiplied by 2 + while exponent5 < 0 { + if bits[2] & SIGN_MASK == 0 { + // No far left bit, the mantissa can withstand a shift-left without overflowing + exponent10 -= 1; + exponent5 += 1; + ops::array::shl1_internal(bits, 0); + } else { + // The mantissa would overflow if shifted. Therefore it should be + // directly divided by 5. This will lose significant digits, unless + // by chance the mantissa happens to be divisible by 5. + exponent5 += 1; + ops::array::div_by_u32(bits, 5); + } + } + + // At this point, the mantissa has assimilated the exponent5, but + // exponent10 might not be suitable for assignment. exponent10 must be + // in the range [-MAX_PRECISION..0], so the mantissa must be scaled up or + // down appropriately. + while exponent10 > 0 { + // In order to bring exponent10 down to 0, the mantissa should be + // multiplied by 10 to compensate. If the exponent10 is too big, this + // will cause the mantissa to overflow. + if ops::array::mul_by_u32(bits, 10) == 0 { + exponent10 -= 1; + } else { + // Overflowed - return? + return None; + } + } + + // In order to bring exponent up to -MAX_PRECISION, the mantissa should + // be divided by 10 to compensate. If the exponent10 is too small, this + // will cause the mantissa to underflow and become 0. + while exponent10 < -(MAX_PRECISION_U32 as i32) { + let rem10 = ops::array::div_by_u32(bits, 10); + exponent10 += 1; + if ops::array::is_all_zero(bits) { + // Underflow, unable to keep dividing + exponent10 = 0; + } else if rem10 >= 5 { + ops::array::add_one_internal(bits); + } + } + + if remove_excess_bits { + // This step is required in order to remove excess bits of precision from the + // end of the bit representation, down to the precision guaranteed by the + // floating point number (see IEEE-754). + if is64 { + // Guaranteed to approx 15/16 dp + while exponent10 < 0 && (bits[2] != 0 || (bits[1] & 0xFFF0_0000) != 0) { + let rem10 = ops::array::div_by_u32(bits, 10); + exponent10 += 1; + if rem10 >= 5 { + ops::array::add_one_internal(bits); + } + } + } else { + // Guaranteed to about 7/8 dp + while exponent10 < 0 && ((bits[0] & 0xFF00_0000) != 0 || bits[1] != 0 || bits[2] != 0) { + let rem10 = ops::array::div_by_u32(bits, 10); + exponent10 += 1; + if rem10 >= 5 { + ops::array::add_one_internal(bits); + } + } + } + + // Remove multiples of 10 from the representation + while exponent10 < 0 { + let mut temp = [bits[0], bits[1], bits[2]]; + let remainder = ops::array::div_by_u32(&mut temp, 10); + if remainder == 0 { + exponent10 += 1; + bits[0] = temp[0]; + bits[1] = temp[1]; + bits[2] = temp[2]; + } else { + break; + } + } + } + + Some(Decimal { + lo: bits[0], + mid: bits[1], + hi: bits[2], + flags: flags(!positive, -exponent10 as u32), + }) +} + +impl ToPrimitive for Decimal { + fn to_i64(&self) -> Option<i64> { + let d = self.trunc(); + // If it is in the hi bit then it is a clear overflow. + if d.hi != 0 { + // Overflow + return None; + } + let negative = self.is_sign_negative(); + + // A bit more convoluted in terms of checking when it comes to the hi bit due to twos-complement + if d.mid & 0x8000_0000 > 0 { + if negative && d.mid == 0x8000_0000 && d.lo == 0 { + // We do this because below we try to convert the i64 to a positive first - of which + // doesn't fit into an i64. + return Some(i64::MIN); + } + return None; + } + + let raw: i64 = (i64::from(d.mid) << 32) | i64::from(d.lo); + if negative { + Some(raw.neg()) + } else { + Some(raw) + } + } + + fn to_i128(&self) -> Option<i128> { + let d = self.trunc(); + let raw: i128 = ((i128::from(d.hi) << 64) | i128::from(d.mid) << 32) | i128::from(d.lo); + if self.is_sign_negative() { + Some(-raw) + } else { + Some(raw) + } + } + + fn to_u64(&self) -> Option<u64> { + if self.is_sign_negative() { + return None; + } + + let d = self.trunc(); + if d.hi != 0 { + // Overflow + return None; + } + + Some((u64::from(d.mid) << 32) | u64::from(d.lo)) + } + + fn to_u128(&self) -> Option<u128> { + if self.is_sign_negative() { + return None; + } + + let d = self.trunc(); + Some((u128::from(d.hi) << 64) | (u128::from(d.mid) << 32) | u128::from(d.lo)) + } + + fn to_f64(&self) -> Option<f64> { + if self.scale() == 0 { + // If scale is zero, we are storing a 96-bit integer value, that would + // always fit into i128, which in turn is always representable as f64, + // albeit with loss of precision for values outside of -2^53..2^53 range. + let integer = self.to_i128(); + integer.map(|i| i as f64) + } else { + let sign: f64 = if self.is_sign_negative() { -1.0 } else { 1.0 }; + let mut mantissa: u128 = self.lo.into(); + mantissa |= (self.mid as u128) << 32; + mantissa |= (self.hi as u128) << 64; + // scale is at most 28, so this fits comfortably into a u128. + let scale = self.scale(); + let precision: u128 = 10_u128.pow(scale); + let integral_part = mantissa / precision; + let frac_part = mantissa % precision; + let frac_f64 = (frac_part as f64) / (precision as f64); + let value = sign * ((integral_part as f64) + frac_f64); + let round_to = 10f64.powi(self.scale() as i32); + Some((value * round_to).round() / round_to) + } + } +} + +impl fmt::Display for Decimal { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + let (rep, additional) = crate::str::to_str_internal(self, false, f.precision()); + if let Some(additional) = additional { + let value = [rep.as_str(), "0".repeat(additional).as_str()].concat(); + f.pad_integral(self.is_sign_positive(), "", value.as_str()) + } else { + f.pad_integral(self.is_sign_positive(), "", rep.as_str()) + } + } +} + +impl fmt::Debug for Decimal { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + fmt::Display::fmt(self, f) + } +} + +impl fmt::LowerExp for Decimal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + crate::str::fmt_scientific_notation(self, "e", f) + } +} + +impl fmt::UpperExp for Decimal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + crate::str::fmt_scientific_notation(self, "E", f) + } +} + +impl Neg for Decimal { + type Output = Decimal; + + fn neg(self) -> Decimal { + let mut copy = self; + copy.set_sign_negative(self.is_sign_positive()); + copy + } +} + +impl<'a> Neg for &'a Decimal { + type Output = Decimal; + + fn neg(self) -> Decimal { + Decimal { + flags: flags(!self.is_sign_negative(), self.scale()), + hi: self.hi, + lo: self.lo, + mid: self.mid, + } + } +} + +impl AddAssign for Decimal { + fn add_assign(&mut self, other: Decimal) { + let result = self.add(other); + self.lo = result.lo; + self.mid = result.mid; + self.hi = result.hi; + self.flags = result.flags; + } +} + +impl<'a> AddAssign<&'a Decimal> for Decimal { + fn add_assign(&mut self, other: &'a Decimal) { + Decimal::add_assign(self, *other) + } +} + +impl<'a> AddAssign<Decimal> for &'a mut Decimal { + fn add_assign(&mut self, other: Decimal) { + Decimal::add_assign(*self, other) + } +} + +impl<'a> AddAssign<&'a Decimal> for &'a mut Decimal { + fn add_assign(&mut self, other: &'a Decimal) { + Decimal::add_assign(*self, *other) + } +} + +impl SubAssign for Decimal { + fn sub_assign(&mut self, other: Decimal) { + let result = self.sub(other); + self.lo = result.lo; + self.mid = result.mid; + self.hi = result.hi; + self.flags = result.flags; + } +} + +impl<'a> SubAssign<&'a Decimal> for Decimal { + fn sub_assign(&mut self, other: &'a Decimal) { + Decimal::sub_assign(self, *other) + } +} + +impl<'a> SubAssign<Decimal> for &'a mut Decimal { + fn sub_assign(&mut self, other: Decimal) { + Decimal::sub_assign(*self, other) + } +} + +impl<'a> SubAssign<&'a Decimal> for &'a mut Decimal { + fn sub_assign(&mut self, other: &'a Decimal) { + Decimal::sub_assign(*self, *other) + } +} + +impl MulAssign for Decimal { + fn mul_assign(&mut self, other: Decimal) { + let result = self.mul(other); + self.lo = result.lo; + self.mid = result.mid; + self.hi = result.hi; + self.flags = result.flags; + } +} + +impl<'a> MulAssign<&'a Decimal> for Decimal { + fn mul_assign(&mut self, other: &'a Decimal) { + Decimal::mul_assign(self, *other) + } +} + +impl<'a> MulAssign<Decimal> for &'a mut Decimal { + fn mul_assign(&mut self, other: Decimal) { + Decimal::mul_assign(*self, other) + } +} + +impl<'a> MulAssign<&'a Decimal> for &'a mut Decimal { + fn mul_assign(&mut self, other: &'a Decimal) { + Decimal::mul_assign(*self, *other) + } +} + +impl DivAssign for Decimal { + fn div_assign(&mut self, other: Decimal) { + let result = self.div(other); + self.lo = result.lo; + self.mid = result.mid; + self.hi = result.hi; + self.flags = result.flags; + } +} + +impl<'a> DivAssign<&'a Decimal> for Decimal { + fn div_assign(&mut self, other: &'a Decimal) { + Decimal::div_assign(self, *other) + } +} + +impl<'a> DivAssign<Decimal> for &'a mut Decimal { + fn div_assign(&mut self, other: Decimal) { + Decimal::div_assign(*self, other) + } +} + +impl<'a> DivAssign<&'a Decimal> for &'a mut Decimal { + fn div_assign(&mut self, other: &'a Decimal) { + Decimal::div_assign(*self, *other) + } +} + +impl RemAssign for Decimal { + fn rem_assign(&mut self, other: Decimal) { + let result = self.rem(other); + self.lo = result.lo; + self.mid = result.mid; + self.hi = result.hi; + self.flags = result.flags; + } +} + +impl<'a> RemAssign<&'a Decimal> for Decimal { + fn rem_assign(&mut self, other: &'a Decimal) { + Decimal::rem_assign(self, *other) + } +} + +impl<'a> RemAssign<Decimal> for &'a mut Decimal { + fn rem_assign(&mut self, other: Decimal) { + Decimal::rem_assign(*self, other) + } +} + +impl<'a> RemAssign<&'a Decimal> for &'a mut Decimal { + fn rem_assign(&mut self, other: &'a Decimal) { + Decimal::rem_assign(*self, *other) + } +} + +impl PartialEq for Decimal { + #[inline] + fn eq(&self, other: &Decimal) -> bool { + self.cmp(other) == Equal + } +} + +impl Eq for Decimal {} + +impl Hash for Decimal { + fn hash<H: Hasher>(&self, state: &mut H) { + let n = self.normalize(); + n.lo.hash(state); + n.mid.hash(state); + n.hi.hash(state); + n.flags.hash(state); + } +} + +impl PartialOrd for Decimal { + #[inline] + fn partial_cmp(&self, other: &Decimal) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for Decimal { + fn cmp(&self, other: &Decimal) -> Ordering { + ops::cmp_impl(self, other) + } +} + +impl Product for Decimal { + /// Panics if out-of-bounds + fn product<I: Iterator<Item = Decimal>>(iter: I) -> Self { + let mut product = ONE; + for i in iter { + product *= i; + } + product + } +} + +impl<'a> Product<&'a Decimal> for Decimal { + /// Panics if out-of-bounds + fn product<I: Iterator<Item = &'a Decimal>>(iter: I) -> Self { + let mut product = ONE; + for i in iter { + product *= i; + } + product + } +} + +impl Sum for Decimal { + fn sum<I: Iterator<Item = Decimal>>(iter: I) -> Self { + let mut sum = ZERO; + for i in iter { + sum += i; + } + sum + } +} + +impl<'a> Sum<&'a Decimal> for Decimal { + fn sum<I: Iterator<Item = &'a Decimal>>(iter: I) -> Self { + let mut sum = ZERO; + for i in iter { + sum += i; + } + sum + } +} diff --git a/third_party/rust/rust_decimal/src/error.rs b/third_party/rust/rust_decimal/src/error.rs new file mode 100644 index 0000000000..5e5969e42b --- /dev/null +++ b/third_party/rust/rust_decimal/src/error.rs @@ -0,0 +1,69 @@ +use crate::{constants::MAX_PRECISION_U32, Decimal}; +use alloc::string::String; +use core::fmt; + +/// Error type for the library. +#[derive(Clone, Debug, PartialEq)] +pub enum Error { + /// A generic error from Rust Decimal with the `String` containing more information as to what + /// went wrong. + /// + /// This is a legacy/deprecated error type retained for backwards compatibility. + ErrorString(String), + /// The value provided exceeds `Decimal::MAX`. + ExceedsMaximumPossibleValue, + /// The value provided is less than `Decimal::MIN`. + LessThanMinimumPossibleValue, + /// An underflow is when there are more fractional digits than can be represented within `Decimal`. + Underflow, + /// The scale provided exceeds the maximum scale that `Decimal` can represent. + ScaleExceedsMaximumPrecision(u32), + /// Represents a failure to convert to/from `Decimal` to the specified type. This is typically + /// due to type constraints (e.g. `Decimal::MAX` cannot be converted into `i32`). + ConversionTo(String), +} + +impl<S> From<S> for Error +where + S: Into<String>, +{ + #[inline] + fn from(from: S) -> Self { + Self::ErrorString(from.into()) + } +} + +#[cold] +pub(crate) fn tail_error(from: &'static str) -> Result<Decimal, Error> { + Err(from.into()) +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Self::ErrorString(ref err) => f.pad(err), + Self::ExceedsMaximumPossibleValue => { + write!(f, "Number exceeds maximum value that can be represented.") + } + Self::LessThanMinimumPossibleValue => { + write!(f, "Number less than minimum value that can be represented.") + } + Self::Underflow => { + write!(f, "Number has a high precision that can not be represented.") + } + Self::ScaleExceedsMaximumPrecision(ref scale) => { + write!( + f, + "Scale exceeds the maximum precision allowed: {} > {}", + scale, MAX_PRECISION_U32 + ) + } + Self::ConversionTo(ref type_name) => { + write!(f, "Error while converting to {}", type_name) + } + } + } +} diff --git a/third_party/rust/rust_decimal/src/fuzz.rs b/third_party/rust/rust_decimal/src/fuzz.rs new file mode 100644 index 0000000000..14f2e6004e --- /dev/null +++ b/third_party/rust/rust_decimal/src/fuzz.rs @@ -0,0 +1,14 @@ +use crate::Decimal; + +use arbitrary::{Arbitrary, Result as ArbitraryResult, Unstructured}; + +impl Arbitrary<'_> for crate::Decimal { + fn arbitrary(u: &mut Unstructured<'_>) -> ArbitraryResult<Self> { + let lo = u32::arbitrary(u)?; + let mid = u32::arbitrary(u)?; + let hi = u32::arbitrary(u)?; + let negative = bool::arbitrary(u)?; + let scale = u32::arbitrary(u)?; + Ok(Decimal::from_parts(lo, mid, hi, negative, scale)) + } +} diff --git a/third_party/rust/rust_decimal/src/lib.rs b/third_party/rust/rust_decimal/src/lib.rs new file mode 100644 index 0000000000..84239fd6a8 --- /dev/null +++ b/third_party/rust/rust_decimal/src/lib.rs @@ -0,0 +1,79 @@ +#![doc = include_str!(concat!(env!("OUT_DIR"), "/README-lib.md"))] +#![forbid(unsafe_code)] +#![deny(clippy::print_stdout, clippy::print_stderr)] +#![cfg_attr(not(feature = "std"), no_std)] +extern crate alloc; + +mod constants; +mod decimal; +mod error; +mod ops; +mod str; + +// We purposely place this here for documentation ordering +mod arithmetic_impls; + +#[cfg(feature = "rust-fuzz")] +mod fuzz; +#[cfg(feature = "maths")] +mod maths; +#[cfg(any(feature = "db-diesel1-mysql", feature = "db-diesel2-mysql"))] +mod mysql; +#[cfg(any( + feature = "db-tokio-postgres", + feature = "db-postgres", + feature = "db-diesel1-postgres", + feature = "db-diesel2-postgres", +))] +mod postgres; +#[cfg(feature = "rand")] +mod rand; +#[cfg(feature = "rocket-traits")] +mod rocket; +#[cfg(all( + feature = "serde", + not(any( + feature = "serde-with-str", + feature = "serde-with-float", + feature = "serde-with-arbitrary-precision" + )) +))] +mod serde; +/// Serde specific functionality to customize how a decimal is serialized/deserialized (`serde_with`) +#[cfg(all( + feature = "serde", + any( + feature = "serde-with-str", + feature = "serde-with-float", + feature = "serde-with-arbitrary-precision" + ) +))] +pub mod serde; + +pub use decimal::{Decimal, RoundingStrategy}; +pub use error::Error; +#[cfg(feature = "maths")] +pub use maths::MathematicalOps; + +/// A convenience module appropriate for glob imports (`use rust_decimal::prelude::*;`). +pub mod prelude { + #[cfg(feature = "maths")] + pub use crate::maths::MathematicalOps; + pub use crate::{Decimal, RoundingStrategy}; + pub use core::str::FromStr; + pub use num_traits::{FromPrimitive, One, Signed, ToPrimitive, Zero}; +} + +#[cfg(all(feature = "diesel1", not(feature = "diesel2")))] +#[macro_use] +extern crate diesel1 as diesel; + +#[cfg(feature = "diesel2")] +extern crate diesel2 as diesel; + +/// Shortcut for `core::result::Result<T, rust_decimal::Error>`. Useful to distinguish +/// between `rust_decimal` and `std` types. +pub type Result<T> = core::result::Result<T, Error>; + +// #[cfg(feature = "legacy-ops")] +// compiler_error!("legacy-ops has been removed as 1.x"); diff --git a/third_party/rust/rust_decimal/src/maths.rs b/third_party/rust/rust_decimal/src/maths.rs new file mode 100644 index 0000000000..c402453002 --- /dev/null +++ b/third_party/rust/rust_decimal/src/maths.rs @@ -0,0 +1,785 @@ +use crate::prelude::*; +use num_traits::pow::Pow; + +// Tolerance for inaccuracies when calculating exp +const EXP_TOLERANCE: Decimal = Decimal::from_parts(2, 0, 0, false, 7); +// Approximation of 1/ln(10) = 0.4342944819032518276511289189 +const LN10_INVERSE: Decimal = Decimal::from_parts_raw(1763037029, 1670682625, 235431510, 1835008); +// Total iterations of taylor series for Trig. +const TRIG_SERIES_UPPER_BOUND: usize = 6; +// PI / 8 +const EIGHTH_PI: Decimal = Decimal::from_parts_raw(2822163429, 3244459792, 212882598, 1835008); + +// Table representing {index}! +const FACTORIAL: [Decimal; 28] = [ + Decimal::from_parts(1, 0, 0, false, 0), + Decimal::from_parts(1, 0, 0, false, 0), + Decimal::from_parts(2, 0, 0, false, 0), + Decimal::from_parts(6, 0, 0, false, 0), + Decimal::from_parts(24, 0, 0, false, 0), + // 5! + Decimal::from_parts(120, 0, 0, false, 0), + Decimal::from_parts(720, 0, 0, false, 0), + Decimal::from_parts(5040, 0, 0, false, 0), + Decimal::from_parts(40320, 0, 0, false, 0), + Decimal::from_parts(362880, 0, 0, false, 0), + // 10! + Decimal::from_parts(3628800, 0, 0, false, 0), + Decimal::from_parts(39916800, 0, 0, false, 0), + Decimal::from_parts(479001600, 0, 0, false, 0), + Decimal::from_parts(1932053504, 1, 0, false, 0), + Decimal::from_parts(1278945280, 20, 0, false, 0), + // 15! + Decimal::from_parts(2004310016, 304, 0, false, 0), + Decimal::from_parts(2004189184, 4871, 0, false, 0), + Decimal::from_parts(4006445056, 82814, 0, false, 0), + Decimal::from_parts(3396534272, 1490668, 0, false, 0), + Decimal::from_parts(109641728, 28322707, 0, false, 0), + // 20! + Decimal::from_parts(2192834560, 566454140, 0, false, 0), + Decimal::from_parts(3099852800, 3305602358, 2, false, 0), + Decimal::from_parts(3772252160, 4003775155, 60, false, 0), + Decimal::from_parts(862453760, 1892515369, 1401, false, 0), + Decimal::from_parts(3519021056, 2470695900, 33634, false, 0), + // 25! + Decimal::from_parts(2076180480, 1637855376, 840864, false, 0), + Decimal::from_parts(2441084928, 3929534124, 21862473, false, 0), + Decimal::from_parts(1484783616, 3018206259, 590286795, false, 0), +]; + +/// Trait exposing various mathematical operations that can be applied using a Decimal. This is only +/// present when the `maths` feature has been enabled. +pub trait MathematicalOps { + /// The estimated exponential function, e<sup>x</sup>. Stops calculating when it is within + /// tolerance of roughly `0.0000002`. + fn exp(&self) -> Decimal; + + /// The estimated exponential function, e<sup>x</sup>. Stops calculating when it is within + /// tolerance of roughly `0.0000002`. Returns `None` on overflow. + fn checked_exp(&self) -> Option<Decimal>; + + /// The estimated exponential function, e<sup>x</sup> using the `tolerance` provided as a hint + /// as to when to stop calculating. A larger tolerance will cause the number to stop calculating + /// sooner at the potential cost of a slightly less accurate result. + fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal; + + /// The estimated exponential function, e<sup>x</sup> using the `tolerance` provided as a hint + /// as to when to stop calculating. A larger tolerance will cause the number to stop calculating + /// sooner at the potential cost of a slightly less accurate result. + /// Returns `None` on overflow. + fn checked_exp_with_tolerance(&self, tolerance: Decimal) -> Option<Decimal>; + + /// Raise self to the given integer exponent: x<sup>y</sup> + fn powi(&self, exp: i64) -> Decimal; + + /// Raise self to the given integer exponent x<sup>y</sup> returning `None` on overflow. + fn checked_powi(&self, exp: i64) -> Option<Decimal>; + + /// Raise self to the given unsigned integer exponent: x<sup>y</sup> + fn powu(&self, exp: u64) -> Decimal; + + /// Raise self to the given unsigned integer exponent x<sup>y</sup> returning `None` on overflow. + fn checked_powu(&self, exp: u64) -> Option<Decimal>; + + /// Raise self to the given floating point exponent: x<sup>y</sup> + fn powf(&self, exp: f64) -> Decimal; + + /// Raise self to the given floating point exponent x<sup>y</sup> returning `None` on overflow. + fn checked_powf(&self, exp: f64) -> Option<Decimal>; + + /// Raise self to the given Decimal exponent: x<sup>y</sup>. If `exp` is not whole then the approximation + /// e<sup>y*ln(x)</sup> is used. + fn powd(&self, exp: Decimal) -> Decimal; + + /// Raise self to the given Decimal exponent x<sup>y</sup> returning `None` on overflow. + /// If `exp` is not whole then the approximation e<sup>y*ln(x)</sup> is used. + fn checked_powd(&self, exp: Decimal) -> Option<Decimal>; + + /// The square root of a Decimal. Uses a standard Babylonian method. + fn sqrt(&self) -> Option<Decimal>; + + /// Calculates the natural logarithm for a Decimal calculated using Taylor's series. + fn ln(&self) -> Decimal; + + /// Calculates the checked natural logarithm for a Decimal calculated using Taylor's series. + /// Returns `None` for negative numbers or zero. + fn checked_ln(&self) -> Option<Decimal>; + + /// Calculates the base 10 logarithm of a specified Decimal number. + fn log10(&self) -> Decimal; + + /// Calculates the checked base 10 logarithm of a specified Decimal number. + /// Returns `None` for negative numbers or zero. + fn checked_log10(&self) -> Option<Decimal>; + + /// Abramowitz Approximation of Error Function from [wikipedia](https://en.wikipedia.org/wiki/Error_function#Numerical_approximations) + fn erf(&self) -> Decimal; + + /// The Cumulative distribution function for a Normal distribution + fn norm_cdf(&self) -> Decimal; + + /// The Probability density function for a Normal distribution. + fn norm_pdf(&self) -> Decimal; + + /// The Probability density function for a Normal distribution returning `None` on overflow. + fn checked_norm_pdf(&self) -> Option<Decimal>; + + /// Computes the sine of a number (in radians). + /// Panics upon overflow. + fn sin(&self) -> Decimal; + + /// Computes the checked sine of a number (in radians). + fn checked_sin(&self) -> Option<Decimal>; + + /// Computes the cosine of a number (in radians). + /// Panics upon overflow. + fn cos(&self) -> Decimal; + + /// Computes the checked cosine of a number (in radians). + fn checked_cos(&self) -> Option<Decimal>; + + /// Computes the tangent of a number (in radians). + /// Panics upon overflow or upon approaching a limit. + fn tan(&self) -> Decimal; + + /// Computes the checked tangent of a number (in radians). + /// Returns None on limit. + fn checked_tan(&self) -> Option<Decimal>; +} + +impl MathematicalOps for Decimal { + fn exp(&self) -> Decimal { + self.exp_with_tolerance(EXP_TOLERANCE) + } + + fn checked_exp(&self) -> Option<Decimal> { + self.checked_exp_with_tolerance(EXP_TOLERANCE) + } + + fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal { + match self.checked_exp_with_tolerance(tolerance) { + Some(d) => d, + None => { + if self.is_sign_negative() { + panic!("Exp underflowed") + } else { + panic!("Exp overflowed") + } + } + } + } + + fn checked_exp_with_tolerance(&self, tolerance: Decimal) -> Option<Decimal> { + if self.is_zero() { + return Some(Decimal::ONE); + } + if self.is_sign_negative() { + let mut flipped = *self; + flipped.set_sign_positive(true); + let exp = flipped.checked_exp_with_tolerance(tolerance)?; + return Decimal::ONE.checked_div(exp); + } + + let mut term = *self; + let mut result = self.checked_add(Decimal::ONE)?; + + for factorial in FACTORIAL.iter().skip(2) { + term = self.checked_mul(term)?; + let next = result + (term / factorial); + let diff = (next - result).abs(); + result = next; + if diff <= tolerance { + break; + } + } + + Some(result) + } + + fn powi(&self, exp: i64) -> Decimal { + match self.checked_powi(exp) { + Some(result) => result, + None => panic!("Pow overflowed"), + } + } + + fn checked_powi(&self, exp: i64) -> Option<Decimal> { + // For negative exponents we change x^-y into 1 / x^y. + // Otherwise, we calculate a standard unsigned exponent + if exp >= 0 { + return self.checked_powu(exp as u64); + } + + // Get the unsigned exponent + let exp = exp.unsigned_abs(); + let pow = match self.checked_powu(exp) { + Some(v) => v, + None => return None, + }; + Decimal::ONE.checked_div(pow) + } + + fn powu(&self, exp: u64) -> Decimal { + match self.checked_powu(exp) { + Some(result) => result, + None => panic!("Pow overflowed"), + } + } + + fn checked_powu(&self, exp: u64) -> Option<Decimal> { + match exp { + 0 => Some(Decimal::ONE), + 1 => Some(*self), + 2 => self.checked_mul(*self), + _ => { + // Get the squared value + let squared = match self.checked_mul(*self) { + Some(s) => s, + None => return None, + }; + // Square self once and make an infinite sized iterator of the square. + let iter = core::iter::repeat(squared); + + // We then take half of the exponent to create a finite iterator and then multiply those together. + let mut product = Decimal::ONE; + for x in iter.take((exp >> 1) as usize) { + match product.checked_mul(x) { + Some(r) => product = r, + None => return None, + }; + } + + // If the exponent is odd we still need to multiply once more + if exp & 0x1 > 0 { + match self.checked_mul(product) { + Some(p) => product = p, + None => return None, + } + } + product.normalize_assign(); + Some(product) + } + } + } + + fn powf(&self, exp: f64) -> Decimal { + match self.checked_powf(exp) { + Some(result) => result, + None => panic!("Pow overflowed"), + } + } + + fn checked_powf(&self, exp: f64) -> Option<Decimal> { + let exp = match Decimal::from_f64(exp) { + Some(f) => f, + None => return None, + }; + self.checked_powd(exp) + } + + fn powd(&self, exp: Decimal) -> Decimal { + match self.checked_powd(exp) { + Some(result) => result, + None => panic!("Pow overflowed"), + } + } + + fn checked_powd(&self, exp: Decimal) -> Option<Decimal> { + if exp.is_zero() { + return Some(Decimal::ONE); + } + if self.is_zero() { + return Some(Decimal::ZERO); + } + if self.is_one() { + return Some(Decimal::ONE); + } + if exp.is_one() { + return Some(*self); + } + + // If the scale is 0 then it's a trivial calculation + let exp = exp.normalize(); + if exp.scale() == 0 { + if exp.mid() != 0 || exp.hi() != 0 { + // Exponent way too big + return None; + } + + return if exp.is_sign_negative() { + self.checked_powi(-(exp.lo() as i64)) + } else { + self.checked_powu(exp.lo() as u64) + }; + } + + // We do some approximations since we've got a decimal exponent. + // For positive bases: a^b = exp(b*ln(a)) + let negative = self.is_sign_negative(); + let e = match self.abs().ln().checked_mul(exp) { + Some(e) => e, + None => return None, + }; + let mut result = e.checked_exp()?; + result.set_sign_negative(negative); + Some(result) + } + + fn sqrt(&self) -> Option<Decimal> { + if self.is_sign_negative() { + return None; + } + + if self.is_zero() { + return Some(Decimal::ZERO); + } + + // Start with an arbitrary number as the first guess + let mut result = self / Decimal::TWO; + // Too small to represent, so we start with self + // Future iterations could actually avoid using a decimal altogether and use a buffered + // vector, only combining back into a decimal on return + if result.is_zero() { + result = *self; + } + let mut last = result + Decimal::ONE; + + // Keep going while the difference is larger than the tolerance + let mut circuit_breaker = 0; + while last != result { + circuit_breaker += 1; + assert!(circuit_breaker < 1000, "geo mean circuit breaker"); + + last = result; + result = (result + self / result) / Decimal::TWO; + } + + Some(result) + } + + #[cfg(feature = "maths-nopanic")] + fn ln(&self) -> Decimal { + match self.checked_ln() { + Some(result) => result, + None => Decimal::ZERO, + } + } + + #[cfg(not(feature = "maths-nopanic"))] + fn ln(&self) -> Decimal { + match self.checked_ln() { + Some(result) => result, + None => { + if self.is_sign_negative() { + panic!("Unable to calculate ln for negative numbers") + } else if self.is_zero() { + panic!("Unable to calculate ln for zero") + } else { + panic!("Calculation of ln failed for unknown reasons") + } + } + } + } + + fn checked_ln(&self) -> Option<Decimal> { + if self.is_sign_negative() || self.is_zero() { + return None; + } + if self.is_one() { + return Some(Decimal::ZERO); + } + + // Approximate using Taylor Series + let mut x = *self; + let mut count = 0; + while x >= Decimal::ONE { + x *= Decimal::E_INVERSE; + count += 1; + } + while x <= Decimal::E_INVERSE { + x *= Decimal::E; + count -= 1; + } + x -= Decimal::ONE; + if x.is_zero() { + return Some(Decimal::new(count, 0)); + } + let mut result = Decimal::ZERO; + let mut iteration = 0; + let mut y = Decimal::ONE; + let mut last = Decimal::ONE; + while last != result && iteration < 100 { + iteration += 1; + last = result; + y *= -x; + result += y / Decimal::new(iteration, 0); + } + Some(Decimal::new(count, 0) - result) + } + + #[cfg(feature = "maths-nopanic")] + fn log10(&self) -> Decimal { + match self.checked_log10() { + Some(result) => result, + None => Decimal::ZERO, + } + } + + #[cfg(not(feature = "maths-nopanic"))] + fn log10(&self) -> Decimal { + match self.checked_log10() { + Some(result) => result, + None => { + if self.is_sign_negative() { + panic!("Unable to calculate log10 for negative numbers") + } else if self.is_zero() { + panic!("Unable to calculate log10 for zero") + } else { + panic!("Calculation of log10 failed for unknown reasons") + } + } + } + } + + fn checked_log10(&self) -> Option<Decimal> { + use crate::ops::array::{div_by_u32, is_all_zero}; + // Early exits + if self.is_sign_negative() || self.is_zero() { + return None; + } + if self.is_one() { + return Some(Decimal::ZERO); + } + + // This uses a very basic method for calculating log10. We know the following is true: + // log10(n) = ln(n) / ln(10) + // From this we can perform some small optimizations: + // 1. ln(10) is a constant + // 2. Multiplication is faster than division, so we can pre-calculate the constant 1/ln(10) + // This allows us to then simplify log10(n) to: + // log10(n) = C * ln(n) + + // Before doing all of this however, we see if there are simple calculations to be made. + let scale = self.scale(); + let mut working = self.mantissa_array3(); + + // Check for scales less than 1 as an early exit + if scale > 0 && working[2] == 0 && working[1] == 0 && working[0] == 1 { + return Some(Decimal::from_parts(scale, 0, 0, true, 0)); + } + + // Loop for detecting bordering base 10 values + let mut result = 0; + let mut base10 = true; + while !is_all_zero(&working) { + let remainder = div_by_u32(&mut working, 10u32); + if remainder != 0 { + base10 = false; + break; + } + result += 1; + if working[2] == 0 && working[1] == 0 && working[0] == 1 { + break; + } + } + if base10 { + return Some((result - scale as i32).into()); + } + + self.checked_ln().map(|result| LN10_INVERSE * result) + } + + fn erf(&self) -> Decimal { + if self.is_sign_positive() { + let one = &Decimal::ONE; + + let xa1 = self * Decimal::from_parts(705230784, 0, 0, false, 10); + let xa2 = self.powi(2) * Decimal::from_parts(422820123, 0, 0, false, 10); + let xa3 = self.powi(3) * Decimal::from_parts(92705272, 0, 0, false, 10); + let xa4 = self.powi(4) * Decimal::from_parts(1520143, 0, 0, false, 10); + let xa5 = self.powi(5) * Decimal::from_parts(2765672, 0, 0, false, 10); + let xa6 = self.powi(6) * Decimal::from_parts(430638, 0, 0, false, 10); + + let sum = one + xa1 + xa2 + xa3 + xa4 + xa5 + xa6; + one - (one / sum.powi(16)) + } else { + -self.abs().erf() + } + } + + fn norm_cdf(&self) -> Decimal { + (Decimal::ONE + (self / Decimal::from_parts(2318911239, 3292722, 0, false, 16)).erf()) / Decimal::TWO + } + + fn norm_pdf(&self) -> Decimal { + match self.checked_norm_pdf() { + Some(d) => d, + None => panic!("Norm Pdf overflowed"), + } + } + + fn checked_norm_pdf(&self) -> Option<Decimal> { + let sqrt2pi = Decimal::from_parts_raw(2133383024, 2079885984, 1358845910, 1835008); + let factor = -self.checked_powi(2)?; + let factor = factor.checked_div(Decimal::TWO)?; + factor.checked_exp()?.checked_div(sqrt2pi) + } + + fn sin(&self) -> Decimal { + match self.checked_sin() { + Some(x) => x, + None => panic!("Sin overflowed"), + } + } + + fn checked_sin(&self) -> Option<Decimal> { + if self.is_zero() { + return Some(Decimal::ZERO); + } + if self.is_sign_negative() { + // -Sin(-x) + return (-self).checked_sin().map(|x| -x); + } + if self >= &Decimal::TWO_PI { + // Reduce large numbers early - we can do this using rem to constrain to a range + let adjusted = self.checked_rem(Decimal::TWO_PI)?; + return adjusted.checked_sin(); + } + if self >= &Decimal::PI { + // -Sin(x-π) + return (self - Decimal::PI).checked_sin().map(|x| -x); + } + if self >= &Decimal::QUARTER_PI { + // Cos(π2-x) + return (Decimal::HALF_PI - self).checked_cos(); + } + + // Taylor series: + // ∑(n=0 to ∞) : ((−1)^n / (2n + 1)!) * x^(2n + 1) , x∈R + // First few expansions: + // x^1/1! - x^3/3! + x^5/5! - x^7/7! + x^9/9! + let mut result = Decimal::ZERO; + for n in 0..TRIG_SERIES_UPPER_BOUND { + let x = 2 * n + 1; + let element = self.checked_powi(x as i64)?.checked_div(FACTORIAL[x])?; + if n & 0x1 == 0 { + result += element; + } else { + result -= element; + } + } + Some(result) + } + + fn cos(&self) -> Decimal { + match self.checked_cos() { + Some(x) => x, + None => panic!("Cos overflowed"), + } + } + + fn checked_cos(&self) -> Option<Decimal> { + if self.is_zero() { + return Some(Decimal::ONE); + } + if self.is_sign_negative() { + // Cos(-x) + return (-self).checked_cos(); + } + if self >= &Decimal::TWO_PI { + // Reduce large numbers early - we can do this using rem to constrain to a range + let adjusted = self.checked_rem(Decimal::TWO_PI)?; + return adjusted.checked_cos(); + } + if self >= &Decimal::PI { + // -Cos(x-π) + return (self - Decimal::PI).checked_cos().map(|x| -x); + } + if self >= &Decimal::QUARTER_PI { + // Sin(π2-x) + return (Decimal::HALF_PI - self).checked_sin(); + } + + // Taylor series: + // ∑(n=0 to ∞) : ((−1)^n / (2n)!) * x^(2n) , x∈R + // First few expansions: + // x^0/0! - x^2/2! + x^4/4! - x^6/6! + x^8/8! + let mut result = Decimal::ZERO; + for n in 0..TRIG_SERIES_UPPER_BOUND { + let x = 2 * n; + let element = self.checked_powi(x as i64)?.checked_div(FACTORIAL[x])?; + if n & 0x1 == 0 { + result += element; + } else { + result -= element; + } + } + Some(result) + } + + fn tan(&self) -> Decimal { + match self.checked_tan() { + Some(x) => x, + None => panic!("Tan overflowed"), + } + } + + fn checked_tan(&self) -> Option<Decimal> { + if self.is_zero() { + return Some(Decimal::ZERO); + } + if self.is_sign_negative() { + // -Tan(-x) + return (-self).checked_tan().map(|x| -x); + } + if self >= &Decimal::TWO_PI { + // Reduce large numbers early - we can do this using rem to constrain to a range + let adjusted = self.checked_rem(Decimal::TWO_PI)?; + return adjusted.checked_tan(); + } + // Reduce to 0 <= x <= PI + if self >= &Decimal::PI { + // Tan(x-π) + return (self - Decimal::PI).checked_tan(); + } + // Reduce to 0 <= x <= PI/2 + if self > &Decimal::HALF_PI { + // We can use the symmetrical function inside the first quadrant + // e.g. tan(x) = -tan((PI/2 - x) + PI/2) + return ((Decimal::HALF_PI - self) + Decimal::HALF_PI).checked_tan().map(|x| -x); + } + + // It has now been reduced to 0 <= x <= PI/2. If it is >= PI/4 we can make it even smaller + // by calculating tan(PI/2 - x) and taking the reciprocal + if self > &Decimal::QUARTER_PI { + return match (Decimal::HALF_PI - self).checked_tan() { + Some(x) => Decimal::ONE.checked_div(x), + None => None, + }; + } + + // Due the way that tan(x) sharply tends towards infinity, we try to optimize + // the resulting accuracy by using Trigonometric identity when > PI/8. We do this by + // replacing the angle with one that is half as big. + if self > &EIGHTH_PI { + // Work out tan(x/2) + let tan_half = (self / Decimal::TWO).checked_tan()?; + // Work out the dividend i.e. 2tan(x/2) + let dividend = Decimal::TWO.checked_mul(tan_half)?; + + // Work out the divisor i.e. 1 - tan^2(x/2) + let squared = tan_half.checked_mul(tan_half)?; + let divisor = Decimal::ONE - squared; + // Treat this as infinity + if divisor.is_zero() { + return None; + } + return dividend.checked_div(divisor); + } + + // Do a polynomial approximation based upon the Maclaurin series. + // This can be simplified to something like: + // + // ∑(n=1,3,5,7,9)(f(n)(0)/n!)x^n + // + // First few expansions (which we leverage): + // (f'(0)/1!)x^1 + (f'''(0)/3!)x^3 + (f'''''(0)/5!)x^5 + (f'''''''/7!)x^7 + // + // x + (1/3)x^3 + (2/15)x^5 + (17/315)x^7 + (62/2835)x^9 + (1382/155925)x^11 + // + // (Generated by https://www.wolframalpha.com/widgets/view.jsp?id=fe1ad8d4f5dbb3cb866d0c89beb527a6) + // The more terms, the better the accuracy. This generates accuracy within approx 10^-8 for angles + // less than PI/8. + const SERIES: [(Decimal, u64); 6] = [ + // 1 / 3 + (Decimal::from_parts_raw(89478485, 347537611, 180700362, 1835008), 3), + // 2 / 15 + (Decimal::from_parts_raw(894784853, 3574988881, 72280144, 1835008), 5), + // 17 / 315 + (Decimal::from_parts_raw(905437054, 3907911371, 2925624, 1769472), 7), + // 62 / 2835 + (Decimal::from_parts_raw(3191872741, 2108928381, 11855473, 1835008), 9), + // 1382 / 155925 + (Decimal::from_parts_raw(3482645539, 2612995122, 4804769, 1835008), 11), + // 21844 / 6081075 + (Decimal::from_parts_raw(4189029078, 2192791200, 1947296, 1835008), 13), + ]; + let mut result = *self; + for (fraction, pow) in SERIES { + result += fraction * self.powu(pow); + } + Some(result) + } +} + +impl Pow<Decimal> for Decimal { + type Output = Decimal; + + fn pow(self, rhs: Decimal) -> Self::Output { + MathematicalOps::powd(&self, rhs) + } +} + +impl Pow<u64> for Decimal { + type Output = Decimal; + + fn pow(self, rhs: u64) -> Self::Output { + MathematicalOps::powu(&self, rhs) + } +} + +impl Pow<i64> for Decimal { + type Output = Decimal; + + fn pow(self, rhs: i64) -> Self::Output { + MathematicalOps::powi(&self, rhs) + } +} + +impl Pow<f64> for Decimal { + type Output = Decimal; + + fn pow(self, rhs: f64) -> Self::Output { + MathematicalOps::powf(&self, rhs) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[cfg(not(feature = "std"))] + use alloc::string::ToString; + + #[test] + fn test_factorials() { + assert_eq!("1", FACTORIAL[0].to_string(), "0!"); + assert_eq!("1", FACTORIAL[1].to_string(), "1!"); + assert_eq!("2", FACTORIAL[2].to_string(), "2!"); + assert_eq!("6", FACTORIAL[3].to_string(), "3!"); + assert_eq!("24", FACTORIAL[4].to_string(), "4!"); + assert_eq!("120", FACTORIAL[5].to_string(), "5!"); + assert_eq!("720", FACTORIAL[6].to_string(), "6!"); + assert_eq!("5040", FACTORIAL[7].to_string(), "7!"); + assert_eq!("40320", FACTORIAL[8].to_string(), "8!"); + assert_eq!("362880", FACTORIAL[9].to_string(), "9!"); + assert_eq!("3628800", FACTORIAL[10].to_string(), "10!"); + assert_eq!("39916800", FACTORIAL[11].to_string(), "11!"); + assert_eq!("479001600", FACTORIAL[12].to_string(), "12!"); + assert_eq!("6227020800", FACTORIAL[13].to_string(), "13!"); + assert_eq!("87178291200", FACTORIAL[14].to_string(), "14!"); + assert_eq!("1307674368000", FACTORIAL[15].to_string(), "15!"); + assert_eq!("20922789888000", FACTORIAL[16].to_string(), "16!"); + assert_eq!("355687428096000", FACTORIAL[17].to_string(), "17!"); + assert_eq!("6402373705728000", FACTORIAL[18].to_string(), "18!"); + assert_eq!("121645100408832000", FACTORIAL[19].to_string(), "19!"); + assert_eq!("2432902008176640000", FACTORIAL[20].to_string(), "20!"); + assert_eq!("51090942171709440000", FACTORIAL[21].to_string(), "21!"); + assert_eq!("1124000727777607680000", FACTORIAL[22].to_string(), "22!"); + assert_eq!("25852016738884976640000", FACTORIAL[23].to_string(), "23!"); + assert_eq!("620448401733239439360000", FACTORIAL[24].to_string(), "24!"); + assert_eq!("15511210043330985984000000", FACTORIAL[25].to_string(), "25!"); + assert_eq!("403291461126605635584000000", FACTORIAL[26].to_string(), "26!"); + assert_eq!("10888869450418352160768000000", FACTORIAL[27].to_string(), "27!"); + } +} diff --git a/third_party/rust/rust_decimal/src/mysql.rs b/third_party/rust/rust_decimal/src/mysql.rs new file mode 100644 index 0000000000..6dc0db253d --- /dev/null +++ b/third_party/rust/rust_decimal/src/mysql.rs @@ -0,0 +1,241 @@ +use crate::Decimal; +use diesel::{ + deserialize::{self, FromSql}, + mysql::Mysql, + serialize::{self, IsNull, Output, ToSql}, + sql_types::Numeric, +}; +use std::io::Write; +use std::str::FromStr; + +#[cfg(all(feature = "diesel1", not(feature = "diesel2")))] +impl ToSql<Numeric, Mysql> for Decimal { + fn to_sql<W: Write>(&self, out: &mut Output<W, Mysql>) -> serialize::Result { + write!(out, "{}", *self).map(|_| IsNull::No).map_err(|e| e.into()) + } +} + +#[cfg(feature = "diesel2")] +impl ToSql<Numeric, Mysql> for Decimal { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result { + write!(out, "{}", *self).map(|_| IsNull::No).map_err(|e| e.into()) + } +} + +#[cfg(all(feature = "diesel1", not(feature = "diesel2")))] +impl FromSql<Numeric, Mysql> for Decimal { + fn from_sql(numeric: Option<&[u8]>) -> deserialize::Result<Self> { + // From what I can ascertain, MySQL simply reads from a string format for the Decimal type. + // Explicitly, it looks like it is length followed by the string. Regardless, we can leverage + // internal types. + let bytes = numeric.ok_or("Invalid decimal")?; + let s = std::str::from_utf8(bytes)?; + Decimal::from_str(s).map_err(|e| e.into()) + } +} + +#[cfg(feature = "diesel2")] +impl FromSql<Numeric, Mysql> for Decimal { + fn from_sql(numeric: diesel::mysql::MysqlValue) -> deserialize::Result<Self> { + // From what I can ascertain, MySQL simply reads from a string format for the Decimal type. + // Explicitly, it looks like it is length followed by the string. Regardless, we can leverage + // internal types. + let s = std::str::from_utf8(numeric.as_bytes())?; + Decimal::from_str(s).map_err(|e| e.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use diesel::deserialize::QueryableByName; + use diesel::prelude::*; + use diesel::row::NamedRow; + use diesel::sql_query; + use diesel::sql_types::Text; + + struct Test { + value: Decimal, + } + + struct NullableTest { + value: Option<Decimal>, + } + + pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[ + // precision, scale, sent, expected + (1, 0, "1", "1"), + (6, 2, "1", "1.00"), + (6, 2, "9999.99", "9999.99"), + (35, 6, "3950.123456", "3950.123456"), + (10, 2, "3950.123456", "3950.12"), + (35, 6, "3950", "3950.000000"), + (4, 0, "3950", "3950"), + (35, 6, "0.1", "0.100000"), + (35, 6, "0.01", "0.010000"), + (35, 6, "0.001", "0.001000"), + (35, 6, "0.0001", "0.000100"), + (35, 6, "0.00001", "0.000010"), + (35, 6, "0.000001", "0.000001"), + (35, 6, "1", "1.000000"), + (35, 6, "-100", "-100.000000"), + (35, 6, "-123.456", "-123.456000"), + (35, 6, "119996.25", "119996.250000"), + (35, 6, "1000000", "1000000.000000"), + (35, 6, "9999999.99999", "9999999.999990"), + (35, 6, "12340.56789", "12340.567890"), + ]; + + /// Gets the URL for connecting to MySQL for testing. Set the MYSQL_URL + /// environment variable to change from the default of "mysql://root@localhost/mysql". + fn get_mysql_url() -> String { + if let Ok(url) = std::env::var("MYSQL_URL") { + return url; + } + "mysql://root@127.0.0.1/mysql".to_string() + } + + #[cfg(all(feature = "diesel1", not(feature = "diesel2")))] + mod diesel1 { + use super::*; + + impl QueryableByName<Mysql> for Test { + fn build<R: NamedRow<Mysql>>(row: &R) -> deserialize::Result<Self> { + let value = row.get("value")?; + Ok(Test { value }) + } + } + + impl QueryableByName<Mysql> for NullableTest { + fn build<R: NamedRow<Mysql>>(row: &R) -> deserialize::Result<Self> { + let value = row.get("value")?; + Ok(NullableTest { value }) + } + } + + #[test] + fn test_null() { + let connection = diesel::MysqlConnection::establish(&get_mysql_url()).expect("Establish connection"); + + // Test NULL + let items: Vec<NullableTest> = sql_query("SELECT CAST(NULL AS DECIMAL) AS value") + .load(&connection) + .expect("Unable to query value"); + let result = items.first().unwrap().value; + assert_eq!(None, result); + } + + #[test] + fn read_numeric_type() { + let connection = diesel::MysqlConnection::establish(&get_mysql_url()).expect("Establish connection"); + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let items: Vec<Test> = sql_query(format!( + "SELECT CAST('{}' AS DECIMAL({}, {})) AS value", + sent, precision, scale + )) + .load(&connection) + .expect("Unable to query value"); + assert_eq!( + expected, + items.first().unwrap().value.to_string(), + "DECIMAL({}, {}) sent: {}", + precision, + scale, + sent + ); + } + } + + #[test] + fn write_numeric_type() { + let connection = diesel::MysqlConnection::establish(&get_mysql_url()).expect("Establish connection"); + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let items: Vec<Test> = + sql_query(format!("SELECT CAST(? AS DECIMAL({}, {})) AS value", precision, scale)) + .bind::<Text, _>(sent) + .load(&connection) + .expect("Unable to query value"); + assert_eq!( + expected, + items.first().unwrap().value.to_string(), + "DECIMAL({}, {}) sent: {}", + precision, + scale, + sent + ); + } + } + } + + #[cfg(feature = "diesel2")] + mod diesel2 { + use super::*; + + impl QueryableByName<Mysql> for Test { + fn build<'a>(row: &impl NamedRow<'a, Mysql>) -> deserialize::Result<Self> { + let value = NamedRow::get(row, "value")?; + Ok(Test { value }) + } + } + + impl QueryableByName<Mysql> for NullableTest { + fn build<'a>(row: &impl NamedRow<'a, Mysql>) -> deserialize::Result<Self> { + let value = NamedRow::get(row, "value")?; + Ok(NullableTest { value }) + } + } + + #[test] + fn test_null() { + let mut connection = diesel::MysqlConnection::establish(&get_mysql_url()).expect("Establish connection"); + + // Test NULL + let items: Vec<NullableTest> = sql_query("SELECT CAST(NULL AS DECIMAL) AS value") + .load(&mut connection) + .expect("Unable to query value"); + let result = items.first().unwrap().value; + assert_eq!(None, result); + } + + #[test] + fn read_numeric_type() { + let mut connection = diesel::MysqlConnection::establish(&get_mysql_url()).expect("Establish connection"); + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let items: Vec<Test> = sql_query(format!( + "SELECT CAST('{}' AS DECIMAL({}, {})) AS value", + sent, precision, scale + )) + .load(&mut connection) + .expect("Unable to query value"); + assert_eq!( + expected, + items.first().unwrap().value.to_string(), + "DECIMAL({}, {}) sent: {}", + precision, + scale, + sent + ); + } + } + + #[test] + fn write_numeric_type() { + let mut connection = diesel::MysqlConnection::establish(&get_mysql_url()).expect("Establish connection"); + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let items: Vec<Test> = + sql_query(format!("SELECT CAST(? AS DECIMAL({}, {})) AS value", precision, scale)) + .bind::<Text, _>(sent) + .load(&mut connection) + .expect("Unable to query value"); + assert_eq!( + expected, + items.first().unwrap().value.to_string(), + "DECIMAL({}, {}) sent: {}", + precision, + scale, + sent + ); + } + } + } +} diff --git a/third_party/rust/rust_decimal/src/ops.rs b/third_party/rust/rust_decimal/src/ops.rs new file mode 100644 index 0000000000..6bda140e89 --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops.rs @@ -0,0 +1,34 @@ +// This code (in fact, this library) is heavily inspired by the dotnet Decimal number library +// implementation. Consequently, a huge thank you for to all the contributors to that project +// whose work has also inspired the solutions found here. + +pub(crate) mod array; + +#[cfg(feature = "legacy-ops")] +mod legacy; +#[cfg(feature = "legacy-ops")] +pub(crate) use legacy::{add_impl, cmp_impl, div_impl, mul_impl, rem_impl, sub_impl}; + +#[cfg(not(feature = "legacy-ops"))] +mod add; +#[cfg(not(feature = "legacy-ops"))] +mod cmp; +#[cfg(not(feature = "legacy-ops"))] +pub(in crate::ops) mod common; +#[cfg(not(feature = "legacy-ops"))] +mod div; +#[cfg(not(feature = "legacy-ops"))] +mod mul; +#[cfg(not(feature = "legacy-ops"))] +mod rem; + +#[cfg(not(feature = "legacy-ops"))] +pub(crate) use add::{add_impl, sub_impl}; +#[cfg(not(feature = "legacy-ops"))] +pub(crate) use cmp::cmp_impl; +#[cfg(not(feature = "legacy-ops"))] +pub(crate) use div::div_impl; +#[cfg(not(feature = "legacy-ops"))] +pub(crate) use mul::mul_impl; +#[cfg(not(feature = "legacy-ops"))] +pub(crate) use rem::rem_impl; diff --git a/third_party/rust/rust_decimal/src/ops/add.rs b/third_party/rust/rust_decimal/src/ops/add.rs new file mode 100644 index 0000000000..52ba675f23 --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/add.rs @@ -0,0 +1,382 @@ +use crate::constants::{MAX_I32_SCALE, POWERS_10, SCALE_MASK, SCALE_SHIFT, SIGN_MASK, U32_MASK, U32_MAX}; +use crate::decimal::{CalculationResult, Decimal}; +use crate::ops::common::{Buf24, Dec64}; + +pub(crate) fn add_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + add_sub_internal(d1, d2, false) +} + +pub(crate) fn sub_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + add_sub_internal(d1, d2, true) +} + +#[inline] +fn add_sub_internal(d1: &Decimal, d2: &Decimal, subtract: bool) -> CalculationResult { + if d1.is_zero() { + // 0 - x or 0 + x + let mut result = *d2; + if subtract && !d2.is_zero() { + result.set_sign_negative(d2.is_sign_positive()); + } + return CalculationResult::Ok(result); + } + if d2.is_zero() { + // x - 0 or x + 0 + return CalculationResult::Ok(*d1); + } + + // Work out whether we need to rescale and/or if it's a subtract still given the signs of the + // numbers. + let flags = d1.flags() ^ d2.flags(); + let subtract = subtract ^ ((flags & SIGN_MASK) != 0); + let rescale = (flags & SCALE_MASK) > 0; + + // We optimize towards using 32 bit logic as much as possible. It's noticeably faster at + // scale, even on 64 bit machines + if d1.mid() | d1.hi() == 0 && d2.mid() | d2.hi() == 0 { + // We'll try to rescale, however we may end up with 64 bit (or more) numbers + // If we do, we'll choose a different flow than fast_add + if rescale { + // This is less optimized if we scale to a 64 bit integer. We can add some further logic + // here later on. + let rescale_factor = ((d2.flags() & SCALE_MASK) as i32 - (d1.flags() & SCALE_MASK) as i32) >> SCALE_SHIFT; + if rescale_factor < 0 { + // We try to rescale the rhs + if let Some(rescaled) = rescale32(d2.lo(), -rescale_factor) { + return fast_add(d1.lo(), rescaled, d1.flags(), subtract); + } + } else { + // We try to rescale the lhs + if let Some(rescaled) = rescale32(d1.lo(), rescale_factor) { + return fast_add( + rescaled, + d2.lo(), + (d2.flags() & SCALE_MASK) | (d1.flags() & SIGN_MASK), + subtract, + ); + } + } + } else { + return fast_add(d1.lo(), d2.lo(), d1.flags(), subtract); + } + } + + // Continue on with the slower 64 bit method + let d1 = Dec64::new(d1); + let d2 = Dec64::new(d2); + + // If we're not the same scale then make sure we're there first before starting addition + if rescale { + let rescale_factor = d2.scale as i32 - d1.scale as i32; + if rescale_factor < 0 { + let negative = subtract ^ d1.negative; + let scale = d1.scale; + unaligned_add(d2, d1, negative, scale, -rescale_factor, subtract) + } else { + let negative = d1.negative; + let scale = d2.scale; + unaligned_add(d1, d2, negative, scale, rescale_factor, subtract) + } + } else { + let neg = d1.negative; + let scale = d1.scale; + aligned_add(d1, d2, neg, scale, subtract) + } +} + +#[inline(always)] +fn rescale32(num: u32, rescale_factor: i32) -> Option<u32> { + if rescale_factor > MAX_I32_SCALE { + return None; + } + num.checked_mul(POWERS_10[rescale_factor as usize]) +} + +fn fast_add(lo1: u32, lo2: u32, flags: u32, subtract: bool) -> CalculationResult { + if subtract { + // Sub can't overflow because we're ensuring the bigger number always subtracts the smaller number + if lo1 < lo2 { + return CalculationResult::Ok(Decimal::from_parts_raw(lo2 - lo1, 0, 0, flags ^ SIGN_MASK)); + } + return CalculationResult::Ok(Decimal::from_parts_raw(lo1 - lo2, 0, 0, flags)); + } + // Add can overflow however, so we check for that explicitly + let lo = lo1.wrapping_add(lo2); + let mid = if lo < lo1 { 1 } else { 0 }; + CalculationResult::Ok(Decimal::from_parts_raw(lo, mid, 0, flags)) +} + +fn aligned_add(lhs: Dec64, rhs: Dec64, negative: bool, scale: u32, subtract: bool) -> CalculationResult { + if subtract { + // Signs differ, so subtract + let mut result = Dec64 { + negative, + scale, + low64: lhs.low64.wrapping_sub(rhs.low64), + hi: lhs.hi.wrapping_sub(rhs.hi), + }; + + // Check for carry + if result.low64 > lhs.low64 { + result.hi = result.hi.wrapping_sub(1); + if result.hi >= lhs.hi { + flip_sign(&mut result); + } + } else if result.hi > lhs.hi { + flip_sign(&mut result); + } + CalculationResult::Ok(result.to_decimal()) + } else { + // Signs are the same, so add + let mut result = Dec64 { + negative, + scale, + low64: lhs.low64.wrapping_add(rhs.low64), + hi: lhs.hi.wrapping_add(rhs.hi), + }; + + // Check for carry + if result.low64 < lhs.low64 { + result.hi = result.hi.wrapping_add(1); + if result.hi <= lhs.hi { + if result.scale == 0 { + return CalculationResult::Overflow; + } + reduce_scale(&mut result); + } + } else if result.hi < lhs.hi { + if result.scale == 0 { + return CalculationResult::Overflow; + } + reduce_scale(&mut result); + } + CalculationResult::Ok(result.to_decimal()) + } +} + +fn flip_sign(result: &mut Dec64) { + // Bitwise not the high portion + result.hi = !result.hi; + let low64 = ((result.low64 as i64).wrapping_neg()) as u64; + if low64 == 0 { + result.hi += 1; + } + result.low64 = low64; + result.negative = !result.negative; +} + +fn reduce_scale(result: &mut Dec64) { + let mut low64 = result.low64; + let mut hi = result.hi; + + let mut num = (hi as u64) + (1u64 << 32); + hi = (num / 10u64) as u32; + num = ((num - (hi as u64) * 10u64) << 32) + (low64 >> 32); + let mut div = (num / 10) as u32; + num = ((num - (div as u64) * 10u64) << 32) + (low64 & U32_MASK); + low64 = (div as u64) << 32; + div = (num / 10u64) as u32; + low64 = low64.wrapping_add(div as u64); + let remainder = (num as u32).wrapping_sub(div.wrapping_mul(10)); + + // Finally, round. This is optimizing slightly toward non-rounded numbers + if remainder >= 5 && (remainder > 5 || (low64 & 1) > 0) { + low64 = low64.wrapping_add(1); + if low64 == 0 { + hi += 1; + } + } + + result.low64 = low64; + result.hi = hi; + result.scale -= 1; +} + +// Assumption going into this function is that the LHS is the larger number and will "absorb" the +// smaller number. +fn unaligned_add( + lhs: Dec64, + rhs: Dec64, + negative: bool, + scale: u32, + rescale_factor: i32, + subtract: bool, +) -> CalculationResult { + let mut lhs = lhs; + let mut low64 = lhs.low64; + let mut high = lhs.hi; + let mut rescale_factor = rescale_factor; + + // First off, we see if we can get away with scaling small amounts (or none at all) + if high == 0 { + if low64 <= U32_MAX { + // We know it's not zero, so we start scaling. + // Start with reducing the scale down for the low portion + while low64 <= U32_MAX { + if rescale_factor <= MAX_I32_SCALE { + low64 *= POWERS_10[rescale_factor as usize] as u64; + lhs.low64 = low64; + return aligned_add(lhs, rhs, negative, scale, subtract); + } + rescale_factor -= MAX_I32_SCALE; + low64 *= POWERS_10[9] as u64; + } + } + + // Reduce the scale for the high portion + while high == 0 { + let power = if rescale_factor <= MAX_I32_SCALE { + POWERS_10[rescale_factor as usize] as u64 + } else { + POWERS_10[9] as u64 + }; + + let tmp_low = (low64 & U32_MASK) * power; + let tmp_hi = (low64 >> 32) * power + (tmp_low >> 32); + low64 = (tmp_low & U32_MASK) + (tmp_hi << 32); + high = (tmp_hi >> 32) as u32; + rescale_factor -= MAX_I32_SCALE; + if rescale_factor <= 0 { + lhs.low64 = low64; + lhs.hi = high; + return aligned_add(lhs, rhs, negative, scale, subtract); + } + } + } + + // See if we can get away with keeping it in the 96 bits. Otherwise, we need a buffer + let mut tmp64: u64; + loop { + let power = if rescale_factor <= MAX_I32_SCALE { + POWERS_10[rescale_factor as usize] as u64 + } else { + POWERS_10[9] as u64 + }; + + let tmp_low = (low64 & U32_MASK) * power; + tmp64 = (low64 >> 32) * power + (tmp_low >> 32); + low64 = (tmp_low & U32_MASK) + (tmp64 << 32); + tmp64 >>= 32; + tmp64 += (high as u64) * power; + + rescale_factor -= MAX_I32_SCALE; + + if tmp64 > U32_MAX { + break; + } else { + high = tmp64 as u32; + if rescale_factor <= 0 { + lhs.low64 = low64; + lhs.hi = high; + return aligned_add(lhs, rhs, negative, scale, subtract); + } + } + } + + let mut buffer = Buf24::zero(); + buffer.set_low64(low64); + buffer.set_mid64(tmp64); + + let mut upper_word = buffer.upper_word(); + while rescale_factor > 0 { + let power = if rescale_factor <= MAX_I32_SCALE { + POWERS_10[rescale_factor as usize] as u64 + } else { + POWERS_10[9] as u64 + }; + tmp64 = 0; + for (index, part) in buffer.data.iter_mut().enumerate() { + tmp64 = tmp64.wrapping_add((*part as u64) * power); + *part = tmp64 as u32; + tmp64 >>= 32; + if index + 1 > upper_word { + break; + } + } + + if tmp64 & U32_MASK > 0 { + // Extend the result + upper_word += 1; + buffer.data[upper_word] = tmp64 as u32; + } + + rescale_factor -= MAX_I32_SCALE; + } + + // Do the add + tmp64 = buffer.low64(); + low64 = rhs.low64; + let tmp_hi = buffer.data[2]; + high = rhs.hi; + + if subtract { + low64 = tmp64.wrapping_sub(low64); + high = tmp_hi.wrapping_sub(high); + + // Check for carry + let carry = if low64 > tmp64 { + high = high.wrapping_sub(1); + high >= tmp_hi + } else { + high > tmp_hi + }; + + if carry { + for part in buffer.data.iter_mut().skip(3) { + *part = part.wrapping_sub(1); + if *part > 0 { + break; + } + } + + if buffer.data[upper_word] == 0 && upper_word < 3 { + return CalculationResult::Ok(Decimal::from_parts( + low64 as u32, + (low64 >> 32) as u32, + high, + negative, + scale, + )); + } + } + } else { + low64 = low64.wrapping_add(tmp64); + high = high.wrapping_add(tmp_hi); + + // Check for carry + let carry = if low64 < tmp64 { + high = high.wrapping_add(1); + high <= tmp_hi + } else { + high < tmp_hi + }; + + if carry { + for (index, part) in buffer.data.iter_mut().enumerate().skip(3) { + if upper_word < index { + *part = 1; + upper_word = index; + break; + } + *part = part.wrapping_add(1); + if *part > 0 { + break; + } + } + } + } + + buffer.set_low64(low64); + buffer.data[2] = high; + if let Some(scale) = buffer.rescale(upper_word, scale) { + CalculationResult::Ok(Decimal::from_parts( + buffer.data[0], + buffer.data[1], + buffer.data[2], + negative, + scale, + )) + } else { + CalculationResult::Overflow + } +} diff --git a/third_party/rust/rust_decimal/src/ops/array.rs b/third_party/rust/rust_decimal/src/ops/array.rs new file mode 100644 index 0000000000..2ab58b4dfb --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/array.rs @@ -0,0 +1,381 @@ +use crate::constants::{MAX_PRECISION_U32, POWERS_10, U32_MASK}; + +/// Rescales the given decimal to new scale. +/// e.g. with 1.23 and new scale 3 rescale the value to 1.230 +#[inline(always)] +pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_scale: u32) { + if *value_scale == new_scale { + // Nothing to do + return; + } + + if is_all_zero(value) { + *value_scale = new_scale.min(MAX_PRECISION_U32); + return; + } + + if *value_scale > new_scale { + let mut diff = value_scale.wrapping_sub(new_scale); + // Scaling further isn't possible since we got an overflow + // In this case we need to reduce the accuracy of the "side to keep" + + // Now do the necessary rounding + let mut remainder = 0; + while let Some(diff_minus_one) = diff.checked_sub(1) { + if is_all_zero(value) { + *value_scale = new_scale; + return; + } + + diff = diff_minus_one; + + // Any remainder is discarded if diff > 0 still (i.e. lost precision) + remainder = div_by_u32(value, 10); + } + if remainder >= 5 { + for part in value.iter_mut() { + let digit = u64::from(*part) + 1u64; + remainder = if digit > U32_MASK { 1 } else { 0 }; + *part = (digit & U32_MASK) as u32; + if remainder == 0 { + break; + } + } + } + *value_scale = new_scale; + } else { + let mut diff = new_scale.wrapping_sub(*value_scale); + let mut working = [value[0], value[1], value[2]]; + while let Some(diff_minus_one) = diff.checked_sub(1) { + if mul_by_10(&mut working) == 0 { + value.copy_from_slice(&working); + diff = diff_minus_one; + } else { + break; + } + } + *value_scale = new_scale.wrapping_sub(diff); + } +} + +#[cfg(feature = "legacy-ops")] +pub(crate) fn add_by_internal(value: &mut [u32], by: &[u32]) -> u32 { + let mut carry: u64 = 0; + let vl = value.len(); + let bl = by.len(); + if vl >= bl { + let mut sum: u64; + for i in 0..bl { + sum = u64::from(value[i]) + u64::from(by[i]) + carry; + value[i] = (sum & U32_MASK) as u32; + carry = sum >> 32; + } + if vl > bl && carry > 0 { + for i in value.iter_mut().skip(bl) { + sum = u64::from(*i) + carry; + *i = (sum & U32_MASK) as u32; + carry = sum >> 32; + if carry == 0 { + break; + } + } + } + } else if vl + 1 == bl { + // Overflow, by default, is anything in the high portion of by + let mut sum: u64; + for i in 0..vl { + sum = u64::from(value[i]) + u64::from(by[i]) + carry; + value[i] = (sum & U32_MASK) as u32; + carry = sum >> 32; + } + if by[vl] > 0 { + carry += u64::from(by[vl]); + } + } else { + panic!("Internal error: add using incompatible length arrays. {} <- {}", vl, bl); + } + carry as u32 +} + +pub(crate) fn add_by_internal_flattened(value: &mut [u32; 3], by: u32) -> u32 { + manage_add_by_internal(by, value) +} + +#[inline] +pub(crate) fn add_one_internal(value: &mut [u32; 3]) -> u32 { + manage_add_by_internal(1, value) +} + +// `u64 as u32` are safe because of widening and 32bits shifts +#[inline] +pub(crate) fn manage_add_by_internal<const N: usize>(initial_carry: u32, value: &mut [u32; N]) -> u32 { + let mut carry = u64::from(initial_carry); + let mut iter = 0..value.len(); + let mut sum = 0; + + let mut sum_fn = |local_carry: &mut u64, idx| { + sum = u64::from(value[idx]).wrapping_add(*local_carry); + value[idx] = (sum & U32_MASK) as u32; + *local_carry = sum.wrapping_shr(32); + }; + + if let Some(idx) = iter.next() { + sum_fn(&mut carry, idx); + } + + for idx in iter { + if carry > 0 { + sum_fn(&mut carry, idx); + } + } + + carry as u32 +} + +pub(crate) fn sub_by_internal(value: &mut [u32], by: &[u32]) -> u32 { + // The way this works is similar to long subtraction + // Let's assume we're working with bytes for simplicity in an example: + // 257 - 8 = 249 + // 0000_0001 0000_0001 - 0000_0000 0000_1000 = 0000_0000 1111_1001 + // We start by doing the first byte... + // Overflow = 0 + // Left = 0000_0001 (1) + // Right = 0000_1000 (8) + // Firstly, we make sure the left and right are scaled up to twice the size + // Left = 0000_0000 0000_0001 + // Right = 0000_0000 0000_1000 + // We then subtract right from left + // Result = Left - Right = 1111_1111 1111_1001 + // We subtract the overflow, which in this case is 0. + // Because left < right (1 < 8) we invert the high part. + // Lo = 1111_1001 + // Hi = 1111_1111 -> 0000_0001 + // Lo is the field, hi is the overflow. + // We do the same for the second byte... + // Overflow = 1 + // Left = 0000_0001 + // Right = 0000_0000 + // Result = Left - Right = 0000_0000 0000_0001 + // We subtract the overflow... + // Result = 0000_0000 0000_0001 - 1 = 0 + // And we invert the high, just because (invert 0 = 0). + // So our result is: + // 0000_0000 1111_1001 + let mut overflow = 0; + let vl = value.len(); + let bl = by.len(); + for i in 0..vl { + if i >= bl { + break; + } + let (lo, hi) = sub_part(value[i], by[i], overflow); + value[i] = lo; + overflow = hi; + } + overflow +} + +fn sub_part(left: u32, right: u32, overflow: u32) -> (u32, u32) { + let part = 0x1_0000_0000u64 + u64::from(left) - (u64::from(right) + u64::from(overflow)); + let lo = part as u32; + let hi = 1 - ((part >> 32) as u32); + (lo, hi) +} + +// Returns overflow +#[inline] +pub(crate) fn mul_by_10(bits: &mut [u32; 3]) -> u32 { + let mut overflow = 0u64; + for b in bits.iter_mut() { + let result = u64::from(*b) * 10u64 + overflow; + let hi = (result >> 32) & U32_MASK; + let lo = (result & U32_MASK) as u32; + *b = lo; + overflow = hi; + } + + overflow as u32 +} + +// Returns overflow +pub(crate) fn mul_by_u32(bits: &mut [u32], m: u32) -> u32 { + let mut overflow = 0; + for b in bits.iter_mut() { + let (lo, hi) = mul_part(*b, m, overflow); + *b = lo; + overflow = hi; + } + overflow +} + +pub(crate) fn mul_part(left: u32, right: u32, high: u32) -> (u32, u32) { + let result = u64::from(left) * u64::from(right) + u64::from(high); + let hi = ((result >> 32) & U32_MASK) as u32; + let lo = (result & U32_MASK) as u32; + (lo, hi) +} + +// Returns remainder +pub(crate) fn div_by_u32<const N: usize>(bits: &mut [u32; N], divisor: u32) -> u32 { + if divisor == 0 { + // Divide by zero + panic!("Internal error: divide by zero"); + } else if divisor == 1 { + // dividend remains unchanged + 0 + } else { + let mut remainder = 0u32; + let divisor = u64::from(divisor); + for part in bits.iter_mut().rev() { + let temp = (u64::from(remainder) << 32) + u64::from(*part); + remainder = (temp % divisor) as u32; + *part = (temp / divisor) as u32; + } + + remainder + } +} + +pub(crate) fn div_by_1x(bits: &mut [u32; 3], power: usize) -> u32 { + let mut remainder = 0u32; + let divisor = POWERS_10[power] as u64; + let temp = ((remainder as u64) << 32) + (bits[2] as u64); + remainder = (temp % divisor) as u32; + bits[2] = (temp / divisor) as u32; + let temp = ((remainder as u64) << 32) + (bits[1] as u64); + remainder = (temp % divisor) as u32; + bits[1] = (temp / divisor) as u32; + let temp = ((remainder as u64) << 32) + (bits[0] as u64); + remainder = (temp % divisor) as u32; + bits[0] = (temp / divisor) as u32; + remainder +} + +#[inline] +pub(crate) fn shl1_internal(bits: &mut [u32], carry: u32) -> u32 { + let mut carry = carry; + for part in bits.iter_mut() { + let b = *part >> 31; + *part = (*part << 1) | carry; + carry = b; + } + carry +} + +#[inline] +pub(crate) fn cmp_internal(left: &[u32; 3], right: &[u32; 3]) -> core::cmp::Ordering { + let left_hi: u32 = left[2]; + let right_hi: u32 = right[2]; + let left_lo: u64 = u64::from(left[1]) << 32 | u64::from(left[0]); + let right_lo: u64 = u64::from(right[1]) << 32 | u64::from(right[0]); + if left_hi < right_hi || (left_hi <= right_hi && left_lo < right_lo) { + core::cmp::Ordering::Less + } else if left_hi == right_hi && left_lo == right_lo { + core::cmp::Ordering::Equal + } else { + core::cmp::Ordering::Greater + } +} + +#[inline] +pub(crate) fn is_all_zero<const N: usize>(bits: &[u32; N]) -> bool { + bits.iter().all(|b| *b == 0) +} + +#[cfg(test)] +mod test { + // Tests on private methods. + // + // All public tests should go under `tests/`. + + use super::*; + use crate::prelude::*; + + #[test] + fn it_can_rescale_internal() { + fn extract(value: &str) -> ([u32; 3], u32) { + let v = Decimal::from_str(value).unwrap(); + (v.mantissa_array3(), v.scale()) + } + + let tests = &[ + ("1", 0, "1", 0), + ("1", 1, "1.0", 1), + ("1", 5, "1.00000", 5), + ("1", 10, "1.0000000000", 10), + ("1", 20, "1.00000000000000000000", 20), + ( + "0.6386554621848739495798319328", + 27, + "0.638655462184873949579831933", + 27, + ), + ( + "843.65000000", // Scale 8 + 25, + "843.6500000000000000000000000", + 25, + ), + ( + "843.65000000", // Scale 8 + 30, + "843.6500000000000000000000000", + 25, // Only fits 25 + ), + ("0", 130, "0.000000000000000000000000000000", 28), + ]; + + for &(value_raw, new_scale, expected_value, expected_scale) in tests { + let (expected_value, _) = extract(expected_value); + let (mut value, mut value_scale) = extract(value_raw); + rescale_internal(&mut value, &mut value_scale, new_scale); + assert_eq!(value, expected_value); + assert_eq!( + value_scale, expected_scale, + "value: {}, requested scale: {}", + value_raw, new_scale + ); + } + } + + #[test] + fn test_shl1_internal() { + struct TestCase { + // One thing to be cautious of is that the structure of a number here for shifting left is + // the reverse of how you may conceive this mentally. i.e. a[2] contains the higher order + // bits: a[2] a[1] a[0] + given: [u32; 3], + given_carry: u32, + expected: [u32; 3], + expected_carry: u32, + } + let tests = [ + TestCase { + given: [1, 0, 0], + given_carry: 0, + expected: [2, 0, 0], + expected_carry: 0, + }, + TestCase { + given: [1, 0, 2147483648], + given_carry: 1, + expected: [3, 0, 0], + expected_carry: 1, + }, + ]; + for case in &tests { + let mut test = [case.given[0], case.given[1], case.given[2]]; + let carry = shl1_internal(&mut test, case.given_carry); + assert_eq!( + test, case.expected, + "Bits: {:?} << 1 | {}", + case.given, case.given_carry + ); + assert_eq!( + carry, case.expected_carry, + "Carry: {:?} << 1 | {}", + case.given, case.given_carry + ) + } + } +} diff --git a/third_party/rust/rust_decimal/src/ops/cmp.rs b/third_party/rust/rust_decimal/src/ops/cmp.rs new file mode 100644 index 0000000000..636085bff5 --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/cmp.rs @@ -0,0 +1,101 @@ +use crate::constants::{MAX_I32_SCALE, POWERS_10, U32_MASK, U32_MAX}; +use crate::decimal::Decimal; +use crate::ops::common::Dec64; + +use core::cmp::Ordering; + +pub(crate) fn cmp_impl(d1: &Decimal, d2: &Decimal) -> Ordering { + if d2.is_zero() { + return if d1.is_zero() { + return Ordering::Equal; + } else if d1.is_sign_negative() { + Ordering::Less + } else { + Ordering::Greater + }; + } + if d1.is_zero() { + return if d2.is_sign_negative() { + Ordering::Greater + } else { + Ordering::Less + }; + } + // If the sign is different, then it's an easy answer + if d1.is_sign_negative() != d2.is_sign_negative() { + return if d1.is_sign_negative() { + Ordering::Less + } else { + Ordering::Greater + }; + } + + // Otherwise, do a deep comparison + let d1 = Dec64::new(d1); + let d2 = Dec64::new(d2); + // We know both signs are the same here so flip it here. + // Negative is handled differently. i.e. 0.5 > 0.01 however -0.5 < -0.01 + if d1.negative { + cmp_internal(&d2, &d1) + } else { + cmp_internal(&d1, &d2) + } +} + +pub(in crate::ops) fn cmp_internal(d1: &Dec64, d2: &Dec64) -> Ordering { + // This function ignores sign + let mut d1_low = d1.low64; + let mut d1_high = d1.hi; + let mut d2_low = d2.low64; + let mut d2_high = d2.hi; + + // If the scale factors aren't equal then + if d1.scale != d2.scale { + let mut diff = d2.scale as i32 - d1.scale as i32; + if diff < 0 { + diff = -diff; + if !rescale(&mut d2_low, &mut d2_high, diff as u32) { + return Ordering::Less; + } + } else if !rescale(&mut d1_low, &mut d1_high, diff as u32) { + return Ordering::Greater; + } + } + + // They're the same scale, do a standard bitwise comparison + let hi_order = d1_high.cmp(&d2_high); + if hi_order != Ordering::Equal { + return hi_order; + } + d1_low.cmp(&d2_low) +} + +fn rescale(low64: &mut u64, high: &mut u32, diff: u32) -> bool { + let mut diff = diff as i32; + // We need to modify d1 by 10^diff to get it to the same scale as d2 + loop { + let power = if diff >= MAX_I32_SCALE { + POWERS_10[9] + } else { + POWERS_10[diff as usize] + } as u64; + let tmp_lo_32 = (*low64 & U32_MASK) * power; + let mut tmp = (*low64 >> 32) * power + (tmp_lo_32 >> 32); + *low64 = (tmp_lo_32 & U32_MASK) + (tmp << 32); + tmp >>= 32; + tmp = tmp.wrapping_add((*high as u64) * power); + // Indicates > 96 bits + if tmp > U32_MAX { + return false; + } + *high = tmp as u32; + + // Keep scaling if there is more to go + diff -= MAX_I32_SCALE; + if diff <= 0 { + break; + } + } + + true +} diff --git a/third_party/rust/rust_decimal/src/ops/common.rs b/third_party/rust/rust_decimal/src/ops/common.rs new file mode 100644 index 0000000000..c29362d824 --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/common.rs @@ -0,0 +1,455 @@ +use crate::constants::{MAX_I32_SCALE, MAX_PRECISION_I32, POWERS_10}; +use crate::Decimal; + +#[derive(Debug)] +pub struct Buf12 { + pub data: [u32; 3], +} + +impl Buf12 { + pub(super) const fn from_dec64(value: &Dec64) -> Self { + Buf12 { + data: [value.low64 as u32, (value.low64 >> 32) as u32, value.hi], + } + } + + pub(super) const fn from_decimal(value: &Decimal) -> Self { + Buf12 { + data: value.mantissa_array3(), + } + } + + #[inline(always)] + pub const fn lo(&self) -> u32 { + self.data[0] + } + #[inline(always)] + pub const fn mid(&self) -> u32 { + self.data[1] + } + #[inline(always)] + pub const fn hi(&self) -> u32 { + self.data[2] + } + #[inline(always)] + pub fn set_lo(&mut self, value: u32) { + self.data[0] = value; + } + #[inline(always)] + pub fn set_mid(&mut self, value: u32) { + self.data[1] = value; + } + #[inline(always)] + pub fn set_hi(&mut self, value: u32) { + self.data[2] = value; + } + + #[inline(always)] + pub const fn low64(&self) -> u64 { + ((self.data[1] as u64) << 32) | (self.data[0] as u64) + } + + #[inline(always)] + pub fn set_low64(&mut self, value: u64) { + self.data[1] = (value >> 32) as u32; + self.data[0] = value as u32; + } + + #[inline(always)] + pub const fn high64(&self) -> u64 { + ((self.data[2] as u64) << 32) | (self.data[1] as u64) + } + + #[inline(always)] + pub fn set_high64(&mut self, value: u64) { + self.data[2] = (value >> 32) as u32; + self.data[1] = value as u32; + } + + // Determine the maximum value of x that ensures that the quotient when scaled up by 10^x + // still fits in 96 bits. Ultimately, we want to make scale positive - if we can't then + // we're going to overflow. Because x is ultimately used to lookup inside the POWERS array, it + // must be a valid value 0 <= x <= 9 + pub fn find_scale(&self, scale: i32) -> Option<usize> { + const OVERFLOW_MAX_9_HI: u32 = 4; + const OVERFLOW_MAX_8_HI: u32 = 42; + const OVERFLOW_MAX_7_HI: u32 = 429; + const OVERFLOW_MAX_6_HI: u32 = 4294; + const OVERFLOW_MAX_5_HI: u32 = 42949; + const OVERFLOW_MAX_4_HI: u32 = 429496; + const OVERFLOW_MAX_3_HI: u32 = 4294967; + const OVERFLOW_MAX_2_HI: u32 = 42949672; + const OVERFLOW_MAX_1_HI: u32 = 429496729; + const OVERFLOW_MAX_9_LOW64: u64 = 5441186219426131129; + + let hi = self.data[2]; + let low64 = self.low64(); + let mut x = 0usize; + + // Quick check to stop us from trying to scale any more. + // + if hi > OVERFLOW_MAX_1_HI { + // If it's less than 0, which it probably is - overflow. We can't do anything. + if scale < 0 { + return None; + } + return Some(x); + } + + if scale > MAX_PRECISION_I32 - 9 { + // We can't scale by 10^9 without exceeding the max scale factor. + // Instead, we'll try to scale by the most that we can and see if that works. + // This is safe to do due to the check above. e.g. scale > 19 in the above, so it will + // evaluate to 9 or less below. + x = (MAX_PRECISION_I32 - scale) as usize; + if hi < POWER_OVERFLOW_VALUES[x - 1].data[2] { + if x as i32 + scale < 0 { + // We still overflow + return None; + } + return Some(x); + } + } else if hi < OVERFLOW_MAX_9_HI || hi == OVERFLOW_MAX_9_HI && low64 <= OVERFLOW_MAX_9_LOW64 { + return Some(9); + } + + // Do a binary search to find a power to scale by that is less than 9 + x = if hi > OVERFLOW_MAX_5_HI { + if hi > OVERFLOW_MAX_3_HI { + if hi > OVERFLOW_MAX_2_HI { + 1 + } else { + 2 + } + } else if hi > OVERFLOW_MAX_4_HI { + 3 + } else { + 4 + } + } else if hi > OVERFLOW_MAX_7_HI { + if hi > OVERFLOW_MAX_6_HI { + 5 + } else { + 6 + } + } else if hi > OVERFLOW_MAX_8_HI { + 7 + } else { + 8 + }; + + // Double check what we've found won't overflow. Otherwise, we go one below. + if hi == POWER_OVERFLOW_VALUES[x - 1].data[2] && low64 > POWER_OVERFLOW_VALUES[x - 1].low64() { + x -= 1; + } + + // Confirm we've actually resolved things + if x as i32 + scale < 0 { + None + } else { + Some(x) + } + } +} + +// This is a table of the largest values that will not overflow when multiplied +// by a given power as represented by the index. +static POWER_OVERFLOW_VALUES: [Buf12; 8] = [ + Buf12 { + data: [2576980377, 2576980377, 429496729], + }, + Buf12 { + data: [687194767, 4123168604, 42949672], + }, + Buf12 { + data: [2645699854, 1271310319, 4294967], + }, + Buf12 { + data: [694066715, 3133608139, 429496], + }, + Buf12 { + data: [2216890319, 2890341191, 42949], + }, + Buf12 { + data: [2369172679, 4154504685, 4294], + }, + Buf12 { + data: [4102387834, 2133437386, 429], + }, + Buf12 { + data: [410238783, 4078814305, 42], + }, +]; + +pub(super) struct Dec64 { + pub negative: bool, + pub scale: u32, + pub hi: u32, + pub low64: u64, +} + +impl Dec64 { + pub(super) const fn new(d: &Decimal) -> Dec64 { + let m = d.mantissa_array3(); + if m[1] == 0 { + Dec64 { + negative: d.is_sign_negative(), + scale: d.scale(), + hi: m[2], + low64: m[0] as u64, + } + } else { + Dec64 { + negative: d.is_sign_negative(), + scale: d.scale(), + hi: m[2], + low64: ((m[1] as u64) << 32) | (m[0] as u64), + } + } + } + + #[inline(always)] + pub(super) const fn lo(&self) -> u32 { + self.low64 as u32 + } + #[inline(always)] + pub(super) const fn mid(&self) -> u32 { + (self.low64 >> 32) as u32 + } + + #[inline(always)] + pub(super) const fn high64(&self) -> u64 { + (self.low64 >> 32) | ((self.hi as u64) << 32) + } + + pub(super) const fn to_decimal(&self) -> Decimal { + Decimal::from_parts( + self.low64 as u32, + (self.low64 >> 32) as u32, + self.hi, + self.negative, + self.scale, + ) + } +} + +pub struct Buf16 { + pub data: [u32; 4], +} + +impl Buf16 { + pub const fn zero() -> Self { + Buf16 { data: [0, 0, 0, 0] } + } + + pub const fn low64(&self) -> u64 { + ((self.data[1] as u64) << 32) | (self.data[0] as u64) + } + + pub fn set_low64(&mut self, value: u64) { + self.data[1] = (value >> 32) as u32; + self.data[0] = value as u32; + } + + pub const fn mid64(&self) -> u64 { + ((self.data[2] as u64) << 32) | (self.data[1] as u64) + } + + pub fn set_mid64(&mut self, value: u64) { + self.data[2] = (value >> 32) as u32; + self.data[1] = value as u32; + } + + pub const fn high64(&self) -> u64 { + ((self.data[3] as u64) << 32) | (self.data[2] as u64) + } + + pub fn set_high64(&mut self, value: u64) { + self.data[3] = (value >> 32) as u32; + self.data[2] = value as u32; + } +} + +#[derive(Debug)] +pub struct Buf24 { + pub data: [u32; 6], +} + +impl Buf24 { + pub const fn zero() -> Self { + Buf24 { + data: [0, 0, 0, 0, 0, 0], + } + } + + pub const fn low64(&self) -> u64 { + ((self.data[1] as u64) << 32) | (self.data[0] as u64) + } + + pub fn set_low64(&mut self, value: u64) { + self.data[1] = (value >> 32) as u32; + self.data[0] = value as u32; + } + + #[allow(dead_code)] + pub const fn mid64(&self) -> u64 { + ((self.data[3] as u64) << 32) | (self.data[2] as u64) + } + + pub fn set_mid64(&mut self, value: u64) { + self.data[3] = (value >> 32) as u32; + self.data[2] = value as u32; + } + + #[allow(dead_code)] + pub const fn high64(&self) -> u64 { + ((self.data[5] as u64) << 32) | (self.data[4] as u64) + } + + pub fn set_high64(&mut self, value: u64) { + self.data[5] = (value >> 32) as u32; + self.data[4] = value as u32; + } + + pub const fn upper_word(&self) -> usize { + if self.data[5] > 0 { + return 5; + } + if self.data[4] > 0 { + return 4; + } + if self.data[3] > 0 { + return 3; + } + if self.data[2] > 0 { + return 2; + } + if self.data[1] > 0 { + return 1; + } + 0 + } + + // Attempt to rescale the number into 96 bits. If successful, the scale is returned wrapped + // in an Option. If it failed due to overflow, we return None. + // * `upper` - Index of last non-zero value in self. + // * `scale` - Current scale factor for this value. + pub fn rescale(&mut self, upper: usize, scale: u32) -> Option<u32> { + let mut scale = scale as i32; + let mut upper = upper; + + // Determine a rescale target to start with + let mut rescale_target = 0i32; + if upper > 2 { + rescale_target = upper as i32 * 32 - 64 - 1; + rescale_target -= self.data[upper].leading_zeros() as i32; + rescale_target = ((rescale_target * 77) >> 8) + 1; + if rescale_target > scale { + return None; + } + } + + // Make sure we scale enough to bring it into a valid range + if rescale_target < scale - MAX_PRECISION_I32 { + rescale_target = scale - MAX_PRECISION_I32; + } + + if rescale_target > 0 { + // We're going to keep reducing by powers of 10. So, start by reducing the scale by + // that amount. + scale -= rescale_target; + let mut sticky = 0; + let mut remainder = 0; + loop { + sticky |= remainder; + let mut power = if rescale_target > 8 { + POWERS_10[9] + } else { + POWERS_10[rescale_target as usize] + }; + + let high = self.data[upper]; + let high_quotient = high / power; + remainder = high - high_quotient * power; + + for item in self.data.iter_mut().rev().skip(6 - upper) { + let num = (*item as u64).wrapping_add((remainder as u64) << 32); + *item = (num / power as u64) as u32; + remainder = (num as u32).wrapping_sub(item.wrapping_mul(power)); + } + + self.data[upper] = high_quotient; + + // If the high quotient was zero then decrease the upper bound + if high_quotient == 0 && upper > 0 { + upper -= 1; + } + if rescale_target > MAX_I32_SCALE { + // Scale some more + rescale_target -= MAX_I32_SCALE; + continue; + } + + // If we fit into 96 bits then we've scaled enough. Otherwise, scale once more. + if upper > 2 { + if scale == 0 { + return None; + } + // Equivalent to scaling down by 10 + rescale_target = 1; + scale -= 1; + continue; + } + + // Round the final result. + power >>= 1; + let carried = if power <= remainder { + // If we're less than half then we're fine. Otherwise, we round if odd or if the + // sticky bit is set. + if power < remainder || ((self.data[0] & 1) | sticky) != 0 { + // Round up + self.data[0] = self.data[0].wrapping_add(1); + // Check if we carried + self.data[0] == 0 + } else { + false + } + } else { + false + }; + + // If we carried then propagate through the portions + if carried { + let mut pos = 0; + for (index, value) in self.data.iter_mut().enumerate().skip(1) { + pos = index; + *value = value.wrapping_add(1); + if *value != 0 { + break; + } + } + + // If we ended up rounding over the 96 bits then we'll try to rescale down (again) + if pos > 2 { + // Nothing to scale down from will cause overflow + if scale == 0 { + return None; + } + + // Loop back around using scale of 10. + // Reset the sticky bit and remainder before looping. + upper = pos; + sticky = 0; + remainder = 0; + rescale_target = 1; + scale -= 1; + continue; + } + } + break; + } + } + + Some(scale as u32) + } +} diff --git a/third_party/rust/rust_decimal/src/ops/div.rs b/third_party/rust/rust_decimal/src/ops/div.rs new file mode 100644 index 0000000000..3b5ec577b2 --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/div.rs @@ -0,0 +1,658 @@ +use crate::constants::{MAX_PRECISION_I32, POWERS_10}; +use crate::decimal::{CalculationResult, Decimal}; +use crate::ops::common::{Buf12, Buf16, Dec64}; + +use core::cmp::Ordering; +use core::ops::BitXor; + +impl Buf12 { + // Returns true if successful, else false for an overflow + fn add32(&mut self, value: u32) -> Result<(), DivError> { + let value = value as u64; + let new = self.low64().wrapping_add(value); + self.set_low64(new); + if new < value { + self.data[2] = self.data[2].wrapping_add(1); + if self.data[2] == 0 { + return Err(DivError::Overflow); + } + } + Ok(()) + } + + // Divide a Decimal union by a 32 bit divisor. + // Self is overwritten with the quotient. + // Return value is a 32 bit remainder. + fn div32(&mut self, divisor: u32) -> u32 { + let divisor64 = divisor as u64; + // See if we can get by using a simple u64 division + if self.data[2] != 0 { + let mut temp = self.high64(); + let q64 = temp / divisor64; + self.set_high64(q64); + + // Calculate the "remainder" + temp = ((temp - q64 * divisor64) << 32) | (self.data[0] as u64); + if temp == 0 { + return 0; + } + let q32 = (temp / divisor64) as u32; + self.data[0] = q32; + ((temp as u32).wrapping_sub(q32.wrapping_mul(divisor))) as u32 + } else { + // Super easy divisor + let low64 = self.low64(); + if low64 == 0 { + // Nothing to do + return 0; + } + // Do the calc + let quotient = low64 / divisor64; + self.set_low64(quotient); + // Remainder is the leftover that wasn't used + (low64.wrapping_sub(quotient.wrapping_mul(divisor64))) as u32 + } + } + + // Divide the number by a power constant + // Returns true if division was successful + fn div32_const(&mut self, pow: u32) -> bool { + let pow64 = pow as u64; + let high64 = self.high64(); + let lo = self.data[0] as u64; + let div64: u64 = high64 / pow64; + let div = ((((high64 - div64 * pow64) << 32) + lo) / pow64) as u32; + if self.data[0] == div.wrapping_mul(pow) { + self.set_high64(div64); + self.data[0] = div; + true + } else { + false + } + } +} + +impl Buf16 { + // Does a partial divide with a 64 bit divisor. The divisor in this case must be 64 bits + // otherwise various assumptions fail (e.g. 32 bit quotient). + // To assist, the upper 64 bits must be greater than the divisor for this to succeed. + // Consequently, it will return the quotient as a 32 bit number and overwrite self with the + // 64 bit remainder. + pub(super) fn partial_divide_64(&mut self, divisor: u64) -> u32 { + // We make this assertion here, however below we pivot based on the data + debug_assert!(divisor > self.mid64()); + + // If we have an empty high bit, then divisor must be greater than the dividend due to + // the assumption that the divisor REQUIRES 64 bits. + if self.data[2] == 0 { + let low64 = self.low64(); + if low64 < divisor { + // We can't divide at at all so result is 0. The dividend remains untouched since + // the full amount is the remainder. + return 0; + } + + let quotient = low64 / divisor; + self.set_low64(low64 - (quotient * divisor)); + return quotient as u32; + } + + // Do a simple check to see if the hi portion of the dividend is greater than the hi + // portion of the divisor. + let divisor_hi32 = (divisor >> 32) as u32; + if self.data[2] >= divisor_hi32 { + // We know that the divisor goes into this at MOST u32::max times. + // So we kick things off, with that assumption + let mut low64 = self.low64(); + low64 = low64.wrapping_sub(divisor << 32).wrapping_add(divisor); + let mut quotient = u32::MAX; + + // If we went negative then keep adding it back in + loop { + if low64 < divisor { + break; + } + quotient = quotient.wrapping_sub(1); + low64 = low64.wrapping_add(divisor); + } + self.set_low64(low64); + return quotient; + } + + let mid64 = self.mid64(); + let divisor_hi32_64 = divisor_hi32 as u64; + if mid64 < divisor_hi32_64 as u64 { + // similar situation as above where we've got nothing left to divide + return 0; + } + + let mut quotient = mid64 / divisor_hi32_64; + let mut remainder = self.data[0] as u64 | ((mid64 - quotient * divisor_hi32_64) << 32); + + // Do quotient * lo divisor + let product = quotient * (divisor & 0xFFFF_FFFF); + remainder = remainder.wrapping_sub(product); + + // Check if we've gone negative. If so, add it back + if remainder > product.bitxor(u64::MAX) { + loop { + quotient = quotient.wrapping_sub(1); + remainder = remainder.wrapping_add(divisor); + if remainder < divisor { + break; + } + } + } + + self.set_low64(remainder); + quotient as u32 + } + + // Does a partial divide with a 96 bit divisor. The divisor in this case must require 96 bits + // otherwise various assumptions fail (e.g. 32 bit quotient). + pub(super) fn partial_divide_96(&mut self, divisor: &Buf12) -> u32 { + let dividend = self.high64(); + let divisor_hi = divisor.data[2]; + if dividend < divisor_hi as u64 { + // Dividend is too small - entire number is remainder + return 0; + } + + let mut quo = (dividend / divisor_hi as u64) as u32; + let mut remainder = (dividend as u32).wrapping_sub(quo.wrapping_mul(divisor_hi)); + + // Compute full remainder + let mut prod1 = quo as u64 * divisor.data[0] as u64; + let mut prod2 = quo as u64 * divisor.data[1] as u64; + prod2 += prod1 >> 32; + prod1 = (prod1 & 0xFFFF_FFFF) | (prod2 << 32); + prod2 >>= 32; + + let mut num = self.low64(); + num = num.wrapping_sub(prod1); + remainder = remainder.wrapping_sub(prod2 as u32); + + // If there are carries make sure they are propagated + if num > prod1.bitxor(u64::MAX) { + remainder = remainder.wrapping_sub(1); + if remainder < (prod2 as u32).bitxor(u32::MAX) { + self.set_low64(num); + self.data[2] = remainder; + return quo; + } + } else if remainder <= (prod2 as u32).bitxor(u32::MAX) { + self.set_low64(num); + self.data[2] = remainder; + return quo; + } + + // Remainder went negative, add divisor back until it's positive + prod1 = divisor.low64(); + loop { + quo = quo.wrapping_sub(1); + num = num.wrapping_add(prod1); + remainder = remainder.wrapping_add(divisor_hi); + + if num < prod1 { + // Detected carry. + let tmp = remainder; + remainder = remainder.wrapping_add(1); + if tmp < divisor_hi { + break; + } + } + if remainder < divisor_hi { + break; // detected carry + } + } + + self.set_low64(num); + self.data[2] = remainder; + quo + } +} + +enum DivError { + Overflow, +} + +pub(crate) fn div_impl(dividend: &Decimal, divisor: &Decimal) -> CalculationResult { + if divisor.is_zero() { + return CalculationResult::DivByZero; + } + if dividend.is_zero() { + return CalculationResult::Ok(Decimal::ZERO); + } + let dividend = Dec64::new(dividend); + let divisor = Dec64::new(divisor); + + // Pre calculate the scale and the sign + let mut scale = (dividend.scale as i32) - (divisor.scale as i32); + let sign_negative = dividend.negative ^ divisor.negative; + + // Set up some variables for modification throughout + let mut require_unscale = false; + let mut quotient = Buf12::from_dec64(÷nd); + let divisor = Buf12::from_dec64(&divisor); + + // Branch depending on the complexity of the divisor + if divisor.data[2] | divisor.data[1] == 0 { + // We have a simple(r) divisor (32 bit) + let divisor32 = divisor.data[0]; + + // Remainder can only be 32 bits since the divisor is 32 bits. + let mut remainder = quotient.div32(divisor32); + let mut power_scale = 0; + + // Figure out how to apply the remainder (i.e. we may have performed something like 10/3 or 8/5) + loop { + // Remainder is 0 so we have a simple situation + if remainder == 0 { + // If the scale is positive then we're actually done + if scale >= 0 { + break; + } + power_scale = 9usize.min((-scale) as usize); + } else { + // We may need to normalize later, so set the flag appropriately + require_unscale = true; + + // We have a remainder so we effectively want to try to adjust the quotient and add + // the remainder into the quotient. We do this below, however first of all we want + // to try to avoid overflowing so we do that check first. + let will_overflow = if scale == MAX_PRECISION_I32 { + true + } else { + // Figure out how much we can scale by + if let Some(s) = quotient.find_scale(scale) { + power_scale = s; + } else { + return CalculationResult::Overflow; + } + // If it comes back as 0 (i.e. 10^0 = 1) then we're going to overflow since + // we're doing nothing. + power_scale == 0 + }; + if will_overflow { + // No more scaling can be done, but remainder is non-zero so we round if necessary. + let tmp = remainder << 1; + let round = if tmp < remainder { + // We round if we wrapped around + true + } else if tmp >= divisor32 { + // If we're greater than the divisor (i.e. underflow) + // or if there is a lo bit set, we round + tmp > divisor32 || (quotient.data[0] & 0x1) > 0 + } else { + false + }; + + // If we need to round, try to do so. + if round { + if let Ok(new_scale) = round_up(&mut quotient, scale) { + scale = new_scale; + } else { + // Overflowed + return CalculationResult::Overflow; + } + } + break; + } + } + + // Do some scaling + let power = POWERS_10[power_scale]; + scale += power_scale as i32; + // Increase the quotient by the power that was looked up + let overflow = increase_scale(&mut quotient, power as u64); + if overflow > 0 { + return CalculationResult::Overflow; + } + + let remainder_scaled = (remainder as u64) * (power as u64); + let remainder_quotient = (remainder_scaled / (divisor32 as u64)) as u32; + remainder = (remainder_scaled - remainder_quotient as u64 * divisor32 as u64) as u32; + if let Err(DivError::Overflow) = quotient.add32(remainder_quotient) { + if let Ok(adj) = unscale_from_overflow(&mut quotient, scale, remainder != 0) { + scale = adj; + } else { + // Still overflowing + return CalculationResult::Overflow; + } + break; + } + } + } else { + // We have a divisor greater than 32 bits. Both of these share some quick calculation wins + // so we'll do those before branching into separate logic. + // The win we can do is shifting the bits to the left as much as possible. We do this to both + // the dividend and the divisor to ensure the quotient is not changed. + // As a simple contrived example: if we have 4 / 2 then we could bit shift all the way to the + // left meaning that the lo portion would have nothing inside of it. Of course, shifting these + // left one has the same result (8/4) etc. + // The advantage is that we may be able to write off lower portions of the number making things + // easier. + let mut power_scale = if divisor.data[2] == 0 { + divisor.data[1].leading_zeros() + } else { + divisor.data[2].leading_zeros() + } as usize; + let mut remainder = Buf16::zero(); + remainder.set_low64(quotient.low64() << power_scale); + let tmp_high = ((quotient.data[1] as u64) + ((quotient.data[2] as u64) << 32)) >> (32 - power_scale); + remainder.set_high64(tmp_high); + + // Work out the divisor after it's shifted + let divisor64 = divisor.low64() << power_scale; + // Check if the divisor is 64 bit or the full 96 bits + if divisor.data[2] == 0 { + // It's 64 bits + quotient.data[2] = 0; + + // Calc mid/lo by shifting accordingly + let rem_lo = remainder.data[0]; + remainder.data[0] = remainder.data[1]; + remainder.data[1] = remainder.data[2]; + remainder.data[2] = remainder.data[3]; + quotient.data[1] = remainder.partial_divide_64(divisor64); + + remainder.data[2] = remainder.data[1]; + remainder.data[1] = remainder.data[0]; + remainder.data[0] = rem_lo; + quotient.data[0] = remainder.partial_divide_64(divisor64); + + loop { + let rem_low64 = remainder.low64(); + if rem_low64 == 0 { + // If the scale is positive then we're actually done + if scale >= 0 { + break; + } + power_scale = 9usize.min((-scale) as usize); + } else { + // We may need to normalize later, so set the flag appropriately + require_unscale = true; + + // We have a remainder so we effectively want to try to adjust the quotient and add + // the remainder into the quotient. We do this below, however first of all we want + // to try to avoid overflowing so we do that check first. + let will_overflow = if scale == MAX_PRECISION_I32 { + true + } else { + // Figure out how much we can scale by + if let Some(s) = quotient.find_scale(scale) { + power_scale = s; + } else { + return CalculationResult::Overflow; + } + // If it comes back as 0 (i.e. 10^0 = 1) then we're going to overflow since + // we're doing nothing. + power_scale == 0 + }; + if will_overflow { + // No more scaling can be done, but remainder is non-zero so we round if necessary. + let mut tmp = remainder.low64(); + let round = if (tmp as i64) < 0 { + // We round if we wrapped around + true + } else { + tmp <<= 1; + if tmp > divisor64 { + true + } else { + tmp == divisor64 && quotient.data[0] & 0x1 != 0 + } + }; + + // If we need to round, try to do so. + if round { + if let Ok(new_scale) = round_up(&mut quotient, scale) { + scale = new_scale; + } else { + // Overflowed + return CalculationResult::Overflow; + } + } + break; + } + } + + // Do some scaling + let power = POWERS_10[power_scale]; + scale += power_scale as i32; + + // Increase the quotient by the power that was looked up + let overflow = increase_scale(&mut quotient, power as u64); + if overflow > 0 { + return CalculationResult::Overflow; + } + increase_scale64(&mut remainder, power as u64); + + let tmp = remainder.partial_divide_64(divisor64); + if let Err(DivError::Overflow) = quotient.add32(tmp) { + if let Ok(adj) = unscale_from_overflow(&mut quotient, scale, remainder.low64() != 0) { + scale = adj; + } else { + // Still overflowing + return CalculationResult::Overflow; + } + break; + } + } + } else { + // It's 96 bits + // Start by finishing the shift left + let divisor_mid = divisor.data[1]; + let divisor_hi = divisor.data[2]; + let mut divisor = divisor; + divisor.set_low64(divisor64); + divisor.data[2] = ((divisor_mid as u64 + ((divisor_hi as u64) << 32)) >> (32 - power_scale)) as u32; + + let quo = remainder.partial_divide_96(&divisor); + quotient.set_low64(quo as u64); + quotient.data[2] = 0; + + loop { + let mut rem_low64 = remainder.low64(); + if rem_low64 == 0 && remainder.data[2] == 0 { + // If the scale is positive then we're actually done + if scale >= 0 { + break; + } + power_scale = 9usize.min((-scale) as usize); + } else { + // We may need to normalize later, so set the flag appropriately + require_unscale = true; + + // We have a remainder so we effectively want to try to adjust the quotient and add + // the remainder into the quotient. We do this below, however first of all we want + // to try to avoid overflowing so we do that check first. + let will_overflow = if scale == MAX_PRECISION_I32 { + true + } else { + // Figure out how much we can scale by + if let Some(s) = quotient.find_scale(scale) { + power_scale = s; + } else { + return CalculationResult::Overflow; + } + // If it comes back as 0 (i.e. 10^0 = 1) then we're going to overflow since + // we're doing nothing. + power_scale == 0 + }; + if will_overflow { + // No more scaling can be done, but remainder is non-zero so we round if necessary. + let round = if (remainder.data[2] as i32) < 0 { + // We round if we wrapped around + true + } else { + let tmp = remainder.data[1] >> 31; + rem_low64 <<= 1; + remainder.set_low64(rem_low64); + remainder.data[2] = (&remainder.data[2] << 1) + tmp; + + match remainder.data[2].cmp(&divisor.data[2]) { + Ordering::Less => false, + Ordering::Equal => { + let divisor_low64 = divisor.low64(); + if rem_low64 > divisor_low64 { + true + } else { + rem_low64 == divisor_low64 && (quotient.data[0] & 1) != 0 + } + } + Ordering::Greater => true, + } + }; + + // If we need to round, try to do so. + if round { + if let Ok(new_scale) = round_up(&mut quotient, scale) { + scale = new_scale; + } else { + // Overflowed + return CalculationResult::Overflow; + } + } + break; + } + } + + // Do some scaling + let power = POWERS_10[power_scale]; + scale += power_scale as i32; + + // Increase the quotient by the power that was looked up + let overflow = increase_scale(&mut quotient, power as u64); + if overflow > 0 { + return CalculationResult::Overflow; + } + let mut tmp_remainder = Buf12 { + data: [remainder.data[0], remainder.data[1], remainder.data[2]], + }; + let overflow = increase_scale(&mut tmp_remainder, power as u64); + remainder.data[0] = tmp_remainder.data[0]; + remainder.data[1] = tmp_remainder.data[1]; + remainder.data[2] = tmp_remainder.data[2]; + remainder.data[3] = overflow; + + let tmp = remainder.partial_divide_96(&divisor); + if let Err(DivError::Overflow) = quotient.add32(tmp) { + if let Ok(adj) = + unscale_from_overflow(&mut quotient, scale, (remainder.low64() | remainder.high64()) != 0) + { + scale = adj; + } else { + // Still overflowing + return CalculationResult::Overflow; + } + break; + } + } + } + } + if require_unscale { + scale = unscale(&mut quotient, scale); + } + CalculationResult::Ok(Decimal::from_parts( + quotient.data[0], + quotient.data[1], + quotient.data[2], + sign_negative, + scale as u32, + )) +} + +// Multiply num by power (multiple of 10). Power must be 32 bits. +// Returns the overflow, if any +fn increase_scale(num: &mut Buf12, power: u64) -> u32 { + let mut tmp = (num.data[0] as u64) * power; + num.data[0] = tmp as u32; + tmp >>= 32; + tmp += (num.data[1] as u64) * power; + num.data[1] = tmp as u32; + tmp >>= 32; + tmp += (num.data[2] as u64) * power; + num.data[2] = tmp as u32; + (tmp >> 32) as u32 +} + +// Multiply num by power (multiple of 10). Power must be 32 bits. +fn increase_scale64(num: &mut Buf16, power: u64) { + let mut tmp = (num.data[0] as u64) * power; + num.data[0] = tmp as u32; + tmp >>= 32; + tmp += (num.data[1] as u64) * power; + num.set_mid64(tmp) +} + +// Adjust the number to deal with an overflow. This function follows being scaled up (i.e. multiplied +// by 10, so this effectively tries to reverse that by dividing by 10 then feeding in the high bit +// to undo the overflow and rounding instead. +// Returns the updated scale. +fn unscale_from_overflow(num: &mut Buf12, scale: i32, sticky: bool) -> Result<i32, DivError> { + let scale = scale - 1; + if scale < 0 { + return Err(DivError::Overflow); + } + + // This function is called when the hi portion has "overflowed" upon adding one and has wrapped + // back around to 0. Consequently, we need to "feed" that back in, but also rescaling down + // to reverse out the overflow. + const HIGH_BIT: u64 = 0x1_0000_0000; + num.data[2] = (HIGH_BIT / 10) as u32; + + // Calc the mid + let mut tmp = ((HIGH_BIT % 10) << 32) + (num.data[1] as u64); + let mut val = (tmp / 10) as u32; + num.data[1] = val; + + // Calc the lo using a similar method + tmp = ((tmp - (val as u64) * 10) << 32) + (num.data[0] as u64); + val = (tmp / 10) as u32; + num.data[0] = val; + + // Work out the remainder, and round if we have one (since it doesn't fit) + let remainder = (tmp - (val as u64) * 10) as u32; + if remainder > 5 || (remainder == 5 && (sticky || num.data[0] & 0x1 > 0)) { + let _ = num.add32(1); + } + Ok(scale) +} + +#[inline] +fn round_up(num: &mut Buf12, scale: i32) -> Result<i32, DivError> { + let low64 = num.low64().wrapping_add(1); + num.set_low64(low64); + if low64 != 0 { + return Ok(scale); + } + let hi = num.data[2].wrapping_add(1); + num.data[2] = hi; + if hi != 0 { + return Ok(scale); + } + unscale_from_overflow(num, scale, true) +} + +fn unscale(num: &mut Buf12, scale: i32) -> i32 { + // Since 10 = 2 * 5, there must be a factor of 2 for every power of 10 we can extract. + // We use this as a quick test on whether to try a given power. + let mut scale = scale; + while num.data[0] == 0 && scale >= 8 && num.div32_const(100000000) { + scale -= 8; + } + + if (num.data[0] & 0xF) == 0 && scale >= 4 && num.div32_const(10000) { + scale -= 4; + } + + if (num.data[0] & 0x3) == 0 && scale >= 2 && num.div32_const(100) { + scale -= 2; + } + + if (num.data[0] & 0x1) == 0 && scale >= 1 && num.div32_const(10) { + scale -= 1; + } + scale +} diff --git a/third_party/rust/rust_decimal/src/ops/legacy.rs b/third_party/rust/rust_decimal/src/ops/legacy.rs new file mode 100644 index 0000000000..49f39814f8 --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/legacy.rs @@ -0,0 +1,843 @@ +use crate::{ + constants::{MAX_PRECISION_U32, POWERS_10, U32_MASK}, + decimal::{CalculationResult, Decimal}, + ops::array::{ + add_by_internal, cmp_internal, div_by_u32, is_all_zero, mul_by_u32, mul_part, rescale_internal, shl1_internal, + }, +}; + +use core::cmp::Ordering; +use num_traits::Zero; + +pub(crate) fn add_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + // Convert to the same scale + let mut my = d1.mantissa_array3(); + let mut my_scale = d1.scale(); + let mut ot = d2.mantissa_array3(); + let mut other_scale = d2.scale(); + rescale_to_maximum_scale(&mut my, &mut my_scale, &mut ot, &mut other_scale); + let mut final_scale = my_scale.max(other_scale); + + // Add the items together + let my_negative = d1.is_sign_negative(); + let other_negative = d2.is_sign_negative(); + let mut negative = false; + let carry; + if !(my_negative ^ other_negative) { + negative = my_negative; + carry = add_by_internal3(&mut my, &ot); + } else { + let cmp = cmp_internal(&my, &ot); + // -x + y + // if x > y then it's negative (i.e. -2 + 1) + match cmp { + Ordering::Less => { + negative = other_negative; + sub_by_internal3(&mut ot, &my); + my[0] = ot[0]; + my[1] = ot[1]; + my[2] = ot[2]; + } + Ordering::Greater => { + negative = my_negative; + sub_by_internal3(&mut my, &ot); + } + Ordering::Equal => { + // -2 + 2 + my[0] = 0; + my[1] = 0; + my[2] = 0; + } + } + carry = 0; + } + + // If we have a carry we underflowed. + // We need to lose some significant digits (if possible) + if carry > 0 { + if final_scale == 0 { + return CalculationResult::Overflow; + } + + // Copy it over to a temp array for modification + let mut temp = [my[0], my[1], my[2], carry]; + while final_scale > 0 && temp[3] != 0 { + div_by_u32(&mut temp, 10); + final_scale -= 1; + } + + // If we still have a carry bit then we overflowed + if temp[3] > 0 { + return CalculationResult::Overflow; + } + + // Copy it back - we're done + my[0] = temp[0]; + my[1] = temp[1]; + my[2] = temp[2]; + } + + CalculationResult::Ok(Decimal::from_parts(my[0], my[1], my[2], negative, final_scale)) +} + +pub(crate) fn sub_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + add_impl(d1, &(-*d2)) +} + +pub(crate) fn div_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + if d2.is_zero() { + return CalculationResult::DivByZero; + } + if d1.is_zero() { + return CalculationResult::Ok(Decimal::zero()); + } + + let dividend = d1.mantissa_array3(); + let divisor = d2.mantissa_array3(); + let mut quotient = [0u32, 0u32, 0u32]; + let mut quotient_scale: i32 = d1.scale() as i32 - d2.scale() as i32; + + // We supply an extra overflow word for each of the dividend and the remainder + let mut working_quotient = [dividend[0], dividend[1], dividend[2], 0u32]; + let mut working_remainder = [0u32, 0u32, 0u32, 0u32]; + let mut working_scale = quotient_scale; + let mut remainder_scale = quotient_scale; + let mut underflow; + + loop { + div_internal(&mut working_quotient, &mut working_remainder, &divisor); + underflow = add_with_scale_internal( + &mut quotient, + &mut quotient_scale, + &mut working_quotient, + &mut working_scale, + ); + + // Multiply the remainder by 10 + let mut overflow = 0; + for part in working_remainder.iter_mut() { + let (lo, hi) = mul_part(*part, 10, overflow); + *part = lo; + overflow = hi; + } + // Copy temp remainder into the temp quotient section + working_quotient.copy_from_slice(&working_remainder); + + remainder_scale += 1; + working_scale = remainder_scale; + + if underflow || is_all_zero(&working_remainder) { + break; + } + } + + // If we have a really big number try to adjust the scale to 0 + while quotient_scale < 0 { + copy_array_diff_lengths(&mut working_quotient, "ient); + working_quotient[3] = 0; + working_remainder.iter_mut().for_each(|x| *x = 0); + + // Mul 10 + let mut overflow = 0; + for part in &mut working_quotient { + let (lo, hi) = mul_part(*part, 10, overflow); + *part = lo; + overflow = hi; + } + for part in &mut working_remainder { + let (lo, hi) = mul_part(*part, 10, overflow); + *part = lo; + overflow = hi; + } + if working_quotient[3] == 0 && is_all_zero(&working_remainder) { + quotient_scale += 1; + quotient[0] = working_quotient[0]; + quotient[1] = working_quotient[1]; + quotient[2] = working_quotient[2]; + } else { + // Overflow + return CalculationResult::Overflow; + } + } + + if quotient_scale > 255 { + quotient[0] = 0; + quotient[1] = 0; + quotient[2] = 0; + quotient_scale = 0; + } + + let mut quotient_negative = d1.is_sign_negative() ^ d2.is_sign_negative(); + + // Check for underflow + let mut final_scale: u32 = quotient_scale as u32; + if final_scale > MAX_PRECISION_U32 { + let mut remainder = 0; + + // Division underflowed. We must remove some significant digits over using + // an invalid scale. + while final_scale > MAX_PRECISION_U32 && !is_all_zero("ient) { + remainder = div_by_u32(&mut quotient, 10); + final_scale -= 1; + } + if final_scale > MAX_PRECISION_U32 { + // Result underflowed so set to zero + final_scale = 0; + quotient_negative = false; + } else if remainder >= 5 { + for part in &mut quotient { + if remainder == 0 { + break; + } + let digit: u64 = u64::from(*part) + 1; + remainder = if digit > 0xFFFF_FFFF { 1 } else { 0 }; + *part = (digit & 0xFFFF_FFFF) as u32; + } + } + } + + CalculationResult::Ok(Decimal::from_parts( + quotient[0], + quotient[1], + quotient[2], + quotient_negative, + final_scale, + )) +} + +pub(crate) fn mul_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + // Early exit if either is zero + if d1.is_zero() || d2.is_zero() { + return CalculationResult::Ok(Decimal::zero()); + } + + // We are only resulting in a negative if we have mismatched signs + let negative = d1.is_sign_negative() ^ d2.is_sign_negative(); + + // We get the scale of the result by adding the operands. This may be too big, however + // we'll correct later + let mut final_scale = d1.scale() + d2.scale(); + + // First of all, if ONLY the lo parts of both numbers is filled + // then we can simply do a standard 64 bit calculation. It's a minor + // optimization however prevents the need for long form multiplication + let my = d1.mantissa_array3(); + let ot = d2.mantissa_array3(); + if my[1] == 0 && my[2] == 0 && ot[1] == 0 && ot[2] == 0 { + // Simply multiplication + let mut u64_result = u64_to_array(u64::from(my[0]) * u64::from(ot[0])); + + // If we're above max precision then this is a very small number + if final_scale > MAX_PRECISION_U32 { + final_scale -= MAX_PRECISION_U32; + + // If the number is above 19 then this will equate to zero. + // This is because the max value in 64 bits is 1.84E19 + if final_scale > 19 { + return CalculationResult::Ok(Decimal::zero()); + } + + let mut rem_lo = 0; + let mut power; + if final_scale > 9 { + // Since 10^10 doesn't fit into u32, we divide by 10^10/4 + // and multiply the next divisor by 4. + rem_lo = div_by_u32(&mut u64_result, 2_500_000_000); + power = POWERS_10[final_scale as usize - 10] << 2; + } else { + power = POWERS_10[final_scale as usize]; + } + + // Divide fits in 32 bits + let rem_hi = div_by_u32(&mut u64_result, power); + + // Round the result. Since the divisor is a power of 10 + // we check to see if the remainder is >= 1/2 divisor + power >>= 1; + if rem_hi >= power && (rem_hi > power || (rem_lo | (u64_result[0] & 0x1)) != 0) { + u64_result[0] += 1; + } + + final_scale = MAX_PRECISION_U32; + } + return CalculationResult::Ok(Decimal::from_parts( + u64_result[0], + u64_result[1], + 0, + negative, + final_scale, + )); + } + + // We're using some of the high bits, so we essentially perform + // long form multiplication. We compute the 9 partial products + // into a 192 bit result array. + // + // [my-h][my-m][my-l] + // x [ot-h][ot-m][ot-l] + // -------------------------------------- + // 1. [r-hi][r-lo] my-l * ot-l [0, 0] + // 2. [r-hi][r-lo] my-l * ot-m [0, 1] + // 3. [r-hi][r-lo] my-m * ot-l [1, 0] + // 4. [r-hi][r-lo] my-m * ot-m [1, 1] + // 5. [r-hi][r-lo] my-l * ot-h [0, 2] + // 6. [r-hi][r-lo] my-h * ot-l [2, 0] + // 7. [r-hi][r-lo] my-m * ot-h [1, 2] + // 8. [r-hi][r-lo] my-h * ot-m [2, 1] + // 9.[r-hi][r-lo] my-h * ot-h [2, 2] + let mut product = [0u32, 0u32, 0u32, 0u32, 0u32, 0u32]; + + // We can perform a minor short circuit here. If the + // high portions are both 0 then we can skip portions 5-9 + let to = if my[2] == 0 && ot[2] == 0 { 2 } else { 3 }; + + for (my_index, my_item) in my.iter().enumerate().take(to) { + for (ot_index, ot_item) in ot.iter().enumerate().take(to) { + let (mut rlo, mut rhi) = mul_part(*my_item, *ot_item, 0); + + // Get the index for the lo portion of the product + for prod in product.iter_mut().skip(my_index + ot_index) { + let (res, overflow) = add_part(rlo, *prod); + *prod = res; + + // If we have something in rhi from before then promote that + if rhi > 0 { + // If we overflowed in the last add, add that with rhi + if overflow > 0 { + let (nlo, nhi) = add_part(rhi, overflow); + rlo = nlo; + rhi = nhi; + } else { + rlo = rhi; + rhi = 0; + } + } else if overflow > 0 { + rlo = overflow; + rhi = 0; + } else { + break; + } + + // If nothing to do next round then break out + if rlo == 0 { + break; + } + } + } + } + + // If our result has used up the high portion of the product + // then we either have an overflow or an underflow situation + // Overflow will occur if we can't scale it back, whereas underflow + // with kick in rounding + let mut remainder = 0; + while final_scale > 0 && (product[3] != 0 || product[4] != 0 || product[5] != 0) { + remainder = div_by_u32(&mut product, 10u32); + final_scale -= 1; + } + + // Round up the carry if we need to + if remainder >= 5 { + for part in product.iter_mut() { + if remainder == 0 { + break; + } + let digit: u64 = u64::from(*part) + 1; + remainder = if digit > 0xFFFF_FFFF { 1 } else { 0 }; + *part = (digit & 0xFFFF_FFFF) as u32; + } + } + + // If we're still above max precision then we'll try again to + // reduce precision - we may be dealing with a limit of "0" + if final_scale > MAX_PRECISION_U32 { + // We're in an underflow situation + // The easiest way to remove precision is to divide off the result + while final_scale > MAX_PRECISION_U32 && !is_all_zero(&product) { + div_by_u32(&mut product, 10); + final_scale -= 1; + } + // If we're still at limit then we can't represent any + // significant decimal digits and will return an integer only + // Can also be invoked while representing 0. + if final_scale > MAX_PRECISION_U32 { + final_scale = 0; + } + } else if !(product[3] == 0 && product[4] == 0 && product[5] == 0) { + // We're in an overflow situation - we're within our precision bounds + // but still have bits in overflow + return CalculationResult::Overflow; + } + + CalculationResult::Ok(Decimal::from_parts( + product[0], + product[1], + product[2], + negative, + final_scale, + )) +} + +pub(crate) fn rem_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + if d2.is_zero() { + return CalculationResult::DivByZero; + } + if d1.is_zero() { + return CalculationResult::Ok(Decimal::zero()); + } + + // Rescale so comparable + let initial_scale = d1.scale(); + let mut quotient = d1.mantissa_array3(); + let mut quotient_scale = initial_scale; + let mut divisor = d2.mantissa_array3(); + let mut divisor_scale = d2.scale(); + rescale_to_maximum_scale(&mut quotient, &mut quotient_scale, &mut divisor, &mut divisor_scale); + + // Working is the remainder + the quotient + // We use an aligned array since we'll be using it a lot. + let mut working_quotient = [quotient[0], quotient[1], quotient[2], 0u32]; + let mut working_remainder = [0u32, 0u32, 0u32, 0u32]; + div_internal(&mut working_quotient, &mut working_remainder, &divisor); + + // Round if necessary. This is for semantic correctness, but could feasibly be removed for + // performance improvements. + if quotient_scale > initial_scale { + let mut working = [ + working_remainder[0], + working_remainder[1], + working_remainder[2], + working_remainder[3], + ]; + while quotient_scale > initial_scale { + if div_by_u32(&mut working, 10) > 0 { + break; + } + quotient_scale -= 1; + working_remainder.copy_from_slice(&working); + } + } + + CalculationResult::Ok(Decimal::from_parts( + working_remainder[0], + working_remainder[1], + working_remainder[2], + d1.is_sign_negative(), + quotient_scale, + )) +} + +pub(crate) fn cmp_impl(d1: &Decimal, d2: &Decimal) -> Ordering { + // Quick exit if major differences + if d1.is_zero() && d2.is_zero() { + return Ordering::Equal; + } + let self_negative = d1.is_sign_negative(); + let other_negative = d2.is_sign_negative(); + if self_negative && !other_negative { + return Ordering::Less; + } else if !self_negative && other_negative { + return Ordering::Greater; + } + + // If we have 1.23 and 1.2345 then we have + // 123 scale 2 and 12345 scale 4 + // We need to convert the first to + // 12300 scale 4 so we can compare equally + let left: &Decimal; + let right: &Decimal; + if self_negative && other_negative { + // Both are negative, so reverse cmp + left = d2; + right = d1; + } else { + left = d1; + right = d2; + } + let mut left_scale = left.scale(); + let mut right_scale = right.scale(); + let mut left_raw = left.mantissa_array3(); + let mut right_raw = right.mantissa_array3(); + + if left_scale == right_scale { + // Fast path for same scale + if left_raw[2] != right_raw[2] { + return left_raw[2].cmp(&right_raw[2]); + } + if left_raw[1] != right_raw[1] { + return left_raw[1].cmp(&right_raw[1]); + } + return left_raw[0].cmp(&right_raw[0]); + } + + // Rescale and compare + rescale_to_maximum_scale(&mut left_raw, &mut left_scale, &mut right_raw, &mut right_scale); + cmp_internal(&left_raw, &right_raw) +} + +#[inline] +fn add_part(left: u32, right: u32) -> (u32, u32) { + let added = u64::from(left) + u64::from(right); + ((added & U32_MASK) as u32, (added >> 32 & U32_MASK) as u32) +} + +#[inline(always)] +fn sub_by_internal3(value: &mut [u32; 3], by: &[u32; 3]) { + let mut overflow = 0; + let vl = value.len(); + for i in 0..vl { + let part = (0x1_0000_0000u64 + u64::from(value[i])) - (u64::from(by[i]) + overflow); + value[i] = part as u32; + overflow = 1 - (part >> 32); + } +} + +fn div_internal(quotient: &mut [u32; 4], remainder: &mut [u32; 4], divisor: &[u32; 3]) { + // There are a couple of ways to do division on binary numbers: + // 1. Using long division + // 2. Using the complement method + // ref: http://paulmason.me/dividing-binary-numbers-part-2/ + // The complement method basically keeps trying to subtract the + // divisor until it can't anymore and placing the rest in remainder. + let mut complement = [ + divisor[0] ^ 0xFFFF_FFFF, + divisor[1] ^ 0xFFFF_FFFF, + divisor[2] ^ 0xFFFF_FFFF, + 0xFFFF_FFFF, + ]; + + // Add one onto the complement + add_one_internal4(&mut complement); + + // Make sure the remainder is 0 + remainder.iter_mut().for_each(|x| *x = 0); + + // If we have nothing in our hi+ block then shift over till we do + let mut blocks_to_process = 0; + while blocks_to_process < 4 && quotient[3] == 0 { + // memcpy would be useful here + quotient[3] = quotient[2]; + quotient[2] = quotient[1]; + quotient[1] = quotient[0]; + quotient[0] = 0; + + // Increment the counter + blocks_to_process += 1; + } + + // Let's try and do the addition... + let mut block = blocks_to_process << 5; + let mut working = [0u32, 0u32, 0u32, 0u32]; + while block < 128 { + // << 1 for quotient AND remainder. Moving the carry from the quotient to the bottom of the + // remainder. + let carry = shl1_internal(quotient, 0); + shl1_internal(remainder, carry); + + // Copy the remainder of working into sub + working.copy_from_slice(remainder); + + // Add the remainder with the complement + add_by_internal(&mut working, &complement); + + // Check for the significant bit - move over to the quotient + // as necessary + if (working[3] & 0x8000_0000) == 0 { + remainder.copy_from_slice(&working); + quotient[0] |= 1; + } + + // Increment our pointer + block += 1; + } +} + +#[inline] +fn copy_array_diff_lengths(into: &mut [u32], from: &[u32]) { + for i in 0..into.len() { + if i >= from.len() { + break; + } + into[i] = from[i]; + } +} + +#[inline] +fn add_one_internal4(value: &mut [u32; 4]) -> u32 { + let mut carry: u64 = 1; // Start with one, since adding one + let mut sum: u64; + for i in value.iter_mut() { + sum = (*i as u64) + carry; + *i = (sum & U32_MASK) as u32; + carry = sum >> 32; + } + + carry as u32 +} + +#[inline] +fn add_by_internal3(value: &mut [u32; 3], by: &[u32; 3]) -> u32 { + let mut carry: u32 = 0; + let bl = by.len(); + for i in 0..bl { + let res1 = value[i].overflowing_add(by[i]); + let res2 = res1.0.overflowing_add(carry); + value[i] = res2.0; + carry = (res1.1 | res2.1) as u32; + } + carry +} + +#[inline] +const fn u64_to_array(value: u64) -> [u32; 2] { + [(value & U32_MASK) as u32, (value >> 32 & U32_MASK) as u32] +} + +fn add_with_scale_internal( + quotient: &mut [u32; 3], + quotient_scale: &mut i32, + working_quotient: &mut [u32; 4], + working_scale: &mut i32, +) -> bool { + // Add quotient and the working (i.e. quotient = quotient + working) + if is_all_zero(quotient) { + // Quotient is zero so we can just copy the working quotient in directly + // First, make sure they are both 96 bit. + while working_quotient[3] != 0 { + div_by_u32(working_quotient, 10); + *working_scale -= 1; + } + copy_array_diff_lengths(quotient, working_quotient); + *quotient_scale = *working_scale; + return false; + } + + if is_all_zero(working_quotient) { + return false; + } + + // We have ensured that our working is not zero so we should do the addition + + // If our two quotients are different then + // try to scale down the one with the bigger scale + let mut temp3 = [0u32, 0u32, 0u32]; + let mut temp4 = [0u32, 0u32, 0u32, 0u32]; + if *quotient_scale != *working_scale { + // TODO: Remove necessity for temp (without performance impact) + fn div_by_10<const N: usize>(target: &mut [u32], temp: &mut [u32; N], scale: &mut i32, target_scale: i32) { + // Copy to the temp array + temp.copy_from_slice(target); + // divide by 10 until target scale is reached + while *scale > target_scale { + let remainder = div_by_u32(temp, 10); + if remainder == 0 { + *scale -= 1; + target.copy_from_slice(temp); + } else { + break; + } + } + } + + if *quotient_scale < *working_scale { + div_by_10(working_quotient, &mut temp4, working_scale, *quotient_scale); + } else { + div_by_10(quotient, &mut temp3, quotient_scale, *working_scale); + } + } + + // If our two quotients are still different then + // try to scale up the smaller scale + if *quotient_scale != *working_scale { + // TODO: Remove necessity for temp (without performance impact) + fn mul_by_10(target: &mut [u32], temp: &mut [u32], scale: &mut i32, target_scale: i32) { + temp.copy_from_slice(target); + let mut overflow = 0; + // Multiply by 10 until target scale reached or overflow + while *scale < target_scale && overflow == 0 { + overflow = mul_by_u32(temp, 10); + if overflow == 0 { + // Still no overflow + *scale += 1; + target.copy_from_slice(temp); + } + } + } + + if *quotient_scale > *working_scale { + mul_by_10(working_quotient, &mut temp4, working_scale, *quotient_scale); + } else { + mul_by_10(quotient, &mut temp3, quotient_scale, *working_scale); + } + } + + // If our two quotients are still different then + // try to scale down the one with the bigger scale + // (ultimately losing significant digits) + if *quotient_scale != *working_scale { + // TODO: Remove necessity for temp (without performance impact) + fn div_by_10_lossy<const N: usize>( + target: &mut [u32], + temp: &mut [u32; N], + scale: &mut i32, + target_scale: i32, + ) { + temp.copy_from_slice(target); + // divide by 10 until target scale is reached + while *scale > target_scale { + div_by_u32(temp, 10); + *scale -= 1; + target.copy_from_slice(temp); + } + } + if *quotient_scale < *working_scale { + div_by_10_lossy(working_quotient, &mut temp4, working_scale, *quotient_scale); + } else { + div_by_10_lossy(quotient, &mut temp3, quotient_scale, *working_scale); + } + } + + // If quotient or working are zero we have an underflow condition + if is_all_zero(quotient) || is_all_zero(working_quotient) { + // Underflow + return true; + } else { + // Both numbers have the same scale and can be added. + // We just need to know whether we can fit them in + let mut underflow = false; + let mut temp = [0u32, 0u32, 0u32]; + while !underflow { + temp.copy_from_slice(quotient); + + // Add the working quotient + let overflow = add_by_internal(&mut temp, working_quotient); + if overflow == 0 { + // addition was successful + quotient.copy_from_slice(&temp); + break; + } else { + // addition overflowed - remove significant digits and try again + div_by_u32(quotient, 10); + *quotient_scale -= 1; + div_by_u32(working_quotient, 10); + *working_scale -= 1; + // Check for underflow + underflow = is_all_zero(quotient) || is_all_zero(working_quotient); + } + } + if underflow { + return true; + } + } + false +} + +/// Rescales the given decimals to equivalent scales. +/// It will firstly try to scale both the left and the right side to +/// the maximum scale of left/right. If it is unable to do that it +/// will try to reduce the accuracy of the other argument. +/// e.g. with 1.23 and 2.345 it'll rescale the first arg to 1.230 +#[inline(always)] +fn rescale_to_maximum_scale(left: &mut [u32; 3], left_scale: &mut u32, right: &mut [u32; 3], right_scale: &mut u32) { + if left_scale == right_scale { + // Nothing to do + return; + } + + if is_all_zero(left) { + *left_scale = *right_scale; + return; + } else if is_all_zero(right) { + *right_scale = *left_scale; + return; + } + + if left_scale > right_scale { + rescale_internal(right, right_scale, *left_scale); + if right_scale != left_scale { + rescale_internal(left, left_scale, *right_scale); + } + } else { + rescale_internal(left, left_scale, *right_scale); + if right_scale != left_scale { + rescale_internal(right, right_scale, *left_scale); + } + } +} + +#[cfg(test)] +mod test { + // Tests on private methods. + // + // All public tests should go under `tests/`. + + use super::*; + use crate::prelude::*; + + #[test] + fn it_can_rescale_to_maximum_scale() { + fn extract(value: &str) -> ([u32; 3], u32) { + let v = Decimal::from_str(value).unwrap(); + (v.mantissa_array3(), v.scale()) + } + + let tests = &[ + ("1", "1", "1", "1"), + ("1", "1.0", "1.0", "1.0"), + ("1", "1.00000", "1.00000", "1.00000"), + ("1", "1.0000000000", "1.0000000000", "1.0000000000"), + ( + "1", + "1.00000000000000000000", + "1.00000000000000000000", + "1.00000000000000000000", + ), + ("1.1", "1.1", "1.1", "1.1"), + ("1.1", "1.10000", "1.10000", "1.10000"), + ("1.1", "1.1000000000", "1.1000000000", "1.1000000000"), + ( + "1.1", + "1.10000000000000000000", + "1.10000000000000000000", + "1.10000000000000000000", + ), + ( + "0.6386554621848739495798319328", + "11.815126050420168067226890757", + "0.638655462184873949579831933", + "11.815126050420168067226890757", + ), + ( + "0.0872727272727272727272727272", // Scale 28 + "843.65000000", // Scale 8 + "0.0872727272727272727272727", // 25 + "843.6500000000000000000000000", // 25 + ), + ]; + + for &(left_raw, right_raw, expected_left, expected_right) in tests { + // Left = the value to rescale + // Right = the new scale we're scaling to + // Expected = the expected left value after rescale + let (expected_left, expected_lscale) = extract(expected_left); + let (expected_right, expected_rscale) = extract(expected_right); + + let (mut left, mut left_scale) = extract(left_raw); + let (mut right, mut right_scale) = extract(right_raw); + rescale_to_maximum_scale(&mut left, &mut left_scale, &mut right, &mut right_scale); + assert_eq!(left, expected_left); + assert_eq!(left_scale, expected_lscale); + assert_eq!(right, expected_right); + assert_eq!(right_scale, expected_rscale); + + // Also test the transitive case + let (mut left, mut left_scale) = extract(left_raw); + let (mut right, mut right_scale) = extract(right_raw); + rescale_to_maximum_scale(&mut right, &mut right_scale, &mut left, &mut left_scale); + assert_eq!(left, expected_left); + assert_eq!(left_scale, expected_lscale); + assert_eq!(right, expected_right); + assert_eq!(right_scale, expected_rscale); + } + } +} diff --git a/third_party/rust/rust_decimal/src/ops/mul.rs b/third_party/rust/rust_decimal/src/ops/mul.rs new file mode 100644 index 0000000000..b36729599d --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/mul.rs @@ -0,0 +1,168 @@ +use crate::constants::{BIG_POWERS_10, MAX_I64_SCALE, MAX_PRECISION_U32, U32_MAX}; +use crate::decimal::{CalculationResult, Decimal}; +use crate::ops::common::Buf24; + +pub(crate) fn mul_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + if d1.is_zero() || d2.is_zero() { + // We should think about this - does zero need to maintain precision? This treats it like + // an absolute which I think is ok, especially since we have is_zero() functions etc. + return CalculationResult::Ok(Decimal::ZERO); + } + + let mut scale = d1.scale() + d2.scale(); + let negative = d1.is_sign_negative() ^ d2.is_sign_negative(); + let mut product = Buf24::zero(); + + // See if we can optimize this calculation depending on whether the hi bits are set + if d1.hi() | d1.mid() == 0 { + if d2.hi() | d2.mid() == 0 { + // We're multiplying two 32 bit integers, so we can take some liberties to optimize this. + let mut low64 = d1.lo() as u64 * d2.lo() as u64; + if scale > MAX_PRECISION_U32 { + // We've exceeded maximum scale so we need to start reducing the precision (aka + // rounding) until we have something that fits. + // If we're too big then we effectively round to zero. + if scale > MAX_PRECISION_U32 + MAX_I64_SCALE { + return CalculationResult::Ok(Decimal::ZERO); + } + + scale -= MAX_PRECISION_U32 + 1; + let mut power = BIG_POWERS_10[scale as usize]; + + let tmp = low64 / power; + let remainder = low64 - tmp * power; + low64 = tmp; + + // Round the result. Since the divisor was a power of 10, it's always even. + power >>= 1; + if remainder >= power && (remainder > power || (low64 as u32 & 1) > 0) { + low64 += 1; + } + + scale = MAX_PRECISION_U32; + } + + // Early exit + return CalculationResult::Ok(Decimal::from_parts( + low64 as u32, + (low64 >> 32) as u32, + 0, + negative, + scale, + )); + } + + // We know that the left hand side is just 32 bits but the right hand side is either + // 64 or 96 bits. + mul_by_32bit_lhs(d1.lo() as u64, d2, &mut product); + } else if d2.mid() | d2.hi() == 0 { + // We know that the right hand side is just 32 bits. + mul_by_32bit_lhs(d2.lo() as u64, d1, &mut product); + } else { + // We know we're not dealing with simple 32 bit operands on either side. + // We compute and accumulate the 9 partial products using long multiplication + + // 1: ll * rl + let mut tmp = d1.lo() as u64 * d2.lo() as u64; + product.data[0] = tmp as u32; + + // 2: ll * rm + let mut tmp2 = (d1.lo() as u64 * d2.mid() as u64).wrapping_add(tmp >> 32); + + // 3: lm * rl + tmp = d1.mid() as u64 * d2.lo() as u64; + tmp = tmp.wrapping_add(tmp2); + product.data[1] = tmp as u32; + + // Detect if carry happened from the wrapping add + if tmp < tmp2 { + tmp2 = (tmp >> 32) | (1u64 << 32); + } else { + tmp2 = tmp >> 32; + } + + // 4: lm * rm + tmp = (d1.mid() as u64 * d2.mid() as u64) + tmp2; + + // If the high bit isn't set then we can stop here. Otherwise, we need to continue calculating + // using the high bits. + if (d1.hi() | d2.hi()) > 0 { + // 5. ll * rh + tmp2 = d1.lo() as u64 * d2.hi() as u64; + tmp = tmp.wrapping_add(tmp2); + // Detect if we carried + let mut tmp3 = if tmp < tmp2 { 1 } else { 0 }; + + // 6. lh * rl + tmp2 = d1.hi() as u64 * d2.lo() as u64; + tmp = tmp.wrapping_add(tmp2); + product.data[2] = tmp as u32; + // Detect if we carried + if tmp < tmp2 { + tmp3 += 1; + } + tmp2 = (tmp3 << 32) | (tmp >> 32); + + // 7. lm * rh + tmp = d1.mid() as u64 * d2.hi() as u64; + tmp = tmp.wrapping_add(tmp2); + // Check for carry + tmp3 = if tmp < tmp2 { 1 } else { 0 }; + + // 8. lh * rm + tmp2 = d1.hi() as u64 * d2.mid() as u64; + tmp = tmp.wrapping_add(tmp2); + product.data[3] = tmp as u32; + // Check for carry + if tmp < tmp2 { + tmp3 += 1; + } + tmp = (tmp3 << 32) | (tmp >> 32); + + // 9. lh * rh + product.set_high64(d1.hi() as u64 * d2.hi() as u64 + tmp); + } else { + product.set_mid64(tmp); + } + } + + // We may want to "rescale". This is the case if the mantissa is > 96 bits or if the scale + // exceeds the maximum precision. + let upper_word = product.upper_word(); + if upper_word > 2 || scale > MAX_PRECISION_U32 { + scale = if let Some(new_scale) = product.rescale(upper_word, scale) { + new_scale + } else { + return CalculationResult::Overflow; + } + } + + CalculationResult::Ok(Decimal::from_parts( + product.data[0], + product.data[1], + product.data[2], + negative, + scale, + )) +} + +#[inline(always)] +fn mul_by_32bit_lhs(d1: u64, d2: &Decimal, product: &mut Buf24) { + let mut tmp = d1 * d2.lo() as u64; + product.data[0] = tmp as u32; + tmp = (d1 * d2.mid() as u64).wrapping_add(tmp >> 32); + product.data[1] = tmp as u32; + tmp >>= 32; + + // If we're multiplying by a 96 bit integer then continue the calculation + if d2.hi() > 0 { + tmp = tmp.wrapping_add(d1 * d2.hi() as u64); + if tmp > U32_MAX { + product.set_mid64(tmp); + } else { + product.data[2] = tmp as u32; + } + } else { + product.data[2] = tmp as u32; + } +} diff --git a/third_party/rust/rust_decimal/src/ops/rem.rs b/third_party/rust/rust_decimal/src/ops/rem.rs new file mode 100644 index 0000000000..a79334e04b --- /dev/null +++ b/third_party/rust/rust_decimal/src/ops/rem.rs @@ -0,0 +1,285 @@ +use crate::constants::{MAX_I32_SCALE, MAX_PRECISION_I32, POWERS_10}; +use crate::decimal::{CalculationResult, Decimal}; +use crate::ops::common::{Buf12, Buf16, Buf24, Dec64}; + +pub(crate) fn rem_impl(d1: &Decimal, d2: &Decimal) -> CalculationResult { + if d2.is_zero() { + return CalculationResult::DivByZero; + } + if d1.is_zero() { + return CalculationResult::Ok(Decimal::ZERO); + } + + // We handle the structs a bit different here. Firstly, we ignore both the sign/scale of d2. + // This is because during a remainder operation we do not care about the sign of the divisor + // and only concern ourselves with that of the dividend. + let mut d1 = Dec64::new(d1); + let d2_scale = d2.scale(); + let mut d2 = Buf12::from_decimal(d2); + + let cmp = crate::ops::cmp::cmp_internal( + &d1, + &Dec64 { + negative: d1.negative, + scale: d2_scale, + hi: d2.hi(), + low64: d2.low64(), + }, + ); + match cmp { + core::cmp::Ordering::Equal => { + // Same numbers meaning that remainder is zero + return CalculationResult::Ok(Decimal::ZERO); + } + core::cmp::Ordering::Less => { + // d1 < d2, e.g. 1/2. This means that the result is the value of d1 + return CalculationResult::Ok(d1.to_decimal()); + } + core::cmp::Ordering::Greater => {} + } + + // At this point we know that the dividend > divisor and that they are both non-zero. + let mut scale = d1.scale as i32 - d2_scale as i32; + if scale > 0 { + // Scale up the divisor + loop { + let power = if scale >= MAX_I32_SCALE { + POWERS_10[9] + } else { + POWERS_10[scale as usize] + } as u64; + + let mut tmp = d2.lo() as u64 * power; + d2.set_lo(tmp as u32); + tmp >>= 32; + tmp = tmp.wrapping_add((d2.mid() as u64 + ((d2.hi() as u64) << 32)) * power); + d2.set_mid(tmp as u32); + d2.set_hi((tmp >> 32) as u32); + + // Keep scaling if there is more to go + scale -= MAX_I32_SCALE; + if scale <= 0 { + break; + } + } + scale = 0; + } + + loop { + // If the dividend is smaller than the divisor then try to scale that up first + if scale < 0 { + let mut quotient = Buf12 { + data: [d1.lo(), d1.mid(), d1.hi], + }; + loop { + // Figure out how much we can scale by + let power_scale; + if let Some(u) = quotient.find_scale(MAX_PRECISION_I32 + scale) { + if u >= POWERS_10.len() { + power_scale = 9; + } else { + power_scale = u; + } + } else { + return CalculationResult::Overflow; + }; + if power_scale == 0 { + break; + } + let power = POWERS_10[power_scale] as u64; + scale += power_scale as i32; + + let mut tmp = quotient.data[0] as u64 * power; + quotient.data[0] = tmp as u32; + tmp >>= 32; + quotient.set_high64(tmp.wrapping_add(quotient.high64().wrapping_mul(power))); + if power_scale != 9 { + break; + } + if scale >= 0 { + break; + } + } + d1.low64 = quotient.low64(); + d1.hi = quotient.data[2]; + d1.scale = d2_scale; + } + + // if the high portion is empty then return the modulus of the bottom portion + if d1.hi == 0 { + d1.low64 %= d2.low64(); + return CalculationResult::Ok(d1.to_decimal()); + } else if (d2.mid() | d2.hi()) == 0 { + let mut tmp = d1.high64(); + tmp = ((tmp % d2.lo() as u64) << 32) | (d1.lo() as u64); + d1.low64 = tmp % d2.lo() as u64; + d1.hi = 0; + } else { + // Divisor is > 32 bits + return rem_full(&d1, &d2, scale); + } + + if scale >= 0 { + break; + } + } + + CalculationResult::Ok(d1.to_decimal()) +} + +fn rem_full(d1: &Dec64, d2: &Buf12, scale: i32) -> CalculationResult { + let mut scale = scale; + + // First normalize the divisor + let shift = if d2.hi() == 0 { + d2.mid().leading_zeros() + } else { + d2.hi().leading_zeros() + }; + + let mut buffer = Buf24::zero(); + let mut overflow = 0u32; + buffer.set_low64(d1.low64 << shift); + buffer.set_mid64(((d1.mid() as u64).wrapping_add((d1.hi as u64) << 32)) >> (32 - shift)); + let mut upper = 3; // We start at 3 due to bit shifting + + while scale < 0 { + let power = if -scale >= MAX_I32_SCALE { + POWERS_10[9] + } else { + POWERS_10[-scale as usize] + } as u64; + let mut tmp64 = buffer.data[0] as u64 * power; + buffer.data[0] = tmp64 as u32; + + for (index, part) in buffer.data.iter_mut().enumerate().skip(1) { + if index > upper { + break; + } + tmp64 >>= 32; + tmp64 = tmp64.wrapping_add((*part as u64).wrapping_mul(power)); + *part = tmp64 as u32; + } + // If we have overflow then also process that + if upper == 6 { + tmp64 >>= 32; + tmp64 = tmp64.wrapping_add((overflow as u64).wrapping_mul(power)); + overflow = tmp64 as u32; + } + + // Make sure the high bit is not set + if tmp64 > 0x7FFF_FFFF { + upper += 1; + if upper > 5 { + overflow = (tmp64 >> 32) as u32; + } else { + buffer.data[upper] = (tmp64 >> 32) as u32; + } + } + scale += MAX_I32_SCALE; + } + + // TODO: Optimize slice logic + + let mut tmp = Buf16::zero(); + let divisor = d2.low64() << shift; + if d2.hi() == 0 { + // Do some division + if upper == 6 { + upper -= 1; + + tmp.data = [buffer.data[4], buffer.data[5], overflow, 0]; + tmp.partial_divide_64(divisor); + buffer.data[4] = tmp.data[0]; + buffer.data[5] = tmp.data[1]; + } + if upper == 5 { + upper -= 1; + tmp.data = [buffer.data[3], buffer.data[4], buffer.data[5], 0]; + tmp.partial_divide_64(divisor); + buffer.data[3] = tmp.data[0]; + buffer.data[4] = tmp.data[1]; + buffer.data[5] = tmp.data[2]; + } + if upper == 4 { + tmp.data = [buffer.data[2], buffer.data[3], buffer.data[4], 0]; + tmp.partial_divide_64(divisor); + buffer.data[2] = tmp.data[0]; + buffer.data[3] = tmp.data[1]; + buffer.data[4] = tmp.data[2]; + } + + tmp.data = [buffer.data[1], buffer.data[2], buffer.data[3], 0]; + tmp.partial_divide_64(divisor); + buffer.data[1] = tmp.data[0]; + buffer.data[2] = tmp.data[1]; + buffer.data[3] = tmp.data[2]; + + tmp.data = [buffer.data[0], buffer.data[1], buffer.data[2], 0]; + tmp.partial_divide_64(divisor); + buffer.data[0] = tmp.data[0]; + buffer.data[1] = tmp.data[1]; + buffer.data[2] = tmp.data[2]; + + let low64 = buffer.low64() >> shift; + CalculationResult::Ok(Decimal::from_parts( + low64 as u32, + (low64 >> 32) as u32, + 0, + d1.negative, + d1.scale, + )) + } else { + let divisor_low64 = divisor; + let divisor = Buf12 { + data: [ + divisor_low64 as u32, + (divisor_low64 >> 32) as u32, + (((d2.mid() as u64) + ((d2.hi() as u64) << 32)) >> (32 - shift)) as u32, + ], + }; + + // Do some division + if upper == 6 { + upper -= 1; + tmp.data = [buffer.data[3], buffer.data[4], buffer.data[5], overflow]; + tmp.partial_divide_96(&divisor); + buffer.data[3] = tmp.data[0]; + buffer.data[4] = tmp.data[1]; + buffer.data[5] = tmp.data[2]; + } + if upper == 5 { + upper -= 1; + tmp.data = [buffer.data[2], buffer.data[3], buffer.data[4], buffer.data[5]]; + tmp.partial_divide_96(&divisor); + buffer.data[2] = tmp.data[0]; + buffer.data[3] = tmp.data[1]; + buffer.data[4] = tmp.data[2]; + buffer.data[5] = tmp.data[3]; + } + if upper == 4 { + tmp.data = [buffer.data[1], buffer.data[2], buffer.data[3], buffer.data[4]]; + tmp.partial_divide_96(&divisor); + buffer.data[1] = tmp.data[0]; + buffer.data[2] = tmp.data[1]; + buffer.data[3] = tmp.data[2]; + buffer.data[4] = tmp.data[3]; + } + + tmp.data = [buffer.data[0], buffer.data[1], buffer.data[2], buffer.data[3]]; + tmp.partial_divide_96(&divisor); + buffer.data[0] = tmp.data[0]; + buffer.data[1] = tmp.data[1]; + buffer.data[2] = tmp.data[2]; + buffer.data[3] = tmp.data[3]; + + let low64 = (buffer.low64() >> shift) + ((buffer.data[2] as u64) << (32 - shift) << 32); + CalculationResult::Ok(Decimal::from_parts( + low64 as u32, + (low64 >> 32) as u32, + buffer.data[2] >> shift, + d1.negative, + d1.scale, + )) + } +} diff --git a/third_party/rust/rust_decimal/src/postgres.rs b/third_party/rust/rust_decimal/src/postgres.rs new file mode 100644 index 0000000000..0930e592ab --- /dev/null +++ b/third_party/rust/rust_decimal/src/postgres.rs @@ -0,0 +1,8 @@ +// Shared +mod common; + +#[cfg(any(feature = "db-diesel1-postgres", feature = "db-diesel2-postgres"))] +mod diesel; + +#[cfg(any(feature = "db-postgres", feature = "db-tokio-postgres"))] +mod driver; diff --git a/third_party/rust/rust_decimal/src/postgres/common.rs b/third_party/rust/rust_decimal/src/postgres/common.rs new file mode 100644 index 0000000000..b821b1edaa --- /dev/null +++ b/third_party/rust/rust_decimal/src/postgres/common.rs @@ -0,0 +1,140 @@ +use crate::constants::MAX_PRECISION_U32; +use crate::{ + ops::array::{div_by_u32, is_all_zero, mul_by_u32}, + Decimal, +}; +use core::fmt; +use std::error; + +#[derive(Debug, Clone)] +pub struct InvalidDecimal { + inner: Option<String>, +} + +impl fmt::Display for InvalidDecimal { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + if let Some(ref msg) = self.inner { + fmt.write_fmt(format_args!("Invalid Decimal: {}", msg)) + } else { + fmt.write_str("Invalid Decimal") + } + } +} + +impl error::Error for InvalidDecimal {} + +pub(in crate::postgres) struct PostgresDecimal<D> { + pub neg: bool, + pub weight: i16, + pub scale: u16, + pub digits: D, +} + +impl Decimal { + pub(in crate::postgres) fn from_postgres<D: ExactSizeIterator<Item = u16>>( + PostgresDecimal { + neg, + scale, + digits, + weight, + }: PostgresDecimal<D>, + ) -> Self { + let mut digits = digits.into_iter().collect::<Vec<_>>(); + + let fractionals_part_count = digits.len() as i32 + (-weight as i32) - 1; + let integers_part_count = weight as i32 + 1; + + let mut result = Decimal::ZERO; + // adding integer part + if integers_part_count > 0 { + let (start_integers, last) = if integers_part_count > digits.len() as i32 { + (integers_part_count - digits.len() as i32, digits.len() as i32) + } else { + (0, integers_part_count) + }; + let integers: Vec<_> = digits.drain(..last as usize).collect(); + for digit in integers { + result *= Decimal::from_i128_with_scale(10i128.pow(4), 0); + result += Decimal::new(digit as i64, 0); + } + result *= Decimal::from_i128_with_scale(10i128.pow(4 * start_integers as u32), 0); + } + // adding fractional part + if fractionals_part_count > 0 { + let start_fractionals = if weight < 0 { (-weight as u32) - 1 } else { 0 }; + for (i, digit) in digits.into_iter().enumerate() { + let fract_pow = 4 * (i as u32 + 1 + start_fractionals); + if fract_pow <= MAX_PRECISION_U32 { + result += Decimal::new(digit as i64, 0) / Decimal::from_i128_with_scale(10i128.pow(fract_pow), 0); + } else if fract_pow == MAX_PRECISION_U32 + 4 { + // rounding last digit + if digit >= 5000 { + result += + Decimal::new(1_i64, 0) / Decimal::from_i128_with_scale(10i128.pow(MAX_PRECISION_U32), 0); + } + } + } + } + + result.set_sign_negative(neg); + // Rescale to the postgres value, automatically rounding as needed. + result.rescale(scale as u32); + result + } + + pub(in crate::postgres) fn to_postgres(self) -> PostgresDecimal<Vec<i16>> { + if self.is_zero() { + return PostgresDecimal { + neg: false, + weight: 0, + scale: 0, + digits: vec![0], + }; + } + let scale = self.scale() as u16; + + let groups_diff = scale & 0x3; // groups_diff = scale % 4 + + let mut mantissa = self.mantissa_array4(); + + if groups_diff > 0 { + let remainder = 4 - groups_diff; + let power = 10u32.pow(u32::from(remainder)); + mul_by_u32(&mut mantissa, power); + } + + // array to store max mantissa of Decimal in Postgres decimal format + const MAX_GROUP_COUNT: usize = 8; + let mut digits = Vec::with_capacity(MAX_GROUP_COUNT); + + while !is_all_zero(&mantissa) { + let digit = div_by_u32(&mut mantissa, 10000) as u16; + digits.push(digit.try_into().unwrap()); + } + digits.reverse(); + let digits_after_decimal = (scale + 3) as u16 / 4; + let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + + let unnecessary_zeroes = if weight >= 0 { + let index_of_decimal = (weight + 1) as usize; + digits + .get(index_of_decimal..) + .expect("enough digits exist") + .iter() + .rev() + .take_while(|i| **i == 0) + .count() + } else { + 0 + }; + let relevant_digits = digits.len() - unnecessary_zeroes; + digits.truncate(relevant_digits); + + PostgresDecimal { + neg: self.is_sign_negative(), + digits, + scale, + weight, + } + } +} diff --git a/third_party/rust/rust_decimal/src/postgres/diesel.rs b/third_party/rust/rust_decimal/src/postgres/diesel.rs new file mode 100644 index 0000000000..26cd3b33be --- /dev/null +++ b/third_party/rust/rust_decimal/src/postgres/diesel.rs @@ -0,0 +1,333 @@ +use crate::postgres::common::*; +use crate::Decimal; +use diesel::{ + deserialize::{self, FromSql}, + pg::data_types::PgNumeric, + pg::Pg, + serialize::{self, Output, ToSql}, + sql_types::Numeric, +}; +use std::error; + +impl<'a> TryFrom<&'a PgNumeric> for Decimal { + type Error = Box<dyn error::Error + Send + Sync>; + + fn try_from(numeric: &'a PgNumeric) -> deserialize::Result<Self> { + let (neg, weight, scale, digits) = match *numeric { + PgNumeric::Positive { + weight, + scale, + ref digits, + } => (false, weight, scale, digits), + PgNumeric::Negative { + weight, + scale, + ref digits, + } => (true, weight, scale, digits), + PgNumeric::NaN => return Err(Box::from("NaN is not supported in Decimal")), + }; + + Ok(Self::from_postgres(PostgresDecimal { + neg, + weight, + scale, + digits: digits.iter().copied().map(|v| v.try_into().unwrap()), + })) + } +} + +impl TryFrom<PgNumeric> for Decimal { + type Error = Box<dyn error::Error + Send + Sync>; + + fn try_from(numeric: PgNumeric) -> deserialize::Result<Self> { + (&numeric).try_into() + } +} + +impl<'a> From<&'a Decimal> for PgNumeric { + fn from(decimal: &'a Decimal) -> Self { + let PostgresDecimal { + neg, + weight, + scale, + digits, + } = decimal.to_postgres(); + + if neg { + PgNumeric::Negative { digits, scale, weight } + } else { + PgNumeric::Positive { digits, scale, weight } + } + } +} + +impl From<Decimal> for PgNumeric { + fn from(decimal: Decimal) -> Self { + (&decimal).into() + } +} + +#[cfg(all(feature = "diesel1", not(feature = "diesel2")))] +impl ToSql<Numeric, Pg> for Decimal { + fn to_sql<W: std::io::Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result { + let numeric = PgNumeric::from(self); + ToSql::<Numeric, Pg>::to_sql(&numeric, out) + } +} + +#[cfg(feature = "diesel2")] +impl ToSql<Numeric, Pg> for Decimal { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let numeric = PgNumeric::from(self); + ToSql::<Numeric, Pg>::to_sql(&numeric, &mut out.reborrow()) + } +} + +#[cfg(all(feature = "diesel1", not(feature = "diesel2")))] +impl FromSql<Numeric, Pg> for Decimal { + fn from_sql(numeric: Option<&[u8]>) -> deserialize::Result<Self> { + PgNumeric::from_sql(numeric)?.try_into() + } +} + +#[cfg(feature = "diesel2")] +impl FromSql<Numeric, Pg> for Decimal { + fn from_sql(numeric: diesel::pg::PgValue) -> deserialize::Result<Self> { + PgNumeric::from_sql(numeric)?.try_into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use core::str::FromStr; + + #[test] + fn test_unnecessary_zeroes() { + fn extract(value: &str) -> Decimal { + Decimal::from_str(value).unwrap() + } + + let tests = &[ + ("0.000001660"), + ("41.120255926293000"), + ("0.5538973300"), + ("08883.55986854293100"), + ("0.0000_0000_0016_6000_00"), + ("0.00000166650000"), + ("1666500000000"), + ("1666500000000.0000054500"), + ("8944.000000000000"), + ]; + + for &value in tests { + let value = extract(value); + let pg = PgNumeric::from(value); + let dec = Decimal::try_from(pg).unwrap(); + assert_eq!(dec, value); + } + } + + #[test] + fn decimal_to_pgnumeric_converts_digits_to_base_10000() { + let decimal = Decimal::from_str("1").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 0, + digits: vec![1], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("10").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 0, + digits: vec![10], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("10000").unwrap(); + let expected = PgNumeric::Positive { + weight: 1, + scale: 0, + digits: vec![1, 0], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("10001").unwrap(); + let expected = PgNumeric::Positive { + weight: 1, + scale: 0, + digits: vec![1, 1], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("100000000").unwrap(); + let expected = PgNumeric::Positive { + weight: 2, + scale: 0, + digits: vec![1, 0, 0], + }; + assert_eq!(expected, decimal.into()); + } + + #[test] + fn decimal_to_pg_numeric_properly_adjusts_scale() { + let decimal = Decimal::from_str("1").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 0, + digits: vec![1], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("1.0").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 1, + digits: vec![1], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("1.1").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 1, + digits: vec![1, 1000], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("1.10").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 2, + digits: vec![1, 1000], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("100000000.0001").unwrap(); + let expected = PgNumeric::Positive { + weight: 2, + scale: 4, + digits: vec![1, 0, 0, 1], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("0.1").unwrap(); + let expected = PgNumeric::Positive { + weight: -1, + scale: 1, + digits: vec![1000], + }; + assert_eq!(expected, decimal.into()); + } + + #[test] + fn decimal_to_pg_numeric_retains_sign() { + let decimal = Decimal::from_str("123.456").unwrap(); + let expected = PgNumeric::Positive { + weight: 0, + scale: 3, + digits: vec![123, 4560], + }; + assert_eq!(expected, decimal.into()); + + let decimal = Decimal::from_str("-123.456").unwrap(); + let expected = PgNumeric::Negative { + weight: 0, + scale: 3, + digits: vec![123, 4560], + }; + assert_eq!(expected, decimal.into()); + } + + #[test] + fn pg_numeric_to_decimal_works() { + let expected = Decimal::from_str("50").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 0, + digits: vec![50], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res, expected); + let expected = Decimal::from_str("123.456").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 3, + digits: vec![123, 4560], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res, expected); + + let expected = Decimal::from_str("-56.78").unwrap(); + let pg_numeric = PgNumeric::Negative { + weight: 0, + scale: 2, + digits: vec![56, 7800], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res, expected); + + // Verify no trailing zeroes are lost. + + let expected = Decimal::from_str("1.100").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 3, + digits: vec![1, 1000], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res.to_string(), expected.to_string()); + + // To represent 5.00, Postgres can return either [5, 0] as the list of digits. + let expected = Decimal::from_str("5.00").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 2, + + digits: vec![5, 0], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res.to_string(), expected.to_string()); + + // To represent 5.00, Postgres can return [5] as the list of digits. + let expected = Decimal::from_str("5.00").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 2, + digits: vec![5], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res.to_string(), expected.to_string()); + + let expected = Decimal::from_str("3.1415926535897932384626433833").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 30, + digits: vec![3, 1415, 9265, 3589, 7932, 3846, 2643, 3832, 7950, 2800], + }; + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res.to_string(), expected.to_string()); + + let expected = Decimal::from_str("3.1415926535897932384626433833").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 34, + digits: vec![3, 1415, 9265, 3589, 7932, 3846, 2643, 3832, 7950, 2800], + }; + + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res.to_string(), expected.to_string()); + + let expected = Decimal::from_str("1.2345678901234567890123456790").unwrap(); + let pg_numeric = PgNumeric::Positive { + weight: 0, + scale: 34, + digits: vec![1, 2345, 6789, 0123, 4567, 8901, 2345, 6789, 5000, 0], + }; + + let res: Decimal = pg_numeric.try_into().unwrap(); + assert_eq!(res.to_string(), expected.to_string()); + } +} diff --git a/third_party/rust/rust_decimal/src/postgres/driver.rs b/third_party/rust/rust_decimal/src/postgres/driver.rs new file mode 100644 index 0000000000..7d185e3bab --- /dev/null +++ b/third_party/rust/rust_decimal/src/postgres/driver.rs @@ -0,0 +1,383 @@ +use crate::postgres::common::*; +use crate::Decimal; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; +use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; +use std::io::Cursor; + +impl<'a> FromSql<'a> for Decimal { + // Decimals are represented as follows: + // Header: + // u16 numGroups + // i16 weightFirstGroup (10000^weight) + // u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN) + // i16 dscale. Number of digits (in base 10) to print after decimal separator + // + // Pseudo code : + // const Decimals [ + // 0.0000000000000000000000000001, + // 0.000000000000000000000001, + // 0.00000000000000000001, + // 0.0000000000000001, + // 0.000000000001, + // 0.00000001, + // 0.0001, + // 1, + // 10000, + // 100000000, + // 1000000000000, + // 10000000000000000, + // 100000000000000000000, + // 1000000000000000000000000, + // 10000000000000000000000000000 + // ] + // overflow = false + // result = 0 + // for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight-- + // group = read.u16 + // if weight < 0 or weight > MaxNum + // overflow = true + // else + // result += Decimals[weight] * group + // sign == 0x4000 ? -result : result + + // So if we were to take the number: 3950.123456 + // + // Stored on Disk: + // 00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0 + // + // Number of groups: 00 03 + // Weight of first group: 00 00 + // Sign: 00 00 + // DScale: 00 06 + // + // 0F 6E = 3950 + // result = result + 3950 * 1; + // 04 D2 = 1234 + // result = result + 1234 * 0.0001; + // 15 E0 = 5600 + // result = result + 5600 * 0.00000001; + // + + fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn std::error::Error + 'static + Sync + Send>> { + let mut raw = Cursor::new(raw); + let num_groups = raw.read_u16::<BigEndian>()?; + let weight = raw.read_i16::<BigEndian>()?; // 10000^weight + // Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN + let sign = raw.read_u16::<BigEndian>()?; + // Number of digits (in base 10) to print after decimal separator + let scale = raw.read_u16::<BigEndian>()?; + + // Read all of the groups + let mut groups = Vec::new(); + for _ in 0..num_groups as usize { + groups.push(raw.read_u16::<BigEndian>()?); + } + + Ok(Self::from_postgres(PostgresDecimal { + neg: sign == 0x4000, + weight, + scale, + digits: groups.into_iter(), + })) + } + + fn accepts(ty: &Type) -> bool { + matches!(*ty, Type::NUMERIC) + } +} + +impl ToSql for Decimal { + fn to_sql( + &self, + _: &Type, + out: &mut BytesMut, + ) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> { + let PostgresDecimal { + neg, + weight, + scale, + digits, + } = self.to_postgres(); + + let num_digits = digits.len(); + + // Reserve bytes + out.reserve(8 + num_digits * 2); + + // Number of groups + out.put_u16(num_digits.try_into().unwrap()); + // Weight of first group + out.put_i16(weight); + // Sign + out.put_u16(if neg { 0x4000 } else { 0x0000 }); + // DScale + out.put_u16(scale); + // Now process the number + for digit in digits[0..num_digits].iter() { + out.put_i16(*digit); + } + + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + matches!(*ty, Type::NUMERIC) + } + + to_sql_checked!(); +} + +#[cfg(test)] +mod test { + use super::*; + use ::postgres::{Client, NoTls}; + use core::str::FromStr; + + /// Gets the URL for connecting to PostgreSQL for testing. Set the POSTGRES_URL + /// environment variable to change from the default of "postgres://postgres@localhost". + fn get_postgres_url() -> String { + if let Ok(url) = std::env::var("POSTGRES_URL") { + return url; + } + "postgres://postgres@localhost".to_string() + } + + pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[ + // precision, scale, sent, expected + (35, 6, "3950.123456", "3950.123456"), + (35, 2, "3950.123456", "3950.12"), + (35, 2, "3950.1256", "3950.13"), + (10, 2, "3950.123456", "3950.12"), + (35, 6, "3950", "3950.000000"), + (4, 0, "3950", "3950"), + (35, 6, "0.1", "0.100000"), + (35, 6, "0.01", "0.010000"), + (35, 6, "0.001", "0.001000"), + (35, 6, "0.0001", "0.000100"), + (35, 6, "0.00001", "0.000010"), + (35, 6, "0.000001", "0.000001"), + (35, 6, "1", "1.000000"), + (35, 6, "-100", "-100.000000"), + (35, 6, "-123.456", "-123.456000"), + (35, 6, "119996.25", "119996.250000"), + (35, 6, "1000000", "1000000.000000"), + (35, 6, "9999999.99999", "9999999.999990"), + (35, 6, "12340.56789", "12340.567890"), + // Scale is only 28 since that is the maximum we can represent. + (65, 30, "1.2", "1.2000000000000000000000000000"), + // Pi - rounded at scale 28 + ( + 65, + 30, + "3.141592653589793238462643383279", + "3.1415926535897932384626433833", + ), + ( + 65, + 34, + "3.1415926535897932384626433832795028", + "3.1415926535897932384626433833", + ), + // Unrounded number + ( + 65, + 34, + "1.234567890123456789012345678950000", + "1.2345678901234567890123456790", + ), + ( + 65, + 34, // No rounding due to 49999 after significant digits + "1.234567890123456789012345678949999", + "1.2345678901234567890123456789", + ), + // 0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF (96 bit) + (35, 0, "79228162514264337593543950335", "79228162514264337593543950335"), + // 0x0FFF_FFFF_FFFF_FFFF_FFFF_FFFF (95 bit) + (35, 1, "4951760157141521099596496895", "4951760157141521099596496895.0"), + // 0x1000_0000_0000_0000_0000_0000 + (35, 1, "4951760157141521099596496896", "4951760157141521099596496896.0"), + (35, 6, "18446744073709551615", "18446744073709551615.000000"), + (35, 6, "-18446744073709551615", "-18446744073709551615.000000"), + (35, 6, "0.10001", "0.100010"), + (35, 6, "0.12345", "0.123450"), + ]; + + #[test] + fn test_null() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + + // Test NULL + let result: Option<Decimal> = match client.query("SELECT NULL::numeric", &[]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("{:#?}", err), + }; + assert_eq!(None, result); + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_test_null() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + let statement = client.prepare(&"SELECT NULL::numeric").await.unwrap(); + let rows = client.query(&statement, &[]).await.unwrap(); + let result: Option<Decimal> = rows.iter().next().unwrap().get(0); + + assert_eq!(None, result); + } + + #[test] + fn read_very_small_numeric_type() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + let result: Decimal = match client.query("SELECT 1e-130::NUMERIC(130, 0)", &[]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("error - {:#?}", err), + }; + // We compare this to zero since it is so small that it is effectively zero + assert_eq!(Decimal::ZERO, result); + } + + #[test] + fn read_numeric_type() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let result: Decimal = + match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("SELECT {}::NUMERIC({}, {}), error - {:#?}", sent, precision, scale, err), + }; + assert_eq!( + expected, + result.to_string(), + "NUMERIC({}, {}) sent: {}", + precision, + scale, + sent + ); + } + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_read_numeric_type() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let statement = client + .prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) + .await + .unwrap(); + let rows = client.query(&statement, &[]).await.unwrap(); + let result: Decimal = rows.iter().next().unwrap().get(0); + + assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale); + } + } + + #[test] + fn write_numeric_type() { + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let number = Decimal::from_str(sent).unwrap(); + let result: Decimal = + match client.query(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale), &[&number]) { + Ok(x) => x.iter().next().unwrap().get(0), + Err(err) => panic!("{:#?}", err), + }; + assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale); + } + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_write_numeric_type() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() { + let statement = client + .prepare(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale)) + .await + .unwrap(); + let number = Decimal::from_str(sent).unwrap(); + let rows = client.query(&statement, &[&number]).await.unwrap(); + let result: Decimal = rows.iter().next().unwrap().get(0); + + assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale); + } + } + + #[test] + fn numeric_overflow() { + let tests = [(4, 4, "3950.1234")]; + let mut client = match Client::connect(&get_postgres_url(), NoTls) { + Ok(x) => x, + Err(err) => panic!("{:#?}", err), + }; + for &(precision, scale, sent) in tests.iter() { + match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) { + Ok(_) => panic!( + "Expected numeric overflow for {}::NUMERIC({}, {})", + sent, precision, scale + ), + Err(err) => { + assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"); + } + }; + } + } + + #[tokio::test] + #[cfg(feature = "tokio-pg")] + async fn async_numeric_overflow() { + use futures::future::FutureExt; + use tokio_postgres::connect; + + let tests = [(4, 4, "3950.1234")]; + let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + for &(precision, scale, sent) in tests.iter() { + let statement = client + .prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale)) + .await + .unwrap(); + + match client.query(&statement, &[]).await { + Ok(_) => panic!( + "Expected numeric overflow for {}::NUMERIC({}, {})", + sent, precision, scale + ), + Err(err) => assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"), + } + } + } +} diff --git a/third_party/rust/rust_decimal/src/rand.rs b/third_party/rust/rust_decimal/src/rand.rs new file mode 100644 index 0000000000..afd8a64361 --- /dev/null +++ b/third_party/rust/rust_decimal/src/rand.rs @@ -0,0 +1,176 @@ +use crate::Decimal; +use rand::{ + distributions::{ + uniform::{SampleBorrow, SampleUniform, UniformInt, UniformSampler}, + Distribution, Standard, + }, + Rng, +}; + +impl Distribution<Decimal> for Standard { + fn sample<R>(&self, rng: &mut R) -> Decimal + where + R: Rng + ?Sized, + { + Decimal::from_parts( + rng.next_u32(), + rng.next_u32(), + rng.next_u32(), + rng.gen(), + rng.next_u32(), + ) + } +} + +impl SampleUniform for Decimal { + type Sampler = DecimalSampler; +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct DecimalSampler { + mantissa_sampler: UniformInt<i128>, + scale: u32, +} + +impl UniformSampler for DecimalSampler { + type X = Decimal; + + /// Creates a new sampler that will yield random decimal objects between `low` and `high`. + /// + /// The sampler will always provide decimals at the same scale as the inputs; if the inputs + /// have different scales, the higher scale is used. + /// + /// # Example + /// + /// ``` + /// # use rand::Rng; + /// # use rust_decimal_macros::dec; + /// let mut rng = rand::rngs::OsRng; + /// let random = rng.gen_range(dec!(1.00)..dec!(2.00)); + /// assert!(random >= dec!(1.00)); + /// assert!(random < dec!(2.00)); + /// assert_eq!(random.scale(), 2); + /// ``` + #[inline] + fn new<B1, B2>(low: B1, high: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let (low, high) = sync_scales(*low.borrow(), *high.borrow()); + let high = Decimal::from_i128_with_scale(high.mantissa() - 1, high.scale()); + UniformSampler::new_inclusive(low, high) + } + + /// Creates a new sampler that will yield random decimal objects between `low` and `high`. + /// + /// The sampler will always provide decimals at the same scale as the inputs; if the inputs + /// have different scales, the higher scale is used. + /// + /// # Example + /// + /// ``` + /// # use rand::Rng; + /// # use rust_decimal_macros::dec; + /// let mut rng = rand::rngs::OsRng; + /// let random = rng.gen_range(dec!(1.00)..=dec!(2.00)); + /// assert!(random >= dec!(1.00)); + /// assert!(random <= dec!(2.00)); + /// assert_eq!(random.scale(), 2); + /// ``` + #[inline] + fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self + where + B1: SampleBorrow<Self::X> + Sized, + B2: SampleBorrow<Self::X> + Sized, + { + let (low, high) = sync_scales(*low.borrow(), *high.borrow()); + + // Return our sampler, which contains an underlying i128 sampler so we + // outsource the actual randomness implementation. + Self { + mantissa_sampler: UniformInt::new_inclusive(low.mantissa(), high.mantissa()), + scale: low.scale(), + } + } + + #[inline] + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { + let mantissa = self.mantissa_sampler.sample(rng); + Decimal::from_i128_with_scale(mantissa, self.scale) + } +} + +/// Return equivalent Decimal objects with the same scale as one another. +#[inline] +fn sync_scales(mut a: Decimal, mut b: Decimal) -> (Decimal, Decimal) { + if a.scale() == b.scale() { + return (a, b); + } + + // Set scales to match one another, because we are relying on mantissas' + // being comparable in order outsource the actual sampling implementation. + a.rescale(a.scale().max(b.scale())); + b.rescale(a.scale().max(b.scale())); + + // Edge case: If the values have _wildly_ different scales, the values may not have rescaled far enough to match one another. + // + // In this case, we accept some precision loss because the randomization approach we are using assumes that the scales will necessarily match. + if a.scale() != b.scale() { + a.rescale(a.scale().min(b.scale())); + b.rescale(a.scale().min(b.scale())); + } + + (a, b) +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + macro_rules! dec { + ($e:expr) => { + Decimal::from_str_exact(stringify!($e)).unwrap() + }; + } + + #[test] + fn has_random_decimal_instances() { + let mut rng = rand::rngs::OsRng; + let random: [Decimal; 32] = rng.gen(); + assert!(random.windows(2).any(|slice| { slice[0] != slice[1] })); + } + + #[test] + fn generates_within_range() { + let mut rng = rand::rngs::OsRng; + for _ in 0..128 { + let random = rng.gen_range(dec!(1.00)..dec!(1.05)); + assert!(random < dec!(1.05)); + assert!(random >= dec!(1.00)); + } + } + + #[test] + fn generates_within_inclusive_range() { + let mut rng = rand::rngs::OsRng; + let mut values: HashSet<Decimal> = HashSet::new(); + for _ in 0..256 { + let random = rng.gen_range(dec!(1.00)..=dec!(1.01)); + // The scale is 2, so 1.00 and 1.01 are the only two valid choices. + assert!(random == dec!(1.00) || random == dec!(1.01)); + values.insert(random); + } + // Somewhat flaky, will fail 1 out of every 2^255 times this is run. + // Probably acceptable in the real world. + assert_eq!(values.len(), 2); + } + + #[test] + fn test_edge_case_scales_match() { + let (low, high) = sync_scales(dec!(1.000_000_000_000_000_000_01), dec!(100_000_000_000_000_000_001)); + assert_eq!(low.scale(), high.scale()); + } +} diff --git a/third_party/rust/rust_decimal/src/rocket.rs b/third_party/rust/rust_decimal/src/rocket.rs new file mode 100644 index 0000000000..6c8938ebec --- /dev/null +++ b/third_party/rust/rust_decimal/src/rocket.rs @@ -0,0 +1,12 @@ +use crate::Decimal; +use rocket::form::{self, FromFormField, ValueField}; +use std::str::FromStr; + +impl<'v> FromFormField<'v> for Decimal { + fn default() -> Option<Self> { + None + } + fn from_value(field: ValueField<'v>) -> form::Result<'v, Self> { + Decimal::from_str(field.value).map_err(|_| form::Error::validation("not a valid number").into()) + } +} diff --git a/third_party/rust/rust_decimal/src/serde.rs b/third_party/rust/rust_decimal/src/serde.rs new file mode 100644 index 0000000000..ce876309f9 --- /dev/null +++ b/third_party/rust/rust_decimal/src/serde.rs @@ -0,0 +1,899 @@ +use crate::Decimal; +use alloc::string::ToString; +use core::{fmt, str::FromStr}; +use num_traits::FromPrimitive; +use serde::{self, de::Unexpected}; + +/// Serialize/deserialize Decimals as arbitrary precision numbers in JSON using the `arbitrary_precision` feature within `serde_json`. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct ArbitraryExample { +/// #[serde(with = "rust_decimal::serde::arbitrary_precision")] +/// value: Decimal, +/// } +/// +/// let value = ArbitraryExample { value: Decimal::from_str("123.400").unwrap() }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":123.400}"# +/// ); +/// ``` +#[cfg(feature = "serde-with-arbitrary-precision")] +pub mod arbitrary_precision { + use super::*; + use serde::Serialize; + + pub fn deserialize<'de, D>(deserializer: D) -> Result<Decimal, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_any(DecimalVisitor) + } + + pub fn serialize<S>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + serde_json::Number::from_str(&value.to_string()) + .map_err(serde::ser::Error::custom)? + .serialize(serializer) + } +} + +/// Serialize/deserialize optional Decimals as arbitrary precision numbers in JSON using the `arbitrary_precision` feature within `serde_json`. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct ArbitraryExample { +/// #[serde(with = "rust_decimal::serde::arbitrary_precision_option")] +/// value: Option<Decimal>, +/// } +/// +/// let value = ArbitraryExample { value: Some(Decimal::from_str("123.400").unwrap()) }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":123.400}"# +/// ); +/// +/// let value = ArbitraryExample { value: None }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":null}"# +/// ); +/// ``` +#[cfg(feature = "serde-with-arbitrary-precision")] +pub mod arbitrary_precision_option { + use super::*; + use serde::Serialize; + + pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Decimal>, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_option(OptionDecimalVisitor) + } + + pub fn serialize<S>(value: &Option<Decimal>, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + match *value { + Some(ref decimal) => serde_json::Number::from_str(&decimal.to_string()) + .map_err(serde::ser::Error::custom)? + .serialize(serializer), + None => serializer.serialize_none(), + } + } +} + +/// Serialize/deserialize Decimals as floats. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct FloatExample { +/// #[serde(with = "rust_decimal::serde::float")] +/// value: Decimal, +/// } +/// +/// let value = FloatExample { value: Decimal::from_str("123.400").unwrap() }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":123.4}"# +/// ); +/// ``` +#[cfg(feature = "serde-with-float")] +pub mod float { + use super::*; + use serde::Serialize; + + pub fn deserialize<'de, D>(deserializer: D) -> Result<Decimal, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_any(DecimalVisitor) + } + + pub fn serialize<S>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + use num_traits::ToPrimitive; + value.to_f64().unwrap().serialize(serializer) + } +} + +/// Serialize/deserialize optional Decimals as floats. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct FloatExample { +/// #[serde(with = "rust_decimal::serde::float_option")] +/// value: Option<Decimal>, +/// } +/// +/// let value = FloatExample { value: Some(Decimal::from_str("123.400").unwrap()) }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":123.4}"# +/// ); +/// +/// let value = FloatExample { value: None }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":null}"# +/// ); +/// ``` +#[cfg(feature = "serde-with-float")] +pub mod float_option { + use super::*; + use serde::Serialize; + + pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Decimal>, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_option(OptionDecimalVisitor) + } + + pub fn serialize<S>(value: &Option<Decimal>, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + match *value { + Some(ref decimal) => { + use num_traits::ToPrimitive; + decimal.to_f64().unwrap().serialize(serializer) + } + None => serializer.serialize_none(), + } + } +} + +/// Serialize/deserialize Decimals as strings. This is particularly useful when using binary encoding formats. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct StringExample { +/// #[serde(with = "rust_decimal::serde::str")] +/// value: Decimal, +/// } +/// +/// let value = StringExample { value: Decimal::from_str("123.400").unwrap() }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":"123.400"}"# +/// ); +/// +/// ``` +#[cfg(feature = "serde-with-str")] +pub mod str { + use super::*; + + pub fn deserialize<'de, D>(deserializer: D) -> Result<Decimal, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_str(DecimalVisitor) + } + + pub fn serialize<S>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + let value = crate::str::to_str_internal(value, true, None); + serializer.serialize_str(value.0.as_ref()) + } +} + +/// Serialize/deserialize optional Decimals as strings. This is particularly useful when using binary encoding formats. +/// +/// ``` +/// # use serde::{Serialize, Deserialize}; +/// # use rust_decimal::Decimal; +/// # use std::str::FromStr; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct StringExample { +/// #[serde(with = "rust_decimal::serde::str_option")] +/// value: Option<Decimal>, +/// } +/// +/// let value = StringExample { value: Some(Decimal::from_str("123.400").unwrap()) }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":"123.400"}"# +/// ); +/// +/// let value = StringExample { value: None }; +/// assert_eq!( +/// &serde_json::to_string(&value).unwrap(), +/// r#"{"value":null}"# +/// ); +/// ``` +#[cfg(feature = "serde-with-str")] +pub mod str_option { + use super::*; + + pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Decimal>, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_option(OptionDecimalStrVisitor) + } + + pub fn serialize<S>(value: &Option<Decimal>, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + match *value { + Some(ref decimal) => { + let decimal = crate::str::to_str_internal(decimal, true, None); + serializer.serialize_some(decimal.0.as_ref()) + } + None => serializer.serialize_none(), + } + } +} + +#[cfg(not(feature = "serde-str"))] +impl<'de> serde::Deserialize<'de> for Decimal { + fn deserialize<D>(deserializer: D) -> Result<Decimal, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_any(DecimalVisitor) + } +} + +#[cfg(all(feature = "serde-str", not(feature = "serde-float")))] +impl<'de> serde::Deserialize<'de> for Decimal { + fn deserialize<D>(deserializer: D) -> Result<Decimal, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_str(DecimalVisitor) + } +} + +#[cfg(all(feature = "serde-str", feature = "serde-float"))] +impl<'de> serde::Deserialize<'de> for Decimal { + fn deserialize<D>(deserializer: D) -> Result<Decimal, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_f64(DecimalVisitor) + } +} + +// It's a shame this needs to be redefined for this feature and not able to be referenced directly +#[cfg(feature = "serde-with-arbitrary-precision")] +const DECIMAL_KEY_TOKEN: &str = "$serde_json::private::Number"; + +struct DecimalVisitor; + +impl<'de> serde::de::Visitor<'de> for DecimalVisitor { + type Value = Decimal; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a Decimal type representing a fixed-point number") + } + + fn visit_i64<E>(self, value: i64) -> Result<Decimal, E> + where + E: serde::de::Error, + { + match Decimal::from_i64(value) { + Some(s) => Ok(s), + None => Err(E::invalid_value(Unexpected::Signed(value), &self)), + } + } + + fn visit_u64<E>(self, value: u64) -> Result<Decimal, E> + where + E: serde::de::Error, + { + match Decimal::from_u64(value) { + Some(s) => Ok(s), + None => Err(E::invalid_value(Unexpected::Unsigned(value), &self)), + } + } + + fn visit_f64<E>(self, value: f64) -> Result<Decimal, E> + where + E: serde::de::Error, + { + Decimal::from_str(&value.to_string()).map_err(|_| E::invalid_value(Unexpected::Float(value), &self)) + } + + fn visit_str<E>(self, value: &str) -> Result<Decimal, E> + where + E: serde::de::Error, + { + Decimal::from_str(value) + .or_else(|_| Decimal::from_scientific(value)) + .map_err(|_| E::invalid_value(Unexpected::Str(value), &self)) + } + + #[cfg(feature = "serde-with-arbitrary-precision")] + fn visit_map<A>(self, map: A) -> Result<Decimal, A::Error> + where + A: serde::de::MapAccess<'de>, + { + let mut map = map; + let value = map.next_key::<DecimalKey>()?; + if value.is_none() { + return Err(serde::de::Error::invalid_type(Unexpected::Map, &self)); + } + let v: DecimalFromString = map.next_value()?; + Ok(v.value) + } +} + +struct OptionDecimalVisitor; + +impl<'de> serde::de::Visitor<'de> for OptionDecimalVisitor { + type Value = Option<Decimal>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Decimal type representing a fixed-point number") + } + + fn visit_none<E>(self) -> Result<Option<Decimal>, E> + where + E: serde::de::Error, + { + Ok(None) + } + + #[cfg(all(feature = "serde-str", feature = "serde-float"))] + fn visit_some<D>(self, d: D) -> Result<Option<Decimal>, D::Error> + where + D: serde::de::Deserializer<'de>, + { + // We've got multiple types that we may see so we need to use any + d.deserialize_any(DecimalVisitor).map(Some) + } + + #[cfg(not(all(feature = "serde-str", feature = "serde-float")))] + fn visit_some<D>(self, d: D) -> Result<Option<Decimal>, D::Error> + where + D: serde::de::Deserializer<'de>, + { + <Decimal as serde::Deserialize>::deserialize(d).map(Some) + } +} + +#[cfg(feature = "serde-with-str")] +struct OptionDecimalStrVisitor; + +#[cfg(feature = "serde-with-str")] +impl<'de> serde::de::Visitor<'de> for OptionDecimalStrVisitor { + type Value = Option<Decimal>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Decimal type representing a fixed-point number") + } + + fn visit_none<E>(self) -> Result<Option<Decimal>, E> + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_some<D>(self, d: D) -> Result<Option<Decimal>, D::Error> + where + D: serde::de::Deserializer<'de>, + { + d.deserialize_str(DecimalVisitor).map(Some) + } +} + +#[cfg(feature = "serde-with-arbitrary-precision")] +struct DecimalKey; + +#[cfg(feature = "serde-with-arbitrary-precision")] +impl<'de> serde::de::Deserialize<'de> for DecimalKey { + fn deserialize<D>(deserializer: D) -> Result<DecimalKey, D::Error> + where + D: serde::de::Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> serde::de::Visitor<'de> for FieldVisitor { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a valid decimal field") + } + + fn visit_str<E>(self, s: &str) -> Result<(), E> + where + E: serde::de::Error, + { + if s == DECIMAL_KEY_TOKEN { + Ok(()) + } else { + Err(serde::de::Error::custom("expected field with custom name")) + } + } + } + + deserializer.deserialize_identifier(FieldVisitor)?; + Ok(DecimalKey) + } +} + +#[cfg(feature = "serde-with-arbitrary-precision")] +pub struct DecimalFromString { + pub value: Decimal, +} + +#[cfg(feature = "serde-with-arbitrary-precision")] +impl<'de> serde::de::Deserialize<'de> for DecimalFromString { + fn deserialize<D>(deserializer: D) -> Result<DecimalFromString, D::Error> + where + D: serde::de::Deserializer<'de>, + { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = DecimalFromString; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("string containing a decimal") + } + + fn visit_str<E>(self, value: &str) -> Result<DecimalFromString, E> + where + E: serde::de::Error, + { + let d = Decimal::from_str(value) + .or_else(|_| Decimal::from_scientific(value)) + .map_err(serde::de::Error::custom)?; + Ok(DecimalFromString { value: d }) + } + } + + deserializer.deserialize_str(Visitor) + } +} + +#[cfg(not(feature = "serde-float"))] +impl serde::Serialize for Decimal { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + let value = crate::str::to_str_internal(self, true, None); + serializer.serialize_str(value.0.as_ref()) + } +} + +#[cfg(all(feature = "serde-float", not(feature = "serde-arbitrary-precision")))] +impl serde::Serialize for Decimal { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + use num_traits::ToPrimitive; + serializer.serialize_f64(self.to_f64().unwrap()) + } +} + +#[cfg(all(feature = "serde-float", feature = "serde-arbitrary-precision"))] +impl serde::Serialize for Decimal { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + serde_json::Number::from_str(&self.to_string()) + .map_err(serde::ser::Error::custom)? + .serialize(serializer) + } +} + +#[cfg(test)] +mod test { + use super::*; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug)] + struct Record { + amount: Decimal, + } + + #[test] + #[cfg(not(feature = "serde-str"))] + fn deserialize_valid_decimal() { + let data = [ + ("{\"amount\":\"1.234\"}", "1.234"), + ("{\"amount\":1234}", "1234"), + ("{\"amount\":1234.56}", "1234.56"), + ("{\"amount\":\"1.23456e3\"}", "1234.56"), + ]; + for &(serialized, value) in data.iter() { + let result = serde_json::from_str(serialized); + assert_eq!( + true, + result.is_ok(), + "expected successful deserialization for {}. Error: {:?}", + serialized, + result.err().unwrap() + ); + let record: Record = result.unwrap(); + assert_eq!( + value, + record.amount.to_string(), + "expected: {}, actual: {}", + value, + record.amount.to_string() + ); + } + } + + #[test] + #[cfg(feature = "serde-arbitrary-precision")] + fn deserialize_basic_decimal() { + let d: Decimal = serde_json::from_str("1.1234127836128763").unwrap(); + // Typically, this would not work without this feature enabled due to rounding + assert_eq!(d.to_string(), "1.1234127836128763"); + } + + #[test] + #[should_panic] + fn deserialize_invalid_decimal() { + let serialized = "{\"amount\":\"foo\"}"; + let _: Record = serde_json::from_str(serialized).unwrap(); + } + + #[test] + #[cfg(not(feature = "serde-float"))] + fn serialize_decimal() { + let record = Record { + amount: Decimal::new(1234, 3), + }; + let serialized = serde_json::to_string(&record).unwrap(); + assert_eq!("{\"amount\":\"1.234\"}", serialized); + } + + #[test] + #[cfg(not(feature = "serde-float"))] + fn serialize_negative_zero() { + let record = Record { amount: -Decimal::ZERO }; + let serialized = serde_json::to_string(&record).unwrap(); + assert_eq!("{\"amount\":\"-0\"}", serialized); + } + + #[test] + #[cfg(feature = "serde-float")] + fn serialize_decimal() { + let record = Record { + amount: Decimal::new(1234, 3), + }; + let serialized = serde_json::to_string(&record).unwrap(); + assert_eq!("{\"amount\":1.234}", serialized); + } + + #[test] + #[cfg(all(feature = "serde-float", feature = "serde-arbitrary-precision"))] + fn serialize_decimal_roundtrip() { + let record = Record { + // 4.81 is intentionally chosen as it is unrepresentable as a floating point number, meaning this test + // would fail if the `serde-arbitrary-precision` was not activated. + amount: Decimal::new(481, 2), + }; + let serialized = serde_json::to_string(&record).unwrap(); + assert_eq!("{\"amount\":4.81}", serialized); + let deserialized: Record = serde_json::from_str(&serialized).unwrap(); + assert_eq!(record.amount, deserialized.amount); + } + + #[test] + #[cfg(all(feature = "serde-str", not(feature = "serde-float")))] + fn serialize_decimal_roundtrip() { + let record = Record { + amount: Decimal::new(481, 2), + }; + let serialized = serde_json::to_string(&record).unwrap(); + assert_eq!("{\"amount\":\"4.81\"}", serialized); + let deserialized: Record = serde_json::from_str(&serialized).unwrap(); + assert_eq!(record.amount, deserialized.amount); + } + + #[test] + #[cfg(all(feature = "serde-str", not(feature = "serde-float")))] + fn bincode_serialization() { + use bincode::{deserialize, serialize}; + + let data = [ + "0", + "0.00", + "3.14159", + "-3.14159", + "1234567890123.4567890", + "-1234567890123.4567890", + "5233.9008808150288439427720175", + "-5233.9008808150288439427720175", + ]; + for &raw in data.iter() { + let value = Decimal::from_str(raw).unwrap(); + let encoded = serialize(&value).unwrap(); + let decoded: Decimal = deserialize(&encoded[..]).unwrap(); + assert_eq!(value, decoded); + assert_eq!(8usize + raw.len(), encoded.len()); + } + } + + #[test] + #[cfg(all(feature = "serde-str", feature = "serde-float"))] + fn bincode_serialization() { + use bincode::{deserialize, serialize}; + + let data = [ + ("0", "0"), + ("0.00", "0.00"), + ("3.14159", "3.14159"), + ("-3.14159", "-3.14159"), + ("1234567890123.4567890", "1234567890123.4568"), + ("-1234567890123.4567890", "-1234567890123.4568"), + ]; + for &(value, expected) in data.iter() { + let value = Decimal::from_str(value).unwrap(); + let expected = Decimal::from_str(expected).unwrap(); + let encoded = serialize(&value).unwrap(); + let decoded: Decimal = deserialize(&encoded[..]).unwrap(); + assert_eq!(expected, decoded); + assert_eq!(8usize, encoded.len()); + } + } + + #[test] + #[cfg(all(feature = "serde-str", not(feature = "serde-float")))] + fn bincode_nested_serialization() { + // Issue #361 + #[derive(Deserialize, Serialize, Debug)] + pub struct Foo { + value: Decimal, + } + + let s = Foo { + value: Decimal::new(-1, 3).round_dp(0), + }; + let ser = bincode::serialize(&s).unwrap(); + let des: Foo = bincode::deserialize(&ser).unwrap(); + assert_eq!(des.value, s.value); + } + + #[test] + #[cfg(feature = "serde-with-arbitrary-precision")] + fn with_arbitrary_precision() { + #[derive(Serialize, Deserialize)] + pub struct ArbitraryExample { + #[serde(with = "crate::serde::arbitrary_precision")] + value: Decimal, + } + + let value = ArbitraryExample { + value: Decimal::from_str("123.400").unwrap(), + }; + assert_eq!(&serde_json::to_string(&value).unwrap(), r#"{"value":123.400}"#); + } + + #[test] + #[cfg(feature = "serde-with-arbitrary-precision")] + fn with_arbitrary_precision_from_string() { + #[derive(Serialize, Deserialize)] + pub struct ArbitraryExample { + #[serde(with = "crate::serde::arbitrary_precision")] + value: Decimal, + } + + let value: ArbitraryExample = serde_json::from_str(r#"{"value":"1.1234127836128763"}"#).unwrap(); + assert_eq!(value.value.to_string(), "1.1234127836128763"); + } + + #[test] + #[cfg(feature = "serde-with-float")] + fn with_float() { + #[derive(Serialize, Deserialize)] + pub struct FloatExample { + #[serde(with = "crate::serde::float")] + value: Decimal, + } + + let value = FloatExample { + value: Decimal::from_str("123.400").unwrap(), + }; + assert_eq!(&serde_json::to_string(&value).unwrap(), r#"{"value":123.4}"#); + } + + #[test] + #[cfg(feature = "serde-with-str")] + fn with_str() { + #[derive(Serialize, Deserialize)] + pub struct StringExample { + #[serde(with = "crate::serde::str")] + value: Decimal, + } + + let value = StringExample { + value: Decimal::from_str("123.400").unwrap(), + }; + assert_eq!(&serde_json::to_string(&value).unwrap(), r#"{"value":"123.400"}"#); + } + + #[test] + #[cfg(feature = "serde-with-str")] + fn with_str_bincode() { + use bincode::{deserialize, serialize}; + + #[derive(Serialize, Deserialize)] + struct BincodeExample { + #[serde(with = "crate::serde::str")] + value: Decimal, + } + + let data = [ + ("0", "0"), + ("0.00", "0.00"), + ("1.234", "1.234"), + ("3.14159", "3.14159"), + ("-3.14159", "-3.14159"), + ("1234567890123.4567890", "1234567890123.4567890"), + ("-1234567890123.4567890", "-1234567890123.4567890"), + ]; + for &(value, expected) in data.iter() { + let value = Decimal::from_str(value).unwrap(); + let expected = Decimal::from_str(expected).unwrap(); + let input = BincodeExample { value }; + + let encoded = serialize(&input).unwrap(); + let decoded: BincodeExample = deserialize(&encoded[..]).unwrap(); + assert_eq!(expected, decoded.value); + } + } + + #[test] + #[cfg(feature = "serde-with-str")] + fn with_str_bincode_optional() { + use bincode::{deserialize, serialize}; + + #[derive(Serialize, Deserialize)] + struct BincodeExample { + #[serde(with = "crate::serde::str_option")] + value: Option<Decimal>, + } + + // Some(value) + let value = Some(Decimal::new(1234, 3)); + let input = BincodeExample { value }; + let encoded = serialize(&input).unwrap(); + let decoded: BincodeExample = deserialize(&encoded[..]).unwrap(); + assert_eq!(value, decoded.value, "Some(value)"); + + // None + let input = BincodeExample { value: None }; + let encoded = serialize(&input).unwrap(); + let decoded: BincodeExample = deserialize(&encoded[..]).unwrap(); + assert_eq!(None, decoded.value, "None"); + } + + #[test] + #[cfg(feature = "serde-with-str")] + fn with_str_optional() { + #[derive(Serialize, Deserialize)] + pub struct StringExample { + #[serde(with = "crate::serde::str_option")] + value: Option<Decimal>, + } + + let original = StringExample { + value: Some(Decimal::from_str("123.400").unwrap()), + }; + assert_eq!(&serde_json::to_string(&original).unwrap(), r#"{"value":"123.400"}"#); + let deserialized: StringExample = serde_json::from_str(r#"{"value":"123.400"}"#).unwrap(); + assert_eq!(deserialized.value, original.value); + assert!(deserialized.value.is_some()); + assert_eq!(deserialized.value.unwrap().unpack(), original.value.unwrap().unpack()); + + // Null tests + let original = StringExample { value: None }; + assert_eq!(&serde_json::to_string(&original).unwrap(), r#"{"value":null}"#); + let deserialized: StringExample = serde_json::from_str(r#"{"value":null}"#).unwrap(); + assert_eq!(deserialized.value, original.value); + assert!(deserialized.value.is_none()); + } + + #[test] + #[cfg(feature = "serde-with-float")] + fn with_float_optional() { + #[derive(Serialize, Deserialize)] + pub struct StringExample { + #[serde(with = "crate::serde::float_option")] + value: Option<Decimal>, + } + + let original = StringExample { + value: Some(Decimal::from_str("123.400").unwrap()), + }; + assert_eq!(&serde_json::to_string(&original).unwrap(), r#"{"value":123.4}"#); + let deserialized: StringExample = serde_json::from_str(r#"{"value":123.4}"#).unwrap(); + assert_eq!(deserialized.value, original.value); + assert!(deserialized.value.is_some()); // Scale is different! + + // Null tests + let original = StringExample { value: None }; + assert_eq!(&serde_json::to_string(&original).unwrap(), r#"{"value":null}"#); + let deserialized: StringExample = serde_json::from_str(r#"{"value":null}"#).unwrap(); + assert_eq!(deserialized.value, original.value); + assert!(deserialized.value.is_none()); + } + + #[test] + #[cfg(feature = "serde-with-arbitrary-precision")] + fn with_arbitrary_precision_optional() { + #[derive(Serialize, Deserialize)] + pub struct StringExample { + #[serde(with = "crate::serde::arbitrary_precision_option")] + value: Option<Decimal>, + } + + let original = StringExample { + value: Some(Decimal::from_str("123.400").unwrap()), + }; + assert_eq!(&serde_json::to_string(&original).unwrap(), r#"{"value":123.400}"#); + let deserialized: StringExample = serde_json::from_str(r#"{"value":123.400}"#).unwrap(); + assert_eq!(deserialized.value, original.value); + assert!(deserialized.value.is_some()); + assert_eq!(deserialized.value.unwrap().unpack(), original.value.unwrap().unpack()); + + // Null tests + let original = StringExample { value: None }; + assert_eq!(&serde_json::to_string(&original).unwrap(), r#"{"value":null}"#); + let deserialized: StringExample = serde_json::from_str(r#"{"value":null}"#).unwrap(); + assert_eq!(deserialized.value, original.value); + assert!(deserialized.value.is_none()); + } +} diff --git a/third_party/rust/rust_decimal/src/str.rs b/third_party/rust/rust_decimal/src/str.rs new file mode 100644 index 0000000000..f3b89d31e0 --- /dev/null +++ b/third_party/rust/rust_decimal/src/str.rs @@ -0,0 +1,993 @@ +use crate::{ + constants::{BYTES_TO_OVERFLOW_U64, MAX_PRECISION, MAX_STR_BUFFER_SIZE, OVERFLOW_U96, WILL_OVERFLOW_U64}, + error::{tail_error, Error}, + ops::array::{add_by_internal_flattened, add_one_internal, div_by_u32, is_all_zero, mul_by_u32}, + Decimal, +}; + +use arrayvec::{ArrayString, ArrayVec}; + +use alloc::{string::String, vec::Vec}; +use core::fmt; + +// impl that doesn't allocate for serialization purposes. +pub(crate) fn to_str_internal( + value: &Decimal, + append_sign: bool, + precision: Option<usize>, +) -> (ArrayString<MAX_STR_BUFFER_SIZE>, Option<usize>) { + // Get the scale - where we need to put the decimal point + let scale = value.scale() as usize; + + // Convert to a string and manipulate that (neg at front, inject decimal) + let mut chars = ArrayVec::<_, MAX_STR_BUFFER_SIZE>::new(); + let mut working = value.mantissa_array3(); + while !is_all_zero(&working) { + let remainder = div_by_u32(&mut working, 10u32); + chars.push(char::from(b'0' + remainder as u8)); + } + while scale > chars.len() { + chars.push('0'); + } + + let (prec, additional) = match precision { + Some(prec) => { + let max: usize = MAX_PRECISION.into(); + if prec > max { + (max, Some(prec - max)) + } else { + (prec, None) + } + } + None => (scale, None), + }; + + let len = chars.len(); + let whole_len = len - scale; + let mut rep = ArrayString::new(); + // Append the negative sign if necessary while also keeping track of the length of an "empty" string representation + let empty_len = if append_sign && value.is_sign_negative() { + rep.push('-'); + 1 + } else { + 0 + }; + for i in 0..whole_len + prec { + if i == len - scale { + if i == 0 { + rep.push('0'); + } + rep.push('.'); + } + + if i >= len { + rep.push('0'); + } else { + let c = chars[len - i - 1]; + rep.push(c); + } + } + + // corner case for when we truncated everything in a low fractional + if rep.len() == empty_len { + rep.push('0'); + } + + (rep, additional) +} + +pub(crate) fn fmt_scientific_notation( + value: &Decimal, + exponent_symbol: &str, + f: &mut fmt::Formatter<'_>, +) -> fmt::Result { + #[cfg(not(feature = "std"))] + use alloc::string::ToString; + + // Get the scale - this is the e value. With multiples of 10 this may get bigger. + let mut exponent = -(value.scale() as isize); + + // Convert the integral to a string + let mut chars = Vec::new(); + let mut working = value.mantissa_array3(); + while !is_all_zero(&working) { + let remainder = div_by_u32(&mut working, 10u32); + chars.push(char::from(b'0' + remainder as u8)); + } + + // First of all, apply scientific notation rules. That is: + // 1. If non-zero digit comes first, move decimal point left so that e is a positive integer + // 2. If decimal point comes first, move decimal point right until after the first non-zero digit + // Since decimal notation naturally lends itself this way, we just need to inject the decimal + // point in the right place and adjust the exponent accordingly. + + let len = chars.len(); + let mut rep; + // We either are operating with a precision specified, or on defaults. Defaults will perform "smart" + // reduction of precision. + if let Some(precision) = f.precision() { + if len > 1 { + // If we're zero precision AND it's trailing zeros then strip them + if precision == 0 && chars.iter().take(len - 1).all(|c| *c == '0') { + rep = chars.iter().skip(len - 1).collect::<String>(); + } else { + // We may still be zero precision, however we aren't trailing zeros + if precision > 0 { + chars.insert(len - 1, '.'); + } + rep = chars + .iter() + .rev() + // Add on extra zeros according to the precision. At least one, since we added a decimal place. + .chain(core::iter::repeat(&'0')) + .take(if precision == 0 { 1 } else { 2 + precision }) + .collect::<String>(); + } + exponent += (len - 1) as isize; + } else if precision > 0 { + // We have precision that we want to add + chars.push('.'); + rep = chars + .iter() + .chain(core::iter::repeat(&'0')) + .take(2 + precision) + .collect::<String>(); + } else { + rep = chars.iter().collect::<String>(); + } + } else if len > 1 { + // If the number is just trailing zeros then we treat it like 0 precision + if chars.iter().take(len - 1).all(|c| *c == '0') { + rep = chars.iter().skip(len - 1).collect::<String>(); + } else { + // Otherwise, we need to insert a decimal place and make it a scientific number + chars.insert(len - 1, '.'); + rep = chars.iter().rev().collect::<String>(); + } + exponent += (len - 1) as isize; + } else { + rep = chars.iter().collect::<String>(); + } + + rep.push_str(exponent_symbol); + rep.push_str(&exponent.to_string()); + f.pad_integral(value.is_sign_positive(), "", &rep) +} + +// dedicated implementation for the most common case. +#[inline] +pub(crate) fn parse_str_radix_10(str: &str) -> Result<Decimal, Error> { + let bytes = str.as_bytes(); + if bytes.len() < BYTES_TO_OVERFLOW_U64 { + parse_str_radix_10_dispatch::<false, true>(bytes) + } else { + parse_str_radix_10_dispatch::<true, true>(bytes) + } +} + +#[inline] +pub(crate) fn parse_str_radix_10_exact(str: &str) -> Result<Decimal, Error> { + let bytes = str.as_bytes(); + if bytes.len() < BYTES_TO_OVERFLOW_U64 { + parse_str_radix_10_dispatch::<false, false>(bytes) + } else { + parse_str_radix_10_dispatch::<true, false>(bytes) + } +} + +#[inline] +fn parse_str_radix_10_dispatch<const BIG: bool, const ROUND: bool>(bytes: &[u8]) -> Result<Decimal, Error> { + match bytes { + [b, rest @ ..] => byte_dispatch_u64::<false, false, false, BIG, true, ROUND>(rest, 0, 0, *b), + [] => tail_error("Invalid decimal: empty"), + } +} + +#[inline] +fn overflow_64(val: u64) -> bool { + val >= WILL_OVERFLOW_U64 +} + +#[inline] +pub fn overflow_128(val: u128) -> bool { + val >= OVERFLOW_U96 +} + +/// Dispatch the next byte: +/// +/// * POINT - a decimal point has been seen +/// * NEG - we've encountered a `-` and the number is negative +/// * HAS - a digit has been encountered (when HAS is false it's invalid) +/// * BIG - a number that uses 96 bits instead of only 64 bits +/// * FIRST - true if it is the first byte in the string +#[inline] +fn dispatch_next<const POINT: bool, const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result<Decimal, Error> { + if let Some((next, bytes)) = bytes.split_first() { + byte_dispatch_u64::<POINT, NEG, HAS, BIG, false, ROUND>(bytes, data64, scale, *next) + } else { + handle_data::<NEG, HAS>(data64 as u128, scale) + } +} + +#[inline(never)] +fn non_digit_dispatch_u64< + const POINT: bool, + const NEG: bool, + const HAS: bool, + const BIG: bool, + const FIRST: bool, + const ROUND: bool, +>( + bytes: &[u8], + data64: u64, + scale: u8, + b: u8, +) -> Result<Decimal, Error> { + match b { + b'-' if FIRST && !HAS => dispatch_next::<false, true, false, BIG, ROUND>(bytes, data64, scale), + b'+' if FIRST && !HAS => dispatch_next::<false, false, false, BIG, ROUND>(bytes, data64, scale), + b'_' if HAS => handle_separator::<POINT, NEG, BIG, ROUND>(bytes, data64, scale), + b => tail_invalid_digit(b), + } +} + +#[inline] +fn byte_dispatch_u64< + const POINT: bool, + const NEG: bool, + const HAS: bool, + const BIG: bool, + const FIRST: bool, + const ROUND: bool, +>( + bytes: &[u8], + data64: u64, + scale: u8, + b: u8, +) -> Result<Decimal, Error> { + match b { + b'0'..=b'9' => handle_digit_64::<POINT, NEG, BIG, ROUND>(bytes, data64, scale, b - b'0'), + b'.' if !POINT => handle_point::<NEG, HAS, BIG, ROUND>(bytes, data64, scale), + b => non_digit_dispatch_u64::<POINT, NEG, HAS, BIG, FIRST, ROUND>(bytes, data64, scale, b), + } +} + +#[inline(never)] +fn handle_digit_64<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, + digit: u8, +) -> Result<Decimal, Error> { + // we have already validated that we cannot overflow + let data64 = data64 * 10 + digit as u64; + let scale = if POINT { scale + 1 } else { 0 }; + + if let Some((next, bytes)) = bytes.split_first() { + let next = *next; + if POINT && BIG && scale >= 28 { + if ROUND { + maybe_round(data64 as u128, next, scale, POINT, NEG) + } else { + Err(Error::Underflow) + } + } else if BIG && overflow_64(data64) { + handle_full_128::<POINT, NEG, ROUND>(data64 as u128, bytes, scale, next) + } else { + byte_dispatch_u64::<POINT, NEG, true, BIG, false, ROUND>(bytes, data64, scale, next) + } + } else { + let data: u128 = data64 as u128; + + handle_data::<NEG, true>(data, scale) + } +} + +#[inline(never)] +fn handle_point<const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result<Decimal, Error> { + dispatch_next::<true, NEG, HAS, BIG, ROUND>(bytes, data64, scale) +} + +#[inline(never)] +fn handle_separator<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result<Decimal, Error> { + dispatch_next::<POINT, NEG, true, BIG, ROUND>(bytes, data64, scale) +} + +#[inline(never)] +#[cold] +fn tail_invalid_digit(digit: u8) -> Result<Decimal, Error> { + match digit { + b'.' => tail_error("Invalid decimal: two decimal points"), + b'_' => tail_error("Invalid decimal: must start lead with a number"), + _ => tail_error("Invalid decimal: unknown character"), + } +} + +#[inline(never)] +#[cold] +fn handle_full_128<const POINT: bool, const NEG: bool, const ROUND: bool>( + mut data: u128, + bytes: &[u8], + scale: u8, + next_byte: u8, +) -> Result<Decimal, Error> { + let b = next_byte; + match b { + b'0'..=b'9' => { + let digit = u32::from(b - b'0'); + + // If the data is going to overflow then we should go into recovery mode + let next = (data * 10) + digit as u128; + if overflow_128(next) { + if !POINT { + return tail_error("Invalid decimal: overflow from too many digits"); + } + + if ROUND { + maybe_round(data, next_byte, scale, POINT, NEG) + } else { + Err(Error::Underflow) + } + } else { + data = next; + let scale = scale + POINT as u8; + if let Some((next, bytes)) = bytes.split_first() { + let next = *next; + if POINT && scale >= 28 { + if ROUND { + maybe_round(data, next, scale, POINT, NEG) + } else { + Err(Error::Underflow) + } + } else { + handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, next) + } + } else { + handle_data::<NEG, true>(data, scale) + } + } + } + b'.' if !POINT => { + // This call won't tail? + if let Some((next, bytes)) = bytes.split_first() { + handle_full_128::<true, NEG, ROUND>(data, bytes, scale, *next) + } else { + handle_data::<NEG, true>(data, scale) + } + } + b'_' => { + if let Some((next, bytes)) = bytes.split_first() { + handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, *next) + } else { + handle_data::<NEG, true>(data, scale) + } + } + b => tail_invalid_digit(b), + } +} + +#[inline(never)] +#[cold] +fn maybe_round( + mut data: u128, + next_byte: u8, + mut scale: u8, + point: bool, + negative: bool, +) -> Result<Decimal, crate::Error> { + let digit = match next_byte { + b'0'..=b'9' => u32::from(next_byte - b'0'), + b'_' => 0, // this should be an invalid string? + b'.' if point => 0, + b => return tail_invalid_digit(b), + }; + + // Round at midpoint + if digit >= 5 { + data += 1; + + // If the mantissa is now overflowing, round to the next + // next least significant digit and discard precision + if overflow_128(data) { + if scale == 0 { + return tail_error("Invalid decimal: overflow from mantissa after rounding"); + } + data += 4; + data /= 10; + scale -= 1; + } + } + + if negative { + handle_data::<true, true>(data, scale) + } else { + handle_data::<false, true>(data, scale) + } +} + +#[inline(never)] +fn tail_no_has() -> Result<Decimal, Error> { + tail_error("Invalid decimal: no digits found") +} + +#[inline] +fn handle_data<const NEG: bool, const HAS: bool>(data: u128, scale: u8) -> Result<Decimal, Error> { + debug_assert_eq!(data >> 96, 0); + if !HAS { + tail_no_has() + } else { + Ok(Decimal::from_parts( + data as u32, + (data >> 32) as u32, + (data >> 64) as u32, + NEG, + scale as u32, + )) + } +} + +pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result<Decimal, Error> { + if str.is_empty() { + return Err(Error::from("Invalid decimal: empty")); + } + if radix < 2 { + return Err(Error::from("Unsupported radix < 2")); + } + if radix > 36 { + // As per trait documentation + return Err(Error::from("Unsupported radix > 36")); + } + + let mut offset = 0; + let mut len = str.len(); + let bytes = str.as_bytes(); + let mut negative = false; // assume positive + + // handle the sign + if bytes[offset] == b'-' { + negative = true; // leading minus means negative + offset += 1; + len -= 1; + } else if bytes[offset] == b'+' { + // leading + allowed + offset += 1; + len -= 1; + } + + // should now be at numeric part of the significand + let mut digits_before_dot: i32 = -1; // digits before '.', -1 if no '.' + let mut coeff = ArrayVec::<_, 96>::new(); // integer significand array + + // Supporting different radix + let (max_n, max_alpha_lower, max_alpha_upper) = if radix <= 10 { + (b'0' + (radix - 1) as u8, 0, 0) + } else { + let adj = (radix - 11) as u8; + (b'9', adj + b'a', adj + b'A') + }; + + // Estimate the max precision. All in all, it needs to fit into 96 bits. + // Rather than try to estimate, I've included the constants directly in here. We could, + // perhaps, replace this with a formula if it's faster - though it does appear to be log2. + let estimated_max_precision = match radix { + 2 => 96, + 3 => 61, + 4 => 48, + 5 => 42, + 6 => 38, + 7 => 35, + 8 => 32, + 9 => 31, + 10 => 28, + 11 => 28, + 12 => 27, + 13 => 26, + 14 => 26, + 15 => 25, + 16 => 24, + 17 => 24, + 18 => 24, + 19 => 23, + 20 => 23, + 21 => 22, + 22 => 22, + 23 => 22, + 24 => 21, + 25 => 21, + 26 => 21, + 27 => 21, + 28 => 20, + 29 => 20, + 30 => 20, + 31 => 20, + 32 => 20, + 33 => 20, + 34 => 19, + 35 => 19, + 36 => 19, + _ => return Err(Error::from("Unsupported radix")), + }; + + let mut maybe_round = false; + while len > 0 { + let b = bytes[offset]; + match b { + b'0'..=b'9' => { + if b > max_n { + return Err(Error::from("Invalid decimal: invalid character")); + } + coeff.push(u32::from(b - b'0')); + offset += 1; + len -= 1; + + // If the coefficient is longer than the max, exit early + if coeff.len() as u32 > estimated_max_precision { + maybe_round = true; + break; + } + } + b'a'..=b'z' => { + if b > max_alpha_lower { + return Err(Error::from("Invalid decimal: invalid character")); + } + coeff.push(u32::from(b - b'a') + 10); + offset += 1; + len -= 1; + + if coeff.len() as u32 > estimated_max_precision { + maybe_round = true; + break; + } + } + b'A'..=b'Z' => { + if b > max_alpha_upper { + return Err(Error::from("Invalid decimal: invalid character")); + } + coeff.push(u32::from(b - b'A') + 10); + offset += 1; + len -= 1; + + if coeff.len() as u32 > estimated_max_precision { + maybe_round = true; + break; + } + } + b'.' => { + if digits_before_dot >= 0 { + return Err(Error::from("Invalid decimal: two decimal points")); + } + digits_before_dot = coeff.len() as i32; + offset += 1; + len -= 1; + } + b'_' => { + // Must start with a number... + if coeff.is_empty() { + return Err(Error::from("Invalid decimal: must start lead with a number")); + } + offset += 1; + len -= 1; + } + _ => return Err(Error::from("Invalid decimal: unknown character")), + } + } + + // If we exited before the end of the string then do some rounding if necessary + if maybe_round && offset < bytes.len() { + let next_byte = bytes[offset]; + let digit = match next_byte { + b'0'..=b'9' => { + if next_byte > max_n { + return Err(Error::from("Invalid decimal: invalid character")); + } + u32::from(next_byte - b'0') + } + b'a'..=b'z' => { + if next_byte > max_alpha_lower { + return Err(Error::from("Invalid decimal: invalid character")); + } + u32::from(next_byte - b'a') + 10 + } + b'A'..=b'Z' => { + if next_byte > max_alpha_upper { + return Err(Error::from("Invalid decimal: invalid character")); + } + u32::from(next_byte - b'A') + 10 + } + b'_' => 0, + b'.' => { + // Still an error if we have a second dp + if digits_before_dot >= 0 { + return Err(Error::from("Invalid decimal: two decimal points")); + } + 0 + } + _ => return Err(Error::from("Invalid decimal: unknown character")), + }; + + // Round at midpoint + let midpoint = if radix & 0x1 == 1 { radix / 2 } else { (radix + 1) / 2 }; + if digit >= midpoint { + let mut index = coeff.len() - 1; + loop { + let new_digit = coeff[index] + 1; + if new_digit <= 9 { + coeff[index] = new_digit; + break; + } else { + coeff[index] = 0; + if index == 0 { + coeff.insert(0, 1u32); + digits_before_dot += 1; + coeff.pop(); + break; + } + } + index -= 1; + } + } + } + + // here when no characters left + if coeff.is_empty() { + return Err(Error::from("Invalid decimal: no digits found")); + } + + let mut scale = if digits_before_dot >= 0 { + // we had a decimal place so set the scale + (coeff.len() as u32) - (digits_before_dot as u32) + } else { + 0 + }; + + // Parse this using specified radix + let mut data = [0u32, 0u32, 0u32]; + let mut tmp = [0u32, 0u32, 0u32]; + let len = coeff.len(); + for (i, digit) in coeff.iter().enumerate() { + // If the data is going to overflow then we should go into recovery mode + tmp[0] = data[0]; + tmp[1] = data[1]; + tmp[2] = data[2]; + let overflow = mul_by_u32(&mut tmp, radix); + if overflow > 0 { + // This means that we have more data to process, that we're not sure what to do with. + // This may or may not be an issue - depending on whether we're past a decimal point + // or not. + if (i as i32) < digits_before_dot && i + 1 < len { + return Err(Error::from("Invalid decimal: overflow from too many digits")); + } + + if *digit >= 5 { + let carry = add_one_internal(&mut data); + if carry > 0 { + // Highly unlikely scenario which is more indicative of a bug + return Err(Error::from("Invalid decimal: overflow when rounding")); + } + } + // We're also one less digit so reduce the scale + let diff = (len - i) as u32; + if diff > scale { + return Err(Error::from("Invalid decimal: overflow from scale mismatch")); + } + scale -= diff; + break; + } else { + data[0] = tmp[0]; + data[1] = tmp[1]; + data[2] = tmp[2]; + let carry = add_by_internal_flattened(&mut data, *digit); + if carry > 0 { + // Highly unlikely scenario which is more indicative of a bug + return Err(Error::from("Invalid decimal: overflow from carry")); + } + } + } + + Ok(Decimal::from_parts(data[0], data[1], data[2], negative, scale)) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::Decimal; + use arrayvec::ArrayString; + use core::{fmt::Write, str::FromStr}; + + #[test] + fn display_does_not_overflow_max_capacity() { + let num = Decimal::from_str("1.2").unwrap(); + let mut buffer = ArrayString::<64>::new(); + let _ = buffer.write_fmt(format_args!("{:.31}", num)).unwrap(); + assert_eq!("1.2000000000000000000000000000000", buffer.as_str()); + } + + #[test] + fn from_str_rounding_0() { + assert_eq!( + parse_str_radix_10("1.234").unwrap().unpack(), + Decimal::new(1234, 3).unpack() + ); + } + + #[test] + fn from_str_rounding_1() { + assert_eq!( + parse_str_radix_10("11111_11111_11111.11111_11111_11111") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_111, 14).unpack() + ); + } + + #[test] + fn from_str_rounding_2() { + assert_eq!( + parse_str_radix_10("11111_11111_11111.11111_11111_11115") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_112, 14).unpack() + ); + } + + #[test] + fn from_str_rounding_3() { + assert_eq!( + parse_str_radix_10("11111_11111_11111.11111_11111_11195") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_1120, 14).unpack() // was Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_112, 13) + ); + } + + #[test] + fn from_str_rounding_4() { + assert_eq!( + parse_str_radix_10("99999_99999_99999.99999_99999_99995") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 13).unpack() // was Decimal::from_i128_with_scale(1_000_000_000_000_000_000_000_000_000, 12) + ); + } + + #[test] + fn from_str_no_rounding_0() { + assert_eq!( + parse_str_radix_10_exact("1.234").unwrap().unpack(), + Decimal::new(1234, 3).unpack() + ); + } + + #[test] + fn from_str_no_rounding_1() { + assert_eq!( + parse_str_radix_10_exact("11111_11111_11111.11111_11111_11111"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_no_rounding_2() { + assert_eq!( + parse_str_radix_10_exact("11111_11111_11111.11111_11111_11115"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_no_rounding_3() { + assert_eq!( + parse_str_radix_10_exact("11111_11111_11111.11111_11111_11195"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_no_rounding_4() { + assert_eq!( + parse_str_radix_10_exact("99999_99999_99999.99999_99999_99995"), + Err(Error::Underflow) + ); + } + + #[test] + fn from_str_many_pointless_chars() { + assert_eq!( + parse_str_radix_10("00________________________________________________________________001.1") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(11, 1).unpack() + ); + } + + #[test] + fn from_str_leading_0s_1() { + assert_eq!( + parse_str_radix_10("00001.1").unwrap().unpack(), + Decimal::from_i128_with_scale(11, 1).unpack() + ); + } + + #[test] + fn from_str_leading_0s_2() { + assert_eq!( + parse_str_radix_10("00000_00000_00000_00000_00001.00001") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(100001, 5).unpack() + ); + } + + #[test] + fn from_str_leading_0s_3() { + assert_eq!( + parse_str_radix_10("0.00000_00000_00000_00000_00000_00100") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(1, 28).unpack() + ); + } + + #[test] + fn from_str_trailing_0s_1() { + assert_eq!( + parse_str_radix_10("0.00001_00000_00000").unwrap().unpack(), + Decimal::from_i128_with_scale(10_000_000_000, 15).unpack() + ); + } + + #[test] + fn from_str_trailing_0s_2() { + assert_eq!( + parse_str_radix_10("0.00001_00000_00000_00000_00000_00000") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(100_000_000_000_000_000_000_000, 28).unpack() + ); + } + + #[test] + fn from_str_overflow_1() { + assert_eq!( + parse_str_radix_10("99999_99999_99999_99999_99999_99999.99999"), + // The original implementation returned + // Ok(10000_00000_00000_00000_00000_0000) + // Which is a bug! + Err(Error::from("Invalid decimal: overflow from too many digits")) + ); + } + + #[test] + fn from_str_overflow_2() { + assert!( + parse_str_radix_10("99999_99999_99999_99999_99999_11111.11111").is_err(), + // The original implementation is 'overflow from scale mismatch' + // but we got rid of that now + ); + } + + #[test] + fn from_str_overflow_3() { + assert!( + parse_str_radix_10("99999_99999_99999_99999_99999_99994").is_err() // We could not get into 'overflow when rounding' or 'overflow from carry' + // in the original implementation because the rounding logic before prevented it + ); + } + + #[test] + fn from_str_overflow_4() { + assert_eq!( + // This does not overflow, moving the decimal point 1 more step would result in + // 'overflow from too many digits' + parse_str_radix_10("99999_99999_99999_99999_99999_999.99") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 0).unpack() + ); + } + + #[test] + fn from_str_mantissa_overflow_1() { + // reminder: + assert_eq!(OVERFLOW_U96, 79_228_162_514_264_337_593_543_950_336); + assert_eq!( + parse_str_radix_10("79_228_162_514_264_337_593_543_950_33.56") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 0).unpack() + ); + // This is a mantissa of OVERFLOW_U96 - 1 just before reaching the last digit. + // Previously, this would return Err("overflow from mantissa after rounding") + // instead of successfully rounding. + } + + #[test] + fn from_str_mantissa_overflow_2() { + assert_eq!( + parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.6"), + Err(Error::from("Invalid decimal: overflow from mantissa after rounding")) + ); + // this case wants to round to 79_228_162_514_264_337_593_543_950_340. + // (79_228_162_514_264_337_593_543_950_336 is OVERFLOW_U96 and too large + // to fit in 96 bits) which is also too large for the mantissa so fails. + } + + #[test] + fn from_str_mantissa_overflow_3() { + // this hits the other avoidable overflow case in maybe_round + assert_eq!( + parse_str_radix_10("7.92281625142643375935439503356").unwrap().unpack(), + Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack() + ); + } + + #[ignore] + #[test] + fn from_str_mantissa_overflow_4() { + // Same test as above, however with underscores. This causes issues. + assert_eq!( + parse_str_radix_10("7.9_228_162_514_264_337_593_543_950_335_6") + .unwrap() + .unpack(), + Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack() + ); + } + + #[test] + fn from_str_edge_cases_1() { + assert_eq!(parse_str_radix_10(""), Err(Error::from("Invalid decimal: empty"))); + } + + #[test] + fn from_str_edge_cases_2() { + assert_eq!( + parse_str_radix_10("0.1."), + Err(Error::from("Invalid decimal: two decimal points")) + ); + } + + #[test] + fn from_str_edge_cases_3() { + assert_eq!( + parse_str_radix_10("_"), + Err(Error::from("Invalid decimal: must start lead with a number")) + ); + } + + #[test] + fn from_str_edge_cases_4() { + assert_eq!( + parse_str_radix_10("1?2"), + Err(Error::from("Invalid decimal: unknown character")) + ); + } + + #[test] + fn from_str_edge_cases_5() { + assert_eq!( + parse_str_radix_10("."), + Err(Error::from("Invalid decimal: no digits found")) + ); + } + + #[test] + fn from_str_edge_cases_6() { + // Decimal::MAX + 0.99999 + assert_eq!( + parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.99999"), + Err(Error::from("Invalid decimal: overflow from mantissa after rounding")) + ); + } +} |