diff options
Diffstat (limited to 'vendor/bstr/src/ascii.rs')
-rw-r--r-- | vendor/bstr/src/ascii.rs | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/vendor/bstr/src/ascii.rs b/vendor/bstr/src/ascii.rs new file mode 100644 index 0000000..5bbbdad --- /dev/null +++ b/vendor/bstr/src/ascii.rs @@ -0,0 +1,337 @@ +// The following ~400 lines of code exists for exactly one purpose, which is +// to optimize this code: +// +// byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len()) +// +// Yes... Overengineered is a word that comes to mind, but this is effectively +// a very similar problem to memchr, and virtually nobody has been able to +// resist optimizing the crap out of that (except for perhaps the BSD and MUSL +// folks). In particular, this routine makes a very common case (ASCII) very +// fast, which seems worth it. We do stop short of adding AVX variants of the +// code below in order to retain our sanity and also to avoid needing to deal +// with runtime target feature detection. RESIST! +// +// In order to understand the SIMD version below, it would be good to read this +// comment describing how my memchr routine works: +// https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106 +// +// The primary difference with memchr is that for ASCII, we can do a bit less +// work. In particular, we don't need to detect the presence of a specific +// byte, but rather, whether any byte has its most significant bit set. That +// means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to +// _mm_movemask_epi8. + +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const USIZE_BYTES: usize = core::mem::size_of::<usize>(); +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES; + +// This is a mask where the most significant bit of each byte in the usize +// is set. We test this bit to determine whether a character is ASCII or not. +// Namely, a single byte is regarded as an ASCII codepoint if and only if it's +// most significant bit is not set. +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const ASCII_MASK_U64: u64 = 0x8080808080808080; +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const ASCII_MASK: usize = ASCII_MASK_U64 as usize; + +/// Returns the index of the first non ASCII byte in the given slice. +/// +/// If slice only contains ASCII bytes, then the length of the slice is +/// returned. +pub fn first_non_ascii_byte(slice: &[u8]) -> usize { + #[cfg(any(miri, not(target_arch = "x86_64")))] + { + first_non_ascii_byte_fallback(slice) + } + + #[cfg(all(not(miri), target_arch = "x86_64"))] + { + first_non_ascii_byte_sse2(slice) + } +} + +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize { + let align = USIZE_BYTES - 1; + let start_ptr = slice.as_ptr(); + let end_ptr = slice[slice.len()..].as_ptr(); + let mut ptr = start_ptr; + + unsafe { + if slice.len() < USIZE_BYTES { + return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); + } + + let chunk = read_unaligned_usize(ptr); + let mask = chunk & ASCII_MASK; + if mask != 0 { + return first_non_ascii_byte_mask(mask); + } + + ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align)); + debug_assert!(ptr > start_ptr); + debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr); + if slice.len() >= FALLBACK_LOOP_SIZE { + while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) { + debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); + + let a = *(ptr as *const usize); + let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); + if (a | b) & ASCII_MASK != 0 { + // What a kludge. We wrap the position finding code into + // a non-inlineable function, which makes the codegen in + // the tight loop above a bit better by avoiding a + // couple extra movs. We pay for it by two additional + // stores, but only in the case of finding a non-ASCII + // byte. + #[inline(never)] + unsafe fn findpos( + start_ptr: *const u8, + ptr: *const u8, + ) -> usize { + let a = *(ptr as *const usize); + let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); + + let mut at = sub(ptr, start_ptr); + let maska = a & ASCII_MASK; + if maska != 0 { + return at + first_non_ascii_byte_mask(maska); + } + + at += USIZE_BYTES; + let maskb = b & ASCII_MASK; + debug_assert!(maskb != 0); + return at + first_non_ascii_byte_mask(maskb); + } + return findpos(start_ptr, ptr); + } + ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE); + } + } + first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) + } +} + +#[cfg(all(not(miri), target_arch = "x86_64"))] +fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize { + use core::arch::x86_64::*; + + const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>(); + const VECTOR_ALIGN: usize = VECTOR_SIZE - 1; + const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE; + + let start_ptr = slice.as_ptr(); + let end_ptr = slice[slice.len()..].as_ptr(); + let mut ptr = start_ptr; + + unsafe { + if slice.len() < VECTOR_SIZE { + return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); + } + + let chunk = _mm_loadu_si128(ptr as *const __m128i); + let mask = _mm_movemask_epi8(chunk); + if mask != 0 { + return mask.trailing_zeros() as usize; + } + + ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN)); + debug_assert!(ptr > start_ptr); + debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr); + if slice.len() >= VECTOR_LOOP_SIZE { + while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) { + debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE); + + let a = _mm_load_si128(ptr as *const __m128i); + let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i); + let c = + _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i); + let d = + _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i); + + let or1 = _mm_or_si128(a, b); + let or2 = _mm_or_si128(c, d); + let or3 = _mm_or_si128(or1, or2); + if _mm_movemask_epi8(or3) != 0 { + let mut at = sub(ptr, start_ptr); + let mask = _mm_movemask_epi8(a); + if mask != 0 { + return at + mask.trailing_zeros() as usize; + } + + at += VECTOR_SIZE; + let mask = _mm_movemask_epi8(b); + if mask != 0 { + return at + mask.trailing_zeros() as usize; + } + + at += VECTOR_SIZE; + let mask = _mm_movemask_epi8(c); + if mask != 0 { + return at + mask.trailing_zeros() as usize; + } + + at += VECTOR_SIZE; + let mask = _mm_movemask_epi8(d); + debug_assert!(mask != 0); + return at + mask.trailing_zeros() as usize; + } + ptr = ptr_add(ptr, VECTOR_LOOP_SIZE); + } + } + while ptr <= end_ptr.sub(VECTOR_SIZE) { + debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE); + + let chunk = _mm_loadu_si128(ptr as *const __m128i); + let mask = _mm_movemask_epi8(chunk); + if mask != 0 { + return sub(ptr, start_ptr) + mask.trailing_zeros() as usize; + } + ptr = ptr.add(VECTOR_SIZE); + } + first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) + } +} + +#[inline(always)] +unsafe fn first_non_ascii_byte_slow( + start_ptr: *const u8, + end_ptr: *const u8, + mut ptr: *const u8, +) -> usize { + debug_assert!(start_ptr <= ptr); + debug_assert!(ptr <= end_ptr); + + while ptr < end_ptr { + if *ptr > 0x7F { + return sub(ptr, start_ptr); + } + ptr = ptr.offset(1); + } + sub(end_ptr, start_ptr) +} + +/// Compute the position of the first ASCII byte in the given mask. +/// +/// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is +/// 8 contiguous bytes of the slice being checked where *at least* one of those +/// bytes is not an ASCII byte. +/// +/// The position returned is always in the inclusive range [0, 7]. +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +fn first_non_ascii_byte_mask(mask: usize) -> usize { + #[cfg(target_endian = "little")] + { + mask.trailing_zeros() as usize / 8 + } + #[cfg(target_endian = "big")] + { + mask.leading_zeros() as usize / 8 + } +} + +/// Increment the given pointer by the given amount. +unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 { + debug_assert!(amt < ::core::isize::MAX as usize); + ptr.offset(amt as isize) +} + +/// Decrement the given pointer by the given amount. +unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 { + debug_assert!(amt < ::core::isize::MAX as usize); + ptr.offset((amt as isize).wrapping_neg()) +} + +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { + use core::ptr; + + let mut n: usize = 0; + ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES); + n +} + +/// Subtract `b` from `a` and return the difference. `a` should be greater than +/// or equal to `b`. +fn sub(a: *const u8, b: *const u8) -> usize { + debug_assert!(a >= b); + (a as usize) - (b as usize) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Our testing approach here is to try and exhaustively test every case. + // This includes the position at which a non-ASCII byte occurs in addition + // to the alignment of the slice that we're searching. + + #[test] + fn positive_fallback_forward() { + for i in 0..517 { + let s = "a".repeat(i); + assert_eq!( + i, + first_non_ascii_byte_fallback(s.as_bytes()), + "i: {:?}, len: {:?}, s: {:?}", + i, + s.len(), + s + ); + } + } + + #[test] + #[cfg(target_arch = "x86_64")] + #[cfg(not(miri))] + fn positive_sse2_forward() { + for i in 0..517 { + let b = "a".repeat(i).into_bytes(); + assert_eq!(b.len(), first_non_ascii_byte_sse2(&b)); + } + } + + #[test] + #[cfg(not(miri))] + fn negative_fallback_forward() { + for i in 0..517 { + for align in 0..65 { + let mut s = "a".repeat(i); + s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); + let s = s.get(align..).unwrap_or(""); + assert_eq!( + i.saturating_sub(align), + first_non_ascii_byte_fallback(s.as_bytes()), + "i: {:?}, align: {:?}, len: {:?}, s: {:?}", + i, + align, + s.len(), + s + ); + } + } + } + + #[test] + #[cfg(target_arch = "x86_64")] + #[cfg(not(miri))] + fn negative_sse2_forward() { + for i in 0..517 { + for align in 0..65 { + let mut s = "a".repeat(i); + s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); + let s = s.get(align..).unwrap_or(""); + assert_eq!( + i.saturating_sub(align), + first_non_ascii_byte_sse2(s.as_bytes()), + "i: {:?}, align: {:?}, len: {:?}, s: {:?}", + i, + align, + s.len(), + s + ); + } + } + } +} |