diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-17 12:02:58 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-17 12:02:58 +0000 |
commit | 698f8c2f01ea549d77d7dc3338a12e04c11057b9 (patch) | |
tree | 173a775858bd501c378080a10dca74132f05bc50 /vendor/bytecount/src | |
parent | Initial commit. (diff) | |
download | rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.tar.xz rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.zip |
Adding upstream version 1.64.0+dfsg1.upstream/1.64.0+dfsg1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'vendor/bytecount/src')
-rw-r--r-- | vendor/bytecount/src/integer_simd.rs | 111 | ||||
-rw-r--r-- | vendor/bytecount/src/lib.rs | 138 | ||||
-rw-r--r-- | vendor/bytecount/src/naive.rs | 42 | ||||
-rw-r--r-- | vendor/bytecount/src/simd/generic.rs | 137 | ||||
-rw-r--r-- | vendor/bytecount/src/simd/mod.rs | 17 | ||||
-rw-r--r-- | vendor/bytecount/src/simd/x86_avx2.rs | 161 | ||||
-rw-r--r-- | vendor/bytecount/src/simd/x86_sse2.rs | 171 |
7 files changed, 777 insertions, 0 deletions
diff --git a/vendor/bytecount/src/integer_simd.rs b/vendor/bytecount/src/integer_simd.rs new file mode 100644 index 000000000..48f2ee8d9 --- /dev/null +++ b/vendor/bytecount/src/integer_simd.rs @@ -0,0 +1,111 @@ +#[cfg(not(feature = "runtime-dispatch-simd"))] +use core::{mem, ptr, usize}; +#[cfg(feature = "runtime-dispatch-simd")] +use std::{mem, ptr, usize}; + +fn splat(byte: u8) -> usize { + let lo = usize::MAX / 0xFF; + lo * byte as usize +} + +unsafe fn usize_load_unchecked(bytes: &[u8], offset: usize) -> usize { + let mut output = 0; + ptr::copy_nonoverlapping( + bytes.as_ptr().add(offset), + &mut output as *mut usize as *mut u8, + mem::size_of::<usize>() + ); + output +} + +fn bytewise_equal(lhs: usize, rhs: usize) -> usize { + let lo = usize::MAX / 0xFF; + let hi = lo << 7; + + let x = lhs ^ rhs; + !((((x & !hi) + !hi) | x) >> 7) & lo +} + +fn sum_usize(values: usize) -> usize { + let every_other_byte_lo = usize::MAX / 0xFFFF; + let every_other_byte = every_other_byte_lo * 0xFF; + + // Pairwise reduction to avoid overflow on next step. + let pair_sum: usize = (values & every_other_byte) + ((values >> 8) & every_other_byte); + + // Multiplication results in top two bytes holding sum. + pair_sum.wrapping_mul(every_other_byte_lo) >> ((mem::size_of::<usize>() - 2) * 8) +} + +fn is_leading_utf8_byte(values: usize) -> usize { + // a leading UTF-8 byte is one which does not start with the bits 10. + ((!values >> 7) | (values >> 6)) & splat(1) +} + +pub fn chunk_count(haystack: &[u8], needle: u8) -> usize { + let chunksize = mem::size_of::<usize>(); + assert!(haystack.len() >= chunksize); + + unsafe { + let mut offset = 0; + let mut count = 0; + + let needles = splat(needle); + + // 2040 + while haystack.len() >= offset + chunksize * 255 { + let mut counts = 0; + for _ in 0..255 { + counts += bytewise_equal(usize_load_unchecked(haystack, offset), needles); + offset += chunksize; + } + count += sum_usize(counts); + } + + // 8 + let mut counts = 0; + for i in 0..(haystack.len() - offset) / chunksize { + counts += bytewise_equal(usize_load_unchecked(haystack, offset + i * chunksize), needles); + } + if haystack.len() % 8 != 0 { + let mask = usize::from_le(!(!0 >> ((haystack.len() % chunksize) * 8))); + counts += bytewise_equal(usize_load_unchecked(haystack, haystack.len() - chunksize), needles) & mask; + } + count += sum_usize(counts); + + count + } +} + +pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize { + let chunksize = mem::size_of::<usize>(); + assert!(utf8_chars.len() >= chunksize); + + unsafe { + let mut offset = 0; + let mut count = 0; + + // 2040 + while utf8_chars.len() >= offset + chunksize * 255 { + let mut counts = 0; + for _ in 0..255 { + counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset)); + offset += chunksize; + } + count += sum_usize(counts); + } + + // 8 + let mut counts = 0; + for i in 0..(utf8_chars.len() - offset) / chunksize { + counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize)); + } + if utf8_chars.len() % 8 != 0 { + let mask = usize::from_le(!(!0 >> ((utf8_chars.len() % chunksize) * 8))); + counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, utf8_chars.len() - chunksize)) & mask; + } + count += sum_usize(counts); + + count + } +} diff --git a/vendor/bytecount/src/lib.rs b/vendor/bytecount/src/lib.rs new file mode 100644 index 000000000..ef4235c26 --- /dev/null +++ b/vendor/bytecount/src/lib.rs @@ -0,0 +1,138 @@ +//! count occurrences of a given byte, or the number of UTF-8 code points, in a +//! byte slice, fast. +//! +//! This crate has the [`count`](fn.count.html) method to count byte +//! occurrences (for example newlines) in a larger `&[u8]` slice. +//! +//! For example: +//! +//! ```rust +//! assert_eq!(5, bytecount::count(b"Hello, this is the bytecount crate!", b' ')); +//! ``` +//! +//! Also there is a [`num_chars`](fn.num_chars.html) method to count +//! the number of UTF8 characters in a slice. It will work the same as +//! `str::chars().count()` for byte slices of correct UTF-8 character +//! sequences. The result will likely be off for invalid sequences, +//! although the result is guaranteed to be between `0` and +//! `[_]::len()`, inclusive. +//! +//! Example: +//! +//! ```rust +//! let sequence = "Wenn ich ein Vöglein wär, flög ich zu Dir!"; +//! assert_eq!(sequence.chars().count(), +//! bytecount::num_chars(sequence.as_bytes())); +//! ``` +//! +//! For completeness and easy comparison, the "naive" versions of both +//! count and num_chars are provided. Those are also faster if used on +//! predominantly small strings. The +//! [`naive_count_32`](fn.naive_count_32.html) method can be faster +//! still on small strings. + +#![deny(missing_docs)] + +#![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)] + +#[cfg(not(feature = "runtime-dispatch-simd"))] +use core::mem; +#[cfg(feature = "runtime-dispatch-simd")] +use std::mem; + +mod naive; +pub use naive::*; +mod integer_simd; + +#[cfg(any( + all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")), + feature = "generic-simd" +))] +mod simd; + +/// Count occurrences of a byte in a slice of bytes, fast +/// +/// # Examples +/// +/// ``` +/// let s = b"This is a Text with spaces"; +/// let number_of_spaces = bytecount::count(s, b' '); +/// assert_eq!(number_of_spaces, 5); +/// ``` +pub fn count(haystack: &[u8], needle: u8) -> usize { + if haystack.len() >= 32 { + #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("avx2") { + unsafe { return simd::x86_avx2::chunk_count(haystack, needle); } + } + } + + #[cfg(feature = "generic-simd")] + return simd::generic::chunk_count(haystack, needle); + } + + if haystack.len() >= 16 { + #[cfg(all( + feature = "runtime-dispatch-simd", + any(target_arch = "x86", target_arch = "x86_64"), + not(feature = "generic-simd") + ))] + { + if is_x86_feature_detected!("sse2") { + unsafe { return simd::x86_sse2::chunk_count(haystack, needle); } + } + } + } + + if haystack.len() >= mem::size_of::<usize>() { + return integer_simd::chunk_count(haystack, needle); + } + + naive_count(haystack, needle) +} + +/// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, fast +/// +/// This function is safe to use on any byte array, valid UTF-8 or not, +/// but the output is only meaningful for well-formed UTF-8. +/// +/// # Example +/// +/// ``` +/// let swordfish = "メカジキ"; +/// let char_count = bytecount::num_chars(swordfish.as_bytes()); +/// assert_eq!(char_count, 4); +/// ``` +pub fn num_chars(utf8_chars: &[u8]) -> usize { + if utf8_chars.len() >= 32 { + #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("avx2") { + unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); } + } + } + + #[cfg(feature = "generic-simd")] + return simd::generic::chunk_num_chars(utf8_chars); + } + + if utf8_chars.len() >= 16 { + #[cfg(all( + feature = "runtime-dispatch-simd", + any(target_arch = "x86", target_arch = "x86_64"), + not(feature = "generic-simd") + ))] + { + if is_x86_feature_detected!("sse2") { + unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); } + } + } + } + + if utf8_chars.len() >= mem::size_of::<usize>() { + return integer_simd::chunk_num_chars(utf8_chars); + } + + naive_num_chars(utf8_chars) +} diff --git a/vendor/bytecount/src/naive.rs b/vendor/bytecount/src/naive.rs new file mode 100644 index 000000000..315c4b675 --- /dev/null +++ b/vendor/bytecount/src/naive.rs @@ -0,0 +1,42 @@ +/// Count up to `(2^32)-1` occurrences of a byte in a slice +/// of bytes, simple +/// +/// # Example +/// +/// ``` +/// let s = b"This is yet another Text with spaces"; +/// let number_of_spaces = bytecount::naive_count_32(s, b' '); +/// assert_eq!(number_of_spaces, 6); +/// ``` +pub fn naive_count_32(haystack: &[u8], needle: u8) -> usize { + haystack.iter().fold(0, |n, c| n + (*c == needle) as u32) as usize +} + +/// Count occurrences of a byte in a slice of bytes, simple +/// +/// # Example +/// +/// ``` +/// let s = b"This is yet another Text with spaces"; +/// let number_of_spaces = bytecount::naive_count(s, b' '); +/// assert_eq!(number_of_spaces, 6); +/// ``` +pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize { + utf8_chars.iter().fold(0, |n, c| n + (*c == needle) as usize) +} + +/// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, simple +/// +/// This function is safe to use on any byte array, valid UTF-8 or not, +/// but the output is only meaningful for well-formed UTF-8. +/// +/// # Example +/// +/// ``` +/// let swordfish = "メカジキ"; +/// let char_count = bytecount::naive_num_chars(swordfish.as_bytes()); +/// assert_eq!(char_count, 4); +/// ``` +pub fn naive_num_chars(utf8_chars: &[u8]) -> usize { + utf8_chars.iter().filter(|&&byte| (byte >> 6) != 0b10).count() +} diff --git a/vendor/bytecount/src/simd/generic.rs b/vendor/bytecount/src/simd/generic.rs new file mode 100644 index 000000000..2031e730e --- /dev/null +++ b/vendor/bytecount/src/simd/generic.rs @@ -0,0 +1,137 @@ +extern crate packed_simd; + +#[cfg(not(feature = "runtime-dispatch-simd"))] +use core::mem; +#[cfg(feature = "runtime-dispatch-simd")] +use std::mem; + +use self::packed_simd::{u8x32, u8x64, FromCast}; + +const MASK: [u8; 64] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, +]; + +unsafe fn u8x64_from_offset(slice: &[u8], offset: usize) -> u8x64 { + u8x64::from_slice_unaligned_unchecked(slice.get_unchecked(offset..)) +} +unsafe fn u8x32_from_offset(slice: &[u8], offset: usize) -> u8x32 { + u8x32::from_slice_unaligned_unchecked(slice.get_unchecked(offset..)) +} + +fn sum_x64(u8s: &u8x64) -> usize { + let mut store = [0; mem::size_of::<u8x64>()]; + u8s.write_to_slice_unaligned(&mut store); + store.iter().map(|&e| e as usize).sum() +} +fn sum_x32(u8s: &u8x32) -> usize { + let mut store = [0; mem::size_of::<u8x32>()]; + u8s.write_to_slice_unaligned(&mut store); + store.iter().map(|&e| e as usize).sum() +} + +pub fn chunk_count(haystack: &[u8], needle: u8) -> usize { + assert!(haystack.len() >= 32); + + unsafe { + let mut offset = 0; + let mut count = 0; + + let needles_x64 = u8x64::splat(needle); + + // 16320 + while haystack.len() >= offset + 64 * 255 { + let mut counts = u8x64::splat(0); + for _ in 0..255 { + counts -= u8x64::from_cast(u8x64_from_offset(haystack, offset).eq(needles_x64)); + offset += 64; + } + count += sum_x64(&counts); + } + + // 8192 + if haystack.len() >= offset + 64 * 128 { + let mut counts = u8x64::splat(0); + for _ in 0..128 { + counts -= u8x64::from_cast(u8x64_from_offset(haystack, offset).eq(needles_x64)); + offset += 64; + } + count += sum_x64(&counts); + } + + let needles_x32 = u8x32::splat(needle); + + // 32 + let mut counts = u8x32::splat(0); + for i in 0..(haystack.len() - offset) / 32 { + counts -= u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32)); + } + count += sum_x32(&counts); + + // Straggler; need to reset counts because prior loop can run 255 times + counts = u8x32::splat(0); + if haystack.len() % 32 != 0 { + counts -= u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) & + u8x32_from_offset(&MASK, haystack.len() % 32); + } + count += sum_x32(&counts); + + count + } +} + +fn is_leading_utf8_byte_x64(u8s: u8x64) -> u8x64 { + u8x64::from_cast((u8s & u8x64::splat(0b1100_0000)).ne(u8x64::splat(0b1000_0000))) +} + +fn is_leading_utf8_byte_x32(u8s: u8x32) -> u8x32 { + u8x32::from_cast((u8s & u8x32::splat(0b1100_0000)).ne(u8x32::splat(0b1000_0000))) +} + +pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize { + assert!(utf8_chars.len() >= 32); + + unsafe { + let mut offset = 0; + let mut count = 0; + + // 16320 + while utf8_chars.len() >= offset + 64 * 255 { + let mut counts = u8x64::splat(0); + for _ in 0..255 { + counts -= is_leading_utf8_byte_x64(u8x64_from_offset(utf8_chars, offset)); + offset += 64; + } + count += sum_x64(&counts); + } + + // 8192 + if utf8_chars.len() >= offset + 64 * 128 { + let mut counts = u8x64::splat(0); + for _ in 0..128 { + counts -= is_leading_utf8_byte_x64(u8x64_from_offset(utf8_chars, offset)); + offset += 64; + } + count += sum_x64(&counts); + } + + // 32 + let mut counts = u8x32::splat(0); + for i in 0..(utf8_chars.len() - offset) / 32 { + counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, offset + i * 32)); + } + count += sum_x32(&counts); + + // Straggler; need to reset counts because prior loop can run 255 times + counts = u8x32::splat(0); + if utf8_chars.len() % 32 != 0 { + counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) & + u8x32_from_offset(&MASK, utf8_chars.len() % 32); + } + count += sum_x32(&counts); + + count + } +} diff --git a/vendor/bytecount/src/simd/mod.rs b/vendor/bytecount/src/simd/mod.rs new file mode 100644 index 000000000..d144e1847 --- /dev/null +++ b/vendor/bytecount/src/simd/mod.rs @@ -0,0 +1,17 @@ +#[cfg(feature = "generic-simd")] +pub mod generic; + +// This is like generic, but written explicitly +// because generic SIMD requires nightly. +#[cfg(all( + feature = "runtime-dispatch-simd", + any(target_arch = "x86", target_arch = "x86_64"), + not(feature = "generic-simd") +))] +pub mod x86_sse2; + +// Modern x86 machines can do lots of fun stuff; +// this is where the *real* optimizations go. +// Runtime feature detection is not available with no_std. +#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] +pub mod x86_avx2; diff --git a/vendor/bytecount/src/simd/x86_avx2.rs b/vendor/bytecount/src/simd/x86_avx2.rs new file mode 100644 index 000000000..90a55c0fb --- /dev/null +++ b/vendor/bytecount/src/simd/x86_avx2.rs @@ -0,0 +1,161 @@ +use std::arch::x86_64::{ + __m256i, + _mm256_and_si256, + _mm256_cmpeq_epi8, + _mm256_extract_epi64, + _mm256_loadu_si256, + _mm256_sad_epu8, + _mm256_set1_epi8, + _mm256_setzero_si256, + _mm256_sub_epi8, + _mm256_xor_si256, +}; + +#[target_feature(enable = "avx2")] +pub unsafe fn _mm256_set1_epu8(a: u8) -> __m256i { + _mm256_set1_epi8(a as i8) +} + +#[target_feature(enable = "avx2")] +pub unsafe fn mm256_cmpneq_epi8(a: __m256i, b: __m256i) -> __m256i { + _mm256_xor_si256(_mm256_cmpeq_epi8(a, b), _mm256_set1_epi8(-1)) +} + +const MASK: [u8; 64] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, +]; + +#[target_feature(enable = "avx2")] +unsafe fn mm256_from_offset(slice: &[u8], offset: usize) -> __m256i { + _mm256_loadu_si256(slice.as_ptr().add(offset) as *const _) +} + +#[target_feature(enable = "avx2")] +unsafe fn sum(u8s: &__m256i) -> usize { + let sums = _mm256_sad_epu8(*u8s, _mm256_setzero_si256()); + ( + _mm256_extract_epi64(sums, 0) + _mm256_extract_epi64(sums, 1) + + _mm256_extract_epi64(sums, 2) + _mm256_extract_epi64(sums, 3) + ) as usize +} + +#[target_feature(enable = "avx2")] +pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { + assert!(haystack.len() >= 32); + + let mut offset = 0; + let mut count = 0; + + let needles = _mm256_set1_epu8(needle); + + // 8160 + while haystack.len() >= offset + 32 * 255 { + let mut counts = _mm256_setzero_si256(); + for _ in 0..255 { + counts = _mm256_sub_epi8( + counts, + _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles) + ); + offset += 32; + } + count += sum(&counts); + } + + // 4096 + if haystack.len() >= offset + 32 * 128 { + let mut counts = _mm256_setzero_si256(); + for _ in 0..128 { + counts = _mm256_sub_epi8( + counts, + _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles) + ); + offset += 32; + } + count += sum(&counts); + } + + // 32 + let mut counts = _mm256_setzero_si256(); + for i in 0..(haystack.len() - offset) / 32 { + counts = _mm256_sub_epi8( + counts, + _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles) + ); + } + if haystack.len() % 32 != 0 { + counts = _mm256_sub_epi8( + counts, + _mm256_and_si256( + _mm256_cmpeq_epi8(mm256_from_offset(haystack, haystack.len() - 32), needles), + mm256_from_offset(&MASK, haystack.len() % 32) + ) + ); + } + count += sum(&counts); + + count +} + +#[target_feature(enable = "avx2")] +unsafe fn is_leading_utf8_byte(u8s: __m256i) -> __m256i { + mm256_cmpneq_epi8(_mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)), _mm256_set1_epu8(0b1000_0000)) +} + +#[target_feature(enable = "avx2")] +pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { + assert!(utf8_chars.len() >= 32); + + let mut offset = 0; + let mut count = 0; + + // 8160 + while utf8_chars.len() >= offset + 32 * 255 { + let mut counts = _mm256_setzero_si256(); + + for _ in 0..255 { + counts = _mm256_sub_epi8( + counts, + is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)) + ); + offset += 32; + } + count += sum(&counts); + } + + // 4096 + if utf8_chars.len() >= offset + 32 * 128 { + let mut counts = _mm256_setzero_si256(); + for _ in 0..128 { + counts = _mm256_sub_epi8( + counts, + is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)) + ); + offset += 32; + } + count += sum(&counts); + } + + // 32 + let mut counts = _mm256_setzero_si256(); + for i in 0..(utf8_chars.len() - offset) / 32 { + counts = _mm256_sub_epi8( + counts, + is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)) + ); + } + if utf8_chars.len() % 32 != 0 { + counts = _mm256_sub_epi8( + counts, + _mm256_and_si256( + is_leading_utf8_byte(mm256_from_offset(utf8_chars, utf8_chars.len() - 32)), + mm256_from_offset(&MASK, utf8_chars.len() % 32) + ) + ); + } + count += sum(&counts); + + count +} diff --git a/vendor/bytecount/src/simd/x86_sse2.rs b/vendor/bytecount/src/simd/x86_sse2.rs new file mode 100644 index 000000000..63d295eae --- /dev/null +++ b/vendor/bytecount/src/simd/x86_sse2.rs @@ -0,0 +1,171 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::{ + __m128i, + _mm_and_si128, + _mm_cmpeq_epi8, + _mm_extract_epi32, + _mm_loadu_si128, + _mm_sad_epu8, + _mm_set1_epi8, + _mm_setzero_si128, + _mm_sub_epi8, + _mm_xor_si128, +}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::{ + __m128i, + _mm_and_si128, + _mm_cmpeq_epi8, + _mm_extract_epi32, + _mm_loadu_si128, + _mm_sad_epu8, + _mm_set1_epi8, + _mm_setzero_si128, + _mm_sub_epi8, + _mm_xor_si128, +}; + +#[target_feature(enable = "sse2")] +pub unsafe fn _mm_set1_epu8(a: u8) -> __m128i { + _mm_set1_epi8(a as i8) +} + +#[target_feature(enable = "sse2")] +pub unsafe fn mm_cmpneq_epi8(a: __m128i, b: __m128i) -> __m128i { + _mm_xor_si128(_mm_cmpeq_epi8(a, b), _mm_set1_epi8(-1)) +} + +const MASK: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, +]; + +#[target_feature(enable = "sse2")] +unsafe fn mm_from_offset(slice: &[u8], offset: usize) -> __m128i { + _mm_loadu_si128(slice.as_ptr().offset(offset as isize) as *const _) +} + +#[target_feature(enable = "sse2")] +unsafe fn sum(u8s: &__m128i) -> usize { + let sums = _mm_sad_epu8(*u8s, _mm_setzero_si128()); + (_mm_extract_epi32(sums, 0) + _mm_extract_epi32(sums, 2)) as usize +} + +#[target_feature(enable = "sse2")] +pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { + assert!(haystack.len() >= 16); + + let mut offset = 0; + let mut count = 0; + + let needles = _mm_set1_epu8(needle); + + // 4080 + while haystack.len() >= offset + 16 * 255 { + let mut counts = _mm_setzero_si128(); + for _ in 0..255 { + counts = _mm_sub_epi8( + counts, + _mm_cmpeq_epi8(mm_from_offset(haystack, offset), needles) + ); + offset += 16; + } + count += sum(&counts); + } + + // 2048 + if haystack.len() >= offset + 16 * 128 { + let mut counts = _mm_setzero_si128(); + for _ in 0..128 { + counts = _mm_sub_epi8( + counts, + _mm_cmpeq_epi8(mm_from_offset(haystack, offset), needles) + ); + offset += 16; + } + count += sum(&counts); + } + + // 16 + let mut counts = _mm_setzero_si128(); + for i in 0..(haystack.len() - offset) / 16 { + counts = _mm_sub_epi8( + counts, + _mm_cmpeq_epi8(mm_from_offset(haystack, offset + i * 16), needles) + ); + } + if haystack.len() % 16 != 0 { + counts = _mm_sub_epi8( + counts, + _mm_and_si128( + _mm_cmpeq_epi8(mm_from_offset(haystack, haystack.len() - 16), needles), + mm_from_offset(&MASK, haystack.len() % 16) + ) + ); + } + count += sum(&counts); + + count +} + +#[target_feature(enable = "sse2")] +unsafe fn is_leading_utf8_byte(u8s: __m128i) -> __m128i { + mm_cmpneq_epi8(_mm_and_si128(u8s, _mm_set1_epu8(0b1100_0000)), _mm_set1_epu8(0b1000_0000)) +} + +#[target_feature(enable = "sse2")] +pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { + assert!(utf8_chars.len() >= 16); + + let mut offset = 0; + let mut count = 0; + + // 4080 + while utf8_chars.len() >= offset + 16 * 255 { + let mut counts = _mm_setzero_si128(); + + for _ in 0..255 { + counts = _mm_sub_epi8( + counts, + is_leading_utf8_byte(mm_from_offset(utf8_chars, offset)) + ); + offset += 16; + } + count += sum(&counts); + } + + // 2048 + if utf8_chars.len() >= offset + 16 * 128 { + let mut counts = _mm_setzero_si128(); + for _ in 0..128 { + counts = _mm_sub_epi8( + counts, + is_leading_utf8_byte(mm_from_offset(utf8_chars, offset)) + ); + offset += 16; + } + count += sum(&counts); + } + + // 16 + let mut counts = _mm_setzero_si128(); + for i in 0..(utf8_chars.len() - offset) / 16 { + counts = _mm_sub_epi8( + counts, + is_leading_utf8_byte(mm_from_offset(utf8_chars, offset + i * 16)) + ); + } + if utf8_chars.len() % 16 != 0 { + counts = _mm_sub_epi8( + counts, + _mm_and_si128( + is_leading_utf8_byte(mm_from_offset(utf8_chars, utf8_chars.len() - 16)), + mm_from_offset(&MASK, utf8_chars.len() % 16) + ) + ); + } + count += sum(&counts); + + count +} |