summaryrefslogtreecommitdiffstats
path: root/third_party/rust/packed_simd/src/api/ptr/gather_scatter.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/packed_simd/src/api/ptr/gather_scatter.rs')
-rw-r--r--third_party/rust/packed_simd/src/api/ptr/gather_scatter.rs216
1 files changed, 216 insertions, 0 deletions
diff --git a/third_party/rust/packed_simd/src/api/ptr/gather_scatter.rs b/third_party/rust/packed_simd/src/api/ptr/gather_scatter.rs
new file mode 100644
index 0000000000..374482ac31
--- /dev/null
+++ b/third_party/rust/packed_simd/src/api/ptr/gather_scatter.rs
@@ -0,0 +1,216 @@
+//! Implements masked gather and scatters for vectors of pointers
+
+macro_rules! impl_ptr_read {
+ ([$elem_ty:ty; $elem_count:expr]: $id:ident, $mask_ty:ident
+ | $test_tt:tt) => {
+ impl<T> $id<T>
+ where
+ [T; $elem_count]: sealed::SimdArray,
+ {
+ /// Reads selected vector elements from memory.
+ ///
+ /// Instantiates a new vector by reading the values from `self` for
+ /// those lanes whose `mask` is `true`, and using the elements of
+ /// `value` otherwise.
+ ///
+ /// No memory is accessed for those lanes of `self` whose `mask` is
+ /// `false`.
+ ///
+ /// # Safety
+ ///
+ /// This method is unsafe because it dereferences raw pointers. The
+ /// pointers must be aligned to `mem::align_of::<T>()`.
+ #[inline]
+ pub unsafe fn read<M>(
+ self,
+ mask: Simd<[M; $elem_count]>,
+ value: Simd<[T; $elem_count]>,
+ ) -> Simd<[T; $elem_count]>
+ where
+ M: sealed::Mask,
+ [M; $elem_count]: sealed::SimdArray,
+ {
+ use crate::llvm::simd_gather;
+ Simd(simd_gather(value.0, self.0, mask.0))
+ }
+ }
+
+ test_if! {
+ $test_tt:
+ paste::item! {
+ mod [<$id _read>] {
+ use super::*;
+ #[test]
+ fn read() {
+ let mut v = [0_i32; $elem_count];
+ for i in 0..$elem_count {
+ v[i] = i as i32;
+ }
+
+ let mut ptr = $id::<i32>::null();
+
+ for i in 0..$elem_count {
+ ptr = ptr.replace(i,
+ &v[i] as *const i32 as *mut i32
+ );
+ }
+
+ // all mask elements are true:
+ let mask = $mask_ty::splat(true);
+ let def = Simd::<[i32; $elem_count]>::splat(42_i32);
+ let r: Simd<[i32; $elem_count]> = unsafe {
+ ptr.read(mask, def)
+ };
+ assert_eq!(
+ r,
+ Simd::<[i32; $elem_count]>::from_slice_unaligned(
+ &v
+ )
+ );
+
+ let mut mask = mask;
+ for i in 0..$elem_count {
+ if i % 2 != 0 {
+ mask = mask.replace(i, false);
+ }
+ }
+
+ // even mask elements are true, odd ones are false:
+ let r: Simd<[i32; $elem_count]> = unsafe {
+ ptr.read(mask, def)
+ };
+ let mut e = v;
+ for i in 0..$elem_count {
+ if i % 2 != 0 {
+ e[i] = 42;
+ }
+ }
+ assert_eq!(
+ r,
+ Simd::<[i32; $elem_count]>::from_slice_unaligned(
+ &e
+ )
+ );
+
+ // all mask elements are false:
+ let mask = $mask_ty::splat(false);
+ let def = Simd::<[i32; $elem_count]>::splat(42_i32);
+ let r: Simd<[i32; $elem_count]> = unsafe {
+ ptr.read(mask, def) }
+ ;
+ assert_eq!(r, def);
+ }
+ }
+ }
+ }
+ };
+}
+
+macro_rules! impl_ptr_write {
+ ([$elem_ty:ty; $elem_count:expr]: $id:ident, $mask_ty:ident
+ | $test_tt:tt) => {
+ impl<T> $id<T>
+ where
+ [T; $elem_count]: sealed::SimdArray,
+ {
+ /// Writes selected vector elements to memory.
+ ///
+ /// Writes the lanes of `values` for which the mask is `true` to
+ /// their corresponding memory addresses in `self`.
+ ///
+ /// No memory is accessed for those lanes of `self` whose `mask` is
+ /// `false`.
+ ///
+ /// Overlapping memory addresses of `self` are written to in order
+ /// from the lest-significant to the most-significant element.
+ ///
+ /// # Safety
+ ///
+ /// This method is unsafe because it dereferences raw pointers. The
+ /// pointers must be aligned to `mem::align_of::<T>()`.
+ #[inline]
+ pub unsafe fn write<M>(self, mask: Simd<[M; $elem_count]>, value: Simd<[T; $elem_count]>)
+ where
+ M: sealed::Mask,
+ [M; $elem_count]: sealed::SimdArray,
+ {
+ use crate::llvm::simd_scatter;
+ simd_scatter(value.0, self.0, mask.0)
+ }
+ }
+
+ test_if! {
+ $test_tt:
+ paste::item! {
+ mod [<$id _write>] {
+ use super::*;
+ #[test]
+ fn write() {
+ // forty_two = [42, 42, 42, ...]
+ let forty_two
+ = Simd::<[i32; $elem_count]>::splat(42_i32);
+
+ // This test will write to this array
+ let mut arr = [0_i32; $elem_count];
+ for i in 0..$elem_count {
+ arr[i] = i as i32;
+ }
+ // arr = [0, 1, 2, ...]
+
+ let mut ptr = $id::<i32>::null();
+ for i in 0..$elem_count {
+ ptr = ptr.replace(i, unsafe {
+ arr.as_ptr().add(i) as *mut i32
+ });
+ }
+ // ptr = [&arr[0], &arr[1], ...]
+
+ // write `forty_two` to all elements of `v`
+ {
+ let backup = arr;
+ unsafe {
+ ptr.write($mask_ty::splat(true), forty_two)
+ };
+ assert_eq!(arr, [42_i32; $elem_count]);
+ arr = backup; // arr = [0, 1, 2, ...]
+ }
+
+ // write 42 to even elements of arr:
+ {
+ // set odd elements of the mask to false
+ let mut mask = $mask_ty::splat(true);
+ for i in 0..$elem_count {
+ if i % 2 != 0 {
+ mask = mask.replace(i, false);
+ }
+ }
+ // mask = [true, false, true, false, ...]
+
+ // expected result r = [42, 1, 42, 3, 42, 5, ...]
+ let mut r = arr;
+ for i in 0..$elem_count {
+ if i % 2 == 0 {
+ r[i] = 42;
+ }
+ }
+
+ let backup = arr;
+ unsafe { ptr.write(mask, forty_two) };
+ assert_eq!(arr, r);
+ arr = backup; // arr = [0, 1, 2, 3, ...]
+ }
+
+ // write 42 to no elements of arr
+ {
+ let backup = arr;
+ unsafe {
+ ptr.write($mask_ty::splat(false), forty_two)
+ };
+ assert_eq!(arr, backup);
+ }
+ }
+ }
+ }
+ }
+ };
+}