diff options
Diffstat (limited to 'library/portable-simd/crates/core_simd/src/masks/full_masks.rs')
-rw-r--r-- | library/portable-simd/crates/core_simd/src/masks/full_masks.rs | 185 |
1 files changed, 105 insertions, 80 deletions
diff --git a/library/portable-simd/crates/core_simd/src/masks/full_masks.rs b/library/portable-simd/crates/core_simd/src/masks/full_masks.rs index 1d13c45b8..63964f455 100644 --- a/library/portable-simd/crates/core_simd/src/masks/full_masks.rs +++ b/library/portable-simd/crates/core_simd/src/masks/full_masks.rs @@ -1,29 +1,25 @@ //! Masks that take up full SIMD vector registers. -use super::MaskElement; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; - -#[cfg(feature = "generic_const_exprs")] -use crate::simd::ToBitMaskArray; +use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; #[repr(transparent)] -pub struct Mask<T, const LANES: usize>(Simd<T, LANES>) +pub struct Mask<T, const N: usize>(Simd<T, N>) where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount; + LaneCount<N>: SupportedLaneCount; -impl<T, const LANES: usize> Copy for Mask<T, LANES> +impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { } -impl<T, const LANES: usize> Clone for Mask<T, LANES> +impl<T, const N: usize> Clone for Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] @@ -32,10 +28,10 @@ where } } -impl<T, const LANES: usize> PartialEq for Mask<T, LANES> +impl<T, const N: usize> PartialEq for Mask<T, N> where T: MaskElement + PartialEq, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { #[inline] fn eq(&self, other: &Self) -> bool { @@ -43,10 +39,10 @@ where } } -impl<T, const LANES: usize> PartialOrd for Mask<T, LANES> +impl<T, const N: usize> PartialOrd for Mask<T, N> where T: MaskElement + PartialOrd, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { #[inline] fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { @@ -54,17 +50,17 @@ where } } -impl<T, const LANES: usize> Eq for Mask<T, LANES> +impl<T, const N: usize> Eq for Mask<T, N> where T: MaskElement + Eq, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { } -impl<T, const LANES: usize> Ord for Mask<T, LANES> +impl<T, const N: usize> Ord for Mask<T, N> where T: MaskElement + Ord, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { #[inline] fn cmp(&self, other: &Self) -> core::cmp::Ordering { @@ -101,10 +97,10 @@ macro_rules! impl_reverse_bits { impl_reverse_bits! { u8, u16, u32, u64 } -impl<T, const LANES: usize> Mask<T, LANES> +impl<T, const N: usize> Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] @@ -125,19 +121,19 @@ where #[inline] #[must_use = "method returns a new vector and does not mutate the original value"] - pub fn to_int(self) -> Simd<T, LANES> { + pub fn to_int(self) -> Simd<T, N> { self.0 } #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self { + pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self { Self(value) } #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn convert<U>(self) -> Mask<U, LANES> + pub fn convert<U>(self) -> Mask<U, N> where U: MaskElement, { @@ -145,62 +141,50 @@ where unsafe { Mask(intrinsics::simd_cast(self.0)) } } - #[cfg(feature = "generic_const_exprs")] #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask_array<const N: usize>(self) -> [u8; N] - where - super::Mask<T, LANES>: ToBitMaskArray, - [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized, - { - assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N); + #[must_use = "method returns a new vector and does not mutate the original value"] + pub fn to_bitmask_vector(self) -> Simd<u8, N> { + let mut bitmask = Simd::splat(0); - // Safety: N is the correct bitmask size + // Safety: Bytes is the right size array unsafe { // Compute the bitmask - let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] = + let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask = intrinsics::simd_bitmask(self.0); - // Transmute to the return type, previously asserted to be the same size - let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask); - // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { - *x = x.reverse_bits(); + for x in bytes.as_mut() { + *x = x.reverse_bits() } - }; + } - bitmask + bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref()); } + + bitmask } - #[cfg(feature = "generic_const_exprs")] #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self - where - super::Mask<T, LANES>: ToBitMaskArray, - [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized, - { - assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N); + pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self { + let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default(); - // Safety: N is the correct bitmask size + // Safety: Bytes is the right size array unsafe { + let len = bytes.as_ref().len(); + bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]); + // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { + for x in bytes.as_mut() { *x = x.reverse_bits(); } } - // Transmute to the bitmask type, previously asserted to be the same size - let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] = - core::mem::transmute_copy(&bitmask); - // Compute the regular mask Self::from_int_unchecked(intrinsics::simd_select_bitmask( - bitmask, + bytes, Self::splat(true).to_int(), Self::splat(false).to_int(), )) @@ -208,40 +192,81 @@ where } #[inline] - pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U + unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U where - super::Mask<T, LANES>: ToBitMask<BitMask = U>, + LaneCount<M>: SupportedLaneCount, { - // Safety: U is required to be the appropriate bitmask type - let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) }; + let resized = self.to_int().resize::<M>(T::FALSE); + + // Safety: `resized` is an integer vector with length M, which must match T + let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) }; // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - bitmask.reverse_bits(LANES) + bitmask.reverse_bits(M) } else { bitmask } } #[inline] - pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self + unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self where - super::Mask<T, LANES>: ToBitMask<BitMask = U>, + LaneCount<M>: SupportedLaneCount, { // LLVM assumes bit order should match endianness let bitmask = if cfg!(target_endian = "big") { - bitmask.reverse_bits(LANES) + bitmask.reverse_bits(M) } else { bitmask }; - // Safety: U is required to be the appropriate bitmask type - unsafe { - Self::from_int_unchecked(intrinsics::simd_select_bitmask( + // SAFETY: `mask` is the correct bitmask type for a u64 bitmask + let mask: Simd<T, M> = unsafe { + intrinsics::simd_select_bitmask( bitmask, - Self::splat(true).to_int(), - Self::splat(false).to_int(), - )) + Simd::<T, M>::splat(T::TRUE), + Simd::<T, M>::splat(T::FALSE), + ) + }; + + // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` + unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) } + } + + #[inline] + pub(crate) fn to_bitmask_integer(self) -> u64 { + // TODO modify simd_bitmask to zero-extend output, making this unnecessary + if N <= 8 { + // Safety: bitmask matches length + unsafe { self.to_bitmask_impl::<u8, 8>() as u64 } + } else if N <= 16 { + // Safety: bitmask matches length + unsafe { self.to_bitmask_impl::<u16, 16>() as u64 } + } else if N <= 32 { + // Safety: bitmask matches length + unsafe { self.to_bitmask_impl::<u32, 32>() as u64 } + } else { + // Safety: bitmask matches length + unsafe { self.to_bitmask_impl::<u64, 64>() } + } + } + + #[inline] + pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { + // TODO modify simd_bitmask_select to truncate input, making this unnecessary + if N <= 8 { + // Safety: bitmask matches length + unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) } + } else if N <= 16 { + // Safety: bitmask matches length + unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) } + } else if N <= 32 { + // Safety: bitmask matches length + unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) } + } else { + // Safety: bitmask matches length + unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) } } } @@ -260,21 +285,21 @@ where } } -impl<T, const LANES: usize> From<Mask<T, LANES>> for Simd<T, LANES> +impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { #[inline] - fn from(value: Mask<T, LANES>) -> Self { + fn from(value: Mask<T, N>) -> Self { value.0 } } -impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES> +impl<T, const N: usize> core::ops::BitAnd for Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { type Output = Self; #[inline] @@ -285,10 +310,10 @@ where } } -impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES> +impl<T, const N: usize> core::ops::BitOr for Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { type Output = Self; #[inline] @@ -299,10 +324,10 @@ where } } -impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES> +impl<T, const N: usize> core::ops::BitXor for Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { type Output = Self; #[inline] @@ -313,10 +338,10 @@ where } } -impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES> +impl<T, const N: usize> core::ops::Not for Mask<T, N> where T: MaskElement, - LaneCount<LANES>: SupportedLaneCount, + LaneCount<N>: SupportedLaneCount, { type Output = Self; #[inline] |