diff options
Diffstat (limited to 'third_party/rust/num-integer/src/roots.rs')
-rw-r--r-- | third_party/rust/num-integer/src/roots.rs | 391 |
1 files changed, 391 insertions, 0 deletions
diff --git a/third_party/rust/num-integer/src/roots.rs b/third_party/rust/num-integer/src/roots.rs new file mode 100644 index 0000000000..a9eec1a93c --- /dev/null +++ b/third_party/rust/num-integer/src/roots.rs @@ -0,0 +1,391 @@ +use core; +use core::mem; +use traits::checked_pow; +use traits::PrimInt; +use Integer; + +/// Provides methods to compute an integer's square root, cube root, +/// and arbitrary `n`th root. +pub trait Roots: Integer { + /// Returns the truncated principal `n`th root of an integer + /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }` + /// + /// This is solving for `r` in `rⁿ = x`, rounding toward zero. + /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`. + /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`. + /// + /// # Panics + /// + /// Panics if `n` is zero: + /// + /// ```should_panic + /// # use num_integer::Roots; + /// println!("can't compute ⁰√x : {}", 123.nth_root(0)); + /// ``` + /// + /// or if `n` is even and `self` is negative: + /// + /// ```should_panic + /// # use num_integer::Roots; + /// println!("no imaginary numbers... {}", (-1).nth_root(10)); + /// ``` + /// + /// # Examples + /// + /// ``` + /// use num_integer::Roots; + /// + /// let x: i32 = 12345; + /// assert_eq!(x.nth_root(1), x); + /// assert_eq!(x.nth_root(2), x.sqrt()); + /// assert_eq!(x.nth_root(3), x.cbrt()); + /// assert_eq!(x.nth_root(4), 10); + /// assert_eq!(x.nth_root(13), 2); + /// assert_eq!(x.nth_root(14), 1); + /// assert_eq!(x.nth_root(std::u32::MAX), 1); + /// + /// assert_eq!(std::i32::MAX.nth_root(30), 2); + /// assert_eq!(std::i32::MAX.nth_root(31), 1); + /// assert_eq!(std::i32::MIN.nth_root(31), -2); + /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1); + /// + /// assert_eq!(std::u32::MAX.nth_root(31), 2); + /// assert_eq!(std::u32::MAX.nth_root(32), 1); + /// ``` + fn nth_root(&self, n: u32) -> Self; + + /// Returns the truncated principal square root of an integer -- `⌊√x⌋` + /// + /// This is solving for `r` in `r² = x`, rounding toward zero. + /// The result will satisfy `r² ≤ x < (r+1)²`. + /// + /// # Panics + /// + /// Panics if `self` is less than zero: + /// + /// ```should_panic + /// # use num_integer::Roots; + /// println!("no imaginary numbers... {}", (-1).sqrt()); + /// ``` + /// + /// # Examples + /// + /// ``` + /// use num_integer::Roots; + /// + /// let x: i32 = 12345; + /// assert_eq!((x * x).sqrt(), x); + /// assert_eq!((x * x + 1).sqrt(), x); + /// assert_eq!((x * x - 1).sqrt(), x - 1); + /// ``` + #[inline] + fn sqrt(&self) -> Self { + self.nth_root(2) + } + + /// Returns the truncated principal cube root of an integer -- + /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }` + /// + /// This is solving for `r` in `r³ = x`, rounding toward zero. + /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`. + /// If `x` is negative, then `(r-1)³ < x ≤ r³`. + /// + /// # Examples + /// + /// ``` + /// use num_integer::Roots; + /// + /// let x: i32 = 1234; + /// assert_eq!((x * x * x).cbrt(), x); + /// assert_eq!((x * x * x + 1).cbrt(), x); + /// assert_eq!((x * x * x - 1).cbrt(), x - 1); + /// + /// assert_eq!((-(x * x * x)).cbrt(), -x); + /// assert_eq!((-(x * x * x + 1)).cbrt(), -x); + /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1)); + /// ``` + #[inline] + fn cbrt(&self) -> Self { + self.nth_root(3) + } +} + +/// Returns the truncated principal square root of an integer -- +/// see [Roots::sqrt](trait.Roots.html#method.sqrt). +#[inline] +pub fn sqrt<T: Roots>(x: T) -> T { + x.sqrt() +} + +/// Returns the truncated principal cube root of an integer -- +/// see [Roots::cbrt](trait.Roots.html#method.cbrt). +#[inline] +pub fn cbrt<T: Roots>(x: T) -> T { + x.cbrt() +} + +/// Returns the truncated principal `n`th root of an integer -- +/// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root). +#[inline] +pub fn nth_root<T: Roots>(x: T, n: u32) -> T { + x.nth_root(n) +} + +macro_rules! signed_roots { + ($T:ty, $U:ty) => { + impl Roots for $T { + #[inline] + fn nth_root(&self, n: u32) -> Self { + if *self >= 0 { + (*self as $U).nth_root(n) as Self + } else { + assert!(n.is_odd(), "even roots of a negative are imaginary"); + -((self.wrapping_neg() as $U).nth_root(n) as Self) + } + } + + #[inline] + fn sqrt(&self) -> Self { + assert!(*self >= 0, "the square root of a negative is imaginary"); + (*self as $U).sqrt() as Self + } + + #[inline] + fn cbrt(&self) -> Self { + if *self >= 0 { + (*self as $U).cbrt() as Self + } else { + -((self.wrapping_neg() as $U).cbrt() as Self) + } + } + } + }; +} + +signed_roots!(i8, u8); +signed_roots!(i16, u16); +signed_roots!(i32, u32); +signed_roots!(i64, u64); +#[cfg(has_i128)] +signed_roots!(i128, u128); +signed_roots!(isize, usize); + +#[inline] +fn fixpoint<T, F>(mut x: T, f: F) -> T +where + T: Integer + Copy, + F: Fn(T) -> T, +{ + let mut xn = f(x); + while x < xn { + x = xn; + xn = f(x); + } + while x > xn { + x = xn; + xn = f(x); + } + x +} + +#[inline] +fn bits<T>() -> u32 { + 8 * mem::size_of::<T>() as u32 +} + +#[inline] +fn log2<T: PrimInt>(x: T) -> u32 { + debug_assert!(x > T::zero()); + bits::<T>() - 1 - x.leading_zeros() +} + +macro_rules! unsigned_roots { + ($T:ident) => { + impl Roots for $T { + #[inline] + fn nth_root(&self, n: u32) -> Self { + fn go(a: $T, n: u32) -> $T { + // Specialize small roots + match n { + 0 => panic!("can't find a root of degree 0!"), + 1 => return a, + 2 => return a.sqrt(), + 3 => return a.cbrt(), + _ => (), + } + + // The root of values less than 2ⁿ can only be 0 or 1. + if bits::<$T>() <= n || a < (1 << n) { + return (a > 0) as $T; + } + + if bits::<$T>() > 64 { + // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough. + return if a <= core::u64::MAX as $T { + (a as u64).nth_root(n) as $T + } else { + let lo = (a >> n).nth_root(n) << 1; + let hi = lo + 1; + // 128-bit `checked_mul` also involves division, but we can't always + // compute `hiⁿ` without risking overflow. Try to avoid it though... + if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() { + match checked_pow(hi, n as usize) { + Some(x) if x <= a => hi, + _ => lo, + } + } else { + if hi.pow(n) <= a { + hi + } else { + lo + } + } + }; + } + + #[cfg(feature = "std")] + #[inline] + fn guess(x: $T, n: u32) -> $T { + // for smaller inputs, `f64` doesn't justify its cost. + if bits::<$T>() <= 32 || x <= core::u32::MAX as $T { + 1 << ((log2(x) + n - 1) / n) + } else { + ((x as f64).ln() / f64::from(n)).exp() as $T + } + } + + #[cfg(not(feature = "std"))] + #[inline] + fn guess(x: $T, n: u32) -> $T { + 1 << ((log2(x) + n - 1) / n) + } + + // https://en.wikipedia.org/wiki/Nth_root_algorithm + let n1 = n - 1; + let next = |x: $T| { + let y = match checked_pow(x, n1 as usize) { + Some(ax) => a / ax, + None => 0, + }; + (y + x * n1 as $T) / n as $T + }; + fixpoint(guess(a, n), next) + } + go(*self, n) + } + + #[inline] + fn sqrt(&self) -> Self { + fn go(a: $T) -> $T { + if bits::<$T>() > 64 { + // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough. + return if a <= core::u64::MAX as $T { + (a as u64).sqrt() as $T + } else { + let lo = (a >> 2u32).sqrt() << 1; + let hi = lo + 1; + if hi * hi <= a { + hi + } else { + lo + } + }; + } + + if a < 4 { + return (a > 0) as $T; + } + + #[cfg(feature = "std")] + #[inline] + fn guess(x: $T) -> $T { + (x as f64).sqrt() as $T + } + + #[cfg(not(feature = "std"))] + #[inline] + fn guess(x: $T) -> $T { + 1 << ((log2(x) + 1) / 2) + } + + // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method + let next = |x: $T| (a / x + x) >> 1; + fixpoint(guess(a), next) + } + go(*self) + } + + #[inline] + fn cbrt(&self) -> Self { + fn go(a: $T) -> $T { + if bits::<$T>() > 64 { + // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough. + return if a <= core::u64::MAX as $T { + (a as u64).cbrt() as $T + } else { + let lo = (a >> 3u32).cbrt() << 1; + let hi = lo + 1; + if hi * hi * hi <= a { + hi + } else { + lo + } + }; + } + + if bits::<$T>() <= 32 { + // Implementation based on Hacker's Delight `icbrt2` + let mut x = a; + let mut y2 = 0; + let mut y = 0; + let smax = bits::<$T>() / 3; + for s in (0..smax + 1).rev() { + let s = s * 3; + y2 *= 4; + y *= 2; + let b = 3 * (y2 + y) + 1; + if x >> s >= b { + x -= b << s; + y2 += 2 * y + 1; + y += 1; + } + } + return y; + } + + if a < 8 { + return (a > 0) as $T; + } + if a <= core::u32::MAX as $T { + return (a as u32).cbrt() as $T; + } + + #[cfg(feature = "std")] + #[inline] + fn guess(x: $T) -> $T { + (x as f64).cbrt() as $T + } + + #[cfg(not(feature = "std"))] + #[inline] + fn guess(x: $T) -> $T { + 1 << ((log2(x) + 2) / 3) + } + + // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods + let next = |x: $T| (a / (x * x) + x * 2) / 3; + fixpoint(guess(a), next) + } + go(*self) + } + } + }; +} + +unsigned_roots!(u8); +unsigned_roots!(u16); +unsigned_roots!(u32); +unsigned_roots!(u64); +#[cfg(has_i128)] +unsigned_roots!(u128); +unsigned_roots!(usize); |