summaryrefslogtreecommitdiffstats
path: root/third_party/rust/num-integer/src/roots.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/num-integer/src/roots.rs')
-rw-r--r--third_party/rust/num-integer/src/roots.rs391
1 files changed, 391 insertions, 0 deletions
diff --git a/third_party/rust/num-integer/src/roots.rs b/third_party/rust/num-integer/src/roots.rs
new file mode 100644
index 0000000000..a9eec1a93c
--- /dev/null
+++ b/third_party/rust/num-integer/src/roots.rs
@@ -0,0 +1,391 @@
+use core;
+use core::mem;
+use traits::checked_pow;
+use traits::PrimInt;
+use Integer;
+
+/// Provides methods to compute an integer's square root, cube root,
+/// and arbitrary `n`th root.
+pub trait Roots: Integer {
+ /// Returns the truncated principal `n`th root of an integer
+ /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }`
+ ///
+ /// This is solving for `r` in `rⁿ = x`, rounding toward zero.
+ /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`.
+ /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `n` is zero:
+ ///
+ /// ```should_panic
+ /// # use num_integer::Roots;
+ /// println!("can't compute ⁰√x : {}", 123.nth_root(0));
+ /// ```
+ ///
+ /// or if `n` is even and `self` is negative:
+ ///
+ /// ```should_panic
+ /// # use num_integer::Roots;
+ /// println!("no imaginary numbers... {}", (-1).nth_root(10));
+ /// ```
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use num_integer::Roots;
+ ///
+ /// let x: i32 = 12345;
+ /// assert_eq!(x.nth_root(1), x);
+ /// assert_eq!(x.nth_root(2), x.sqrt());
+ /// assert_eq!(x.nth_root(3), x.cbrt());
+ /// assert_eq!(x.nth_root(4), 10);
+ /// assert_eq!(x.nth_root(13), 2);
+ /// assert_eq!(x.nth_root(14), 1);
+ /// assert_eq!(x.nth_root(std::u32::MAX), 1);
+ ///
+ /// assert_eq!(std::i32::MAX.nth_root(30), 2);
+ /// assert_eq!(std::i32::MAX.nth_root(31), 1);
+ /// assert_eq!(std::i32::MIN.nth_root(31), -2);
+ /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1);
+ ///
+ /// assert_eq!(std::u32::MAX.nth_root(31), 2);
+ /// assert_eq!(std::u32::MAX.nth_root(32), 1);
+ /// ```
+ fn nth_root(&self, n: u32) -> Self;
+
+ /// Returns the truncated principal square root of an integer -- `⌊√x⌋`
+ ///
+ /// This is solving for `r` in `r² = x`, rounding toward zero.
+ /// The result will satisfy `r² ≤ x < (r+1)²`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `self` is less than zero:
+ ///
+ /// ```should_panic
+ /// # use num_integer::Roots;
+ /// println!("no imaginary numbers... {}", (-1).sqrt());
+ /// ```
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use num_integer::Roots;
+ ///
+ /// let x: i32 = 12345;
+ /// assert_eq!((x * x).sqrt(), x);
+ /// assert_eq!((x * x + 1).sqrt(), x);
+ /// assert_eq!((x * x - 1).sqrt(), x - 1);
+ /// ```
+ #[inline]
+ fn sqrt(&self) -> Self {
+ self.nth_root(2)
+ }
+
+ /// Returns the truncated principal cube root of an integer --
+ /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }`
+ ///
+ /// This is solving for `r` in `r³ = x`, rounding toward zero.
+ /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`.
+ /// If `x` is negative, then `(r-1)³ < x ≤ r³`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use num_integer::Roots;
+ ///
+ /// let x: i32 = 1234;
+ /// assert_eq!((x * x * x).cbrt(), x);
+ /// assert_eq!((x * x * x + 1).cbrt(), x);
+ /// assert_eq!((x * x * x - 1).cbrt(), x - 1);
+ ///
+ /// assert_eq!((-(x * x * x)).cbrt(), -x);
+ /// assert_eq!((-(x * x * x + 1)).cbrt(), -x);
+ /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1));
+ /// ```
+ #[inline]
+ fn cbrt(&self) -> Self {
+ self.nth_root(3)
+ }
+}
+
+/// Returns the truncated principal square root of an integer --
+/// see [Roots::sqrt](trait.Roots.html#method.sqrt).
+#[inline]
+pub fn sqrt<T: Roots>(x: T) -> T {
+ x.sqrt()
+}
+
+/// Returns the truncated principal cube root of an integer --
+/// see [Roots::cbrt](trait.Roots.html#method.cbrt).
+#[inline]
+pub fn cbrt<T: Roots>(x: T) -> T {
+ x.cbrt()
+}
+
+/// Returns the truncated principal `n`th root of an integer --
+/// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root).
+#[inline]
+pub fn nth_root<T: Roots>(x: T, n: u32) -> T {
+ x.nth_root(n)
+}
+
+macro_rules! signed_roots {
+ ($T:ty, $U:ty) => {
+ impl Roots for $T {
+ #[inline]
+ fn nth_root(&self, n: u32) -> Self {
+ if *self >= 0 {
+ (*self as $U).nth_root(n) as Self
+ } else {
+ assert!(n.is_odd(), "even roots of a negative are imaginary");
+ -((self.wrapping_neg() as $U).nth_root(n) as Self)
+ }
+ }
+
+ #[inline]
+ fn sqrt(&self) -> Self {
+ assert!(*self >= 0, "the square root of a negative is imaginary");
+ (*self as $U).sqrt() as Self
+ }
+
+ #[inline]
+ fn cbrt(&self) -> Self {
+ if *self >= 0 {
+ (*self as $U).cbrt() as Self
+ } else {
+ -((self.wrapping_neg() as $U).cbrt() as Self)
+ }
+ }
+ }
+ };
+}
+
+signed_roots!(i8, u8);
+signed_roots!(i16, u16);
+signed_roots!(i32, u32);
+signed_roots!(i64, u64);
+#[cfg(has_i128)]
+signed_roots!(i128, u128);
+signed_roots!(isize, usize);
+
+#[inline]
+fn fixpoint<T, F>(mut x: T, f: F) -> T
+where
+ T: Integer + Copy,
+ F: Fn(T) -> T,
+{
+ let mut xn = f(x);
+ while x < xn {
+ x = xn;
+ xn = f(x);
+ }
+ while x > xn {
+ x = xn;
+ xn = f(x);
+ }
+ x
+}
+
+#[inline]
+fn bits<T>() -> u32 {
+ 8 * mem::size_of::<T>() as u32
+}
+
+#[inline]
+fn log2<T: PrimInt>(x: T) -> u32 {
+ debug_assert!(x > T::zero());
+ bits::<T>() - 1 - x.leading_zeros()
+}
+
+macro_rules! unsigned_roots {
+ ($T:ident) => {
+ impl Roots for $T {
+ #[inline]
+ fn nth_root(&self, n: u32) -> Self {
+ fn go(a: $T, n: u32) -> $T {
+ // Specialize small roots
+ match n {
+ 0 => panic!("can't find a root of degree 0!"),
+ 1 => return a,
+ 2 => return a.sqrt(),
+ 3 => return a.cbrt(),
+ _ => (),
+ }
+
+ // The root of values less than 2ⁿ can only be 0 or 1.
+ if bits::<$T>() <= n || a < (1 << n) {
+ return (a > 0) as $T;
+ }
+
+ if bits::<$T>() > 64 {
+ // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
+ return if a <= core::u64::MAX as $T {
+ (a as u64).nth_root(n) as $T
+ } else {
+ let lo = (a >> n).nth_root(n) << 1;
+ let hi = lo + 1;
+ // 128-bit `checked_mul` also involves division, but we can't always
+ // compute `hiⁿ` without risking overflow. Try to avoid it though...
+ if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
+ match checked_pow(hi, n as usize) {
+ Some(x) if x <= a => hi,
+ _ => lo,
+ }
+ } else {
+ if hi.pow(n) <= a {
+ hi
+ } else {
+ lo
+ }
+ }
+ };
+ }
+
+ #[cfg(feature = "std")]
+ #[inline]
+ fn guess(x: $T, n: u32) -> $T {
+ // for smaller inputs, `f64` doesn't justify its cost.
+ if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
+ 1 << ((log2(x) + n - 1) / n)
+ } else {
+ ((x as f64).ln() / f64::from(n)).exp() as $T
+ }
+ }
+
+ #[cfg(not(feature = "std"))]
+ #[inline]
+ fn guess(x: $T, n: u32) -> $T {
+ 1 << ((log2(x) + n - 1) / n)
+ }
+
+ // https://en.wikipedia.org/wiki/Nth_root_algorithm
+ let n1 = n - 1;
+ let next = |x: $T| {
+ let y = match checked_pow(x, n1 as usize) {
+ Some(ax) => a / ax,
+ None => 0,
+ };
+ (y + x * n1 as $T) / n as $T
+ };
+ fixpoint(guess(a, n), next)
+ }
+ go(*self, n)
+ }
+
+ #[inline]
+ fn sqrt(&self) -> Self {
+ fn go(a: $T) -> $T {
+ if bits::<$T>() > 64 {
+ // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
+ return if a <= core::u64::MAX as $T {
+ (a as u64).sqrt() as $T
+ } else {
+ let lo = (a >> 2u32).sqrt() << 1;
+ let hi = lo + 1;
+ if hi * hi <= a {
+ hi
+ } else {
+ lo
+ }
+ };
+ }
+
+ if a < 4 {
+ return (a > 0) as $T;
+ }
+
+ #[cfg(feature = "std")]
+ #[inline]
+ fn guess(x: $T) -> $T {
+ (x as f64).sqrt() as $T
+ }
+
+ #[cfg(not(feature = "std"))]
+ #[inline]
+ fn guess(x: $T) -> $T {
+ 1 << ((log2(x) + 1) / 2)
+ }
+
+ // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
+ let next = |x: $T| (a / x + x) >> 1;
+ fixpoint(guess(a), next)
+ }
+ go(*self)
+ }
+
+ #[inline]
+ fn cbrt(&self) -> Self {
+ fn go(a: $T) -> $T {
+ if bits::<$T>() > 64 {
+ // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
+ return if a <= core::u64::MAX as $T {
+ (a as u64).cbrt() as $T
+ } else {
+ let lo = (a >> 3u32).cbrt() << 1;
+ let hi = lo + 1;
+ if hi * hi * hi <= a {
+ hi
+ } else {
+ lo
+ }
+ };
+ }
+
+ if bits::<$T>() <= 32 {
+ // Implementation based on Hacker's Delight `icbrt2`
+ let mut x = a;
+ let mut y2 = 0;
+ let mut y = 0;
+ let smax = bits::<$T>() / 3;
+ for s in (0..smax + 1).rev() {
+ let s = s * 3;
+ y2 *= 4;
+ y *= 2;
+ let b = 3 * (y2 + y) + 1;
+ if x >> s >= b {
+ x -= b << s;
+ y2 += 2 * y + 1;
+ y += 1;
+ }
+ }
+ return y;
+ }
+
+ if a < 8 {
+ return (a > 0) as $T;
+ }
+ if a <= core::u32::MAX as $T {
+ return (a as u32).cbrt() as $T;
+ }
+
+ #[cfg(feature = "std")]
+ #[inline]
+ fn guess(x: $T) -> $T {
+ (x as f64).cbrt() as $T
+ }
+
+ #[cfg(not(feature = "std"))]
+ #[inline]
+ fn guess(x: $T) -> $T {
+ 1 << ((log2(x) + 2) / 3)
+ }
+
+ // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
+ let next = |x: $T| (a / (x * x) + x * 2) / 3;
+ fixpoint(guess(a), next)
+ }
+ go(*self)
+ }
+ }
+ };
+}
+
+unsigned_roots!(u8);
+unsigned_roots!(u16);
+unsigned_roots!(u32);
+unsigned_roots!(u64);
+#[cfg(has_i128)]
+unsigned_roots!(u128);
+unsigned_roots!(usize);