summaryrefslogtreecommitdiffstats
path: root/third_party/rust/packed_simd/src/api/reductions/integer_arithmetic.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/packed_simd/src/api/reductions/integer_arithmetic.rs')
-rw-r--r--third_party/rust/packed_simd/src/api/reductions/integer_arithmetic.rs193
1 files changed, 193 insertions, 0 deletions
diff --git a/third_party/rust/packed_simd/src/api/reductions/integer_arithmetic.rs b/third_party/rust/packed_simd/src/api/reductions/integer_arithmetic.rs
new file mode 100644
index 0000000000..e99e6cb5d7
--- /dev/null
+++ b/third_party/rust/packed_simd/src/api/reductions/integer_arithmetic.rs
@@ -0,0 +1,193 @@
+//! Implements portable horizontal integer vector arithmetic reductions.
+
+macro_rules! impl_reduction_integer_arithmetic {
+ ([$elem_ty:ident; $elem_count:expr]: $id:ident | $ielem_ty:ident
+ | $test_tt:tt) => {
+ impl $id {
+ /// Horizontal wrapping sum of the vector elements.
+ ///
+ /// The intrinsic performs a tree-reduction of the vector elements.
+ /// That is, for an 8 element vector:
+ ///
+ /// > ((x0 + x1) + (x2 + x3)) + ((x4 + x5) + (x6 + x7))
+ ///
+ /// If an operation overflows it returns the mathematical result
+ /// modulo `2^n` where `n` is the number of times it overflows.
+ #[inline]
+ pub fn wrapping_sum(self) -> $elem_ty {
+ #[cfg(not(target_arch = "aarch64"))]
+ {
+ use crate::llvm::simd_reduce_add_ordered;
+ let v: $ielem_ty = unsafe { simd_reduce_add_ordered(self.0, 0 as $ielem_ty) };
+ v as $elem_ty
+ }
+ #[cfg(target_arch = "aarch64")]
+ {
+ // FIXME: broken on AArch64
+ // https://github.com/rust-lang-nursery/packed_simd/issues/15
+ let mut x = self.extract(0) as $elem_ty;
+ for i in 1..$id::lanes() {
+ x = x.wrapping_add(self.extract(i) as $elem_ty);
+ }
+ x
+ }
+ }
+
+ /// Horizontal wrapping product of the vector elements.
+ ///
+ /// The intrinsic performs a tree-reduction of the vector elements.
+ /// That is, for an 8 element vector:
+ ///
+ /// > ((x0 * x1) * (x2 * x3)) * ((x4 * x5) * (x6 * x7))
+ ///
+ /// If an operation overflows it returns the mathematical result
+ /// modulo `2^n` where `n` is the number of times it overflows.
+ #[inline]
+ pub fn wrapping_product(self) -> $elem_ty {
+ #[cfg(not(target_arch = "aarch64"))]
+ {
+ use crate::llvm::simd_reduce_mul_ordered;
+ let v: $ielem_ty = unsafe { simd_reduce_mul_ordered(self.0, 1 as $ielem_ty) };
+ v as $elem_ty
+ }
+ #[cfg(target_arch = "aarch64")]
+ {
+ // FIXME: broken on AArch64
+ // https://github.com/rust-lang-nursery/packed_simd/issues/15
+ let mut x = self.extract(0) as $elem_ty;
+ for i in 1..$id::lanes() {
+ x = x.wrapping_mul(self.extract(i) as $elem_ty);
+ }
+ x
+ }
+ }
+ }
+
+ impl crate::iter::Sum for $id {
+ #[inline]
+ fn sum<I: Iterator<Item = $id>>(iter: I) -> $id {
+ iter.fold($id::splat(0), crate::ops::Add::add)
+ }
+ }
+
+ impl crate::iter::Product for $id {
+ #[inline]
+ fn product<I: Iterator<Item = $id>>(iter: I) -> $id {
+ iter.fold($id::splat(1), crate::ops::Mul::mul)
+ }
+ }
+
+ impl<'a> crate::iter::Sum<&'a $id> for $id {
+ #[inline]
+ fn sum<I: Iterator<Item = &'a $id>>(iter: I) -> $id {
+ iter.fold($id::splat(0), |a, b| crate::ops::Add::add(a, *b))
+ }
+ }
+
+ impl<'a> crate::iter::Product<&'a $id> for $id {
+ #[inline]
+ fn product<I: Iterator<Item = &'a $id>>(iter: I) -> $id {
+ iter.fold($id::splat(1), |a, b| crate::ops::Mul::mul(a, *b))
+ }
+ }
+
+ test_if! {
+ $test_tt:
+ paste::item! {
+ pub mod [<$id _reduction_int_arith>] {
+ use super::*;
+
+ fn alternating(x: usize) -> $id {
+ let mut v = $id::splat(1 as $elem_ty);
+ for i in 0..$id::lanes() {
+ if i % x == 0 {
+ v = v.replace(i, 2 as $elem_ty);
+ }
+ }
+ v
+ }
+
+ #[cfg_attr(not(target_arch = "wasm32"), test)]
+ #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
+ fn wrapping_sum() {
+ let v = $id::splat(0 as $elem_ty);
+ assert_eq!(v.wrapping_sum(), 0 as $elem_ty);
+ let v = $id::splat(1 as $elem_ty);
+ assert_eq!(v.wrapping_sum(), $id::lanes() as $elem_ty);
+ let v = alternating(2);
+ if $id::lanes() > 1 {
+ assert_eq!(
+ v.wrapping_sum(),
+ ($id::lanes() / 2 + $id::lanes()) as $elem_ty
+ );
+ } else {
+ assert_eq!(
+ v.wrapping_sum(),
+ 2 as $elem_ty
+ );
+ }
+ }
+ #[cfg_attr(not(target_arch = "wasm32"), test)]
+ #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
+ fn wrapping_sum_overflow() {
+ let start = $elem_ty::max_value()
+ - ($id::lanes() as $elem_ty / 2);
+
+ let v = $id::splat(start as $elem_ty);
+ let vwrapping_sum = v.wrapping_sum();
+
+ let mut wrapping_sum = start;
+ for _ in 1..$id::lanes() {
+ wrapping_sum = wrapping_sum.wrapping_add(start);
+ }
+ assert_eq!(wrapping_sum, vwrapping_sum, "v = {:?}", v);
+ }
+
+ #[cfg_attr(not(target_arch = "wasm32"), test)]
+ #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
+ fn wrapping_product() {
+ let v = $id::splat(0 as $elem_ty);
+ assert_eq!(v.wrapping_product(), 0 as $elem_ty);
+ let v = $id::splat(1 as $elem_ty);
+ assert_eq!(v.wrapping_product(), 1 as $elem_ty);
+ let f = match $id::lanes() {
+ 64 => 16,
+ 32 => 8,
+ 16 => 4,
+ _ => 2,
+ };
+ let v = alternating(f);
+ if $id::lanes() > 1 {
+ assert_eq!(
+ v.wrapping_product(),
+ (2_usize.pow(($id::lanes() / f) as u32)
+ as $elem_ty)
+ );
+ } else {
+ assert_eq!(
+ v.wrapping_product(),
+ 2 as $elem_ty
+ );
+ }
+ }
+
+ #[cfg_attr(not(target_arch = "wasm32"), test)]
+ #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
+ fn wrapping_product_overflow() {
+ let start = $elem_ty::max_value()
+ - ($id::lanes() as $elem_ty / 2);
+
+ let v = $id::splat(start as $elem_ty);
+ let vmul = v.wrapping_product();
+
+ let mut mul = start;
+ for _ in 1..$id::lanes() {
+ mul = mul.wrapping_mul(start);
+ }
+ assert_eq!(mul, vmul, "v = {:?}", v);
+ }
+ }
+ }
+ }
+ };
+}