// avx2 decode modified from https://github.com/zbjornson/fast-hex/blob/master/src/hex.cc #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; use crate::error::Error; const NIL: u8 = u8::MAX; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] const T_MASK: i32 = 65535; const fn init_unhex_array(check_case: CheckCase) -> [u8; 256] { let mut arr = [0; 256]; let mut i = 0; while i < 256 { arr[i] = match i as u8 { b'0'..=b'9' => i as u8 - b'0', b'a'..=b'f' => match check_case { CheckCase::Lower | CheckCase::None => i as u8 - b'a' + 10, _ => NIL, }, b'A'..=b'F' => match check_case { CheckCase::Upper | CheckCase::None => i as u8 - b'A' + 10, _ => NIL, }, _ => NIL, }; i += 1; } arr } const fn init_unhex4_array(check_case: CheckCase) -> [u8; 256] { let unhex_arr = init_unhex_array(check_case); let mut unhex4_arr = [NIL; 256]; let mut i = 0; while i < 256 { if unhex_arr[i] != NIL { unhex4_arr[i] = unhex_arr[i] << 4; } i += 1; } unhex4_arr } // ASCII -> hex pub(crate) static UNHEX: [u8; 256] = init_unhex_array(CheckCase::None); // ASCII -> hex, lower case pub(crate) static UNHEX_LOWER: [u8; 256] = init_unhex_array(CheckCase::Lower); // ASCII -> hex, upper case pub(crate) static UNHEX_UPPER: [u8; 256] = init_unhex_array(CheckCase::Upper); // ASCII -> hex << 4 pub(crate) static UNHEX4: [u8; 256] = init_unhex4_array(CheckCase::None); const _0213: i32 = 0b11011000; // lower nibble #[inline] fn unhex_b(x: usize) -> u8 { UNHEX[x] } // upper nibble, logically equivalent to unhex_b(x) << 4 #[inline] fn unhex_a(x: usize) -> u8 { UNHEX4[x] } #[inline] #[target_feature(enable = "avx2")] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] unsafe fn unhex_avx2(value: __m256i) -> __m256i { let sr6 = _mm256_srai_epi16(value, 6); let and15 = _mm256_and_si256(value, _mm256_set1_epi16(0xf)); let mul = _mm256_maddubs_epi16(sr6, _mm256_set1_epi16(9)); _mm256_add_epi16(mul, and15) } // (a << 4) | b; #[inline] #[target_feature(enable = "avx2")] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] unsafe fn nib2byte_avx2(a1: __m256i, b1: __m256i, a2: __m256i, b2: __m256i) -> __m256i { let a4_1 = _mm256_slli_epi16(a1, 4); let a4_2 = _mm256_slli_epi16(a2, 4); let a4orb_1 = _mm256_or_si256(a4_1, b1); let a4orb_2 = _mm256_or_si256(a4_2, b2); let pck1 = _mm256_packus_epi16(a4orb_1, a4orb_2); _mm256_permute4x64_epi64(pck1, _0213) } /// Check if the input is valid hex bytes slice pub fn hex_check(src: &[u8]) -> bool { hex_check_with_case(src, CheckCase::None) } /// Check if the input is valid hex bytes slice with case check pub fn hex_check_with_case(src: &[u8], check_case: CheckCase) -> bool { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { match crate::vectorization_support() { crate::Vectorization::AVX2 | crate::Vectorization::SSE41 => unsafe { hex_check_sse_with_case(src, check_case) }, crate::Vectorization::None => hex_check_fallback_with_case(src, check_case), } } #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] hex_check_fallback_with_case(src, check_case) } /// Check if the input is valid hex bytes slice pub fn hex_check_fallback(src: &[u8]) -> bool { hex_check_fallback_with_case(src, CheckCase::None) } /// Check if the input is valid hex bytes slice with case check pub fn hex_check_fallback_with_case(src: &[u8], check_case: CheckCase) -> bool { match check_case { CheckCase::None => src.iter().all(|&x| UNHEX[x as usize] != NIL), CheckCase::Lower => src.iter().all(|&x| UNHEX_LOWER[x as usize] != NIL), CheckCase::Upper => src.iter().all(|&x| UNHEX_UPPER[x as usize] != NIL), } } /// # Safety /// Check if a byte slice is valid. #[target_feature(enable = "sse4.1")] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] pub unsafe fn hex_check_sse(src: &[u8]) -> bool { hex_check_sse_with_case(src, CheckCase::None) } #[derive(Eq, PartialEq)] pub enum CheckCase { None, Lower, Upper, } /// # Safety /// Check if a byte slice is valid on given check_case. #[target_feature(enable = "sse4.1")] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] pub unsafe fn hex_check_sse_with_case(mut src: &[u8], check_case: CheckCase) -> bool { let ascii_zero = _mm_set1_epi8((b'0' - 1) as i8); let ascii_nine = _mm_set1_epi8((b'9' + 1) as i8); let ascii_ua = _mm_set1_epi8((b'A' - 1) as i8); let ascii_uf = _mm_set1_epi8((b'F' + 1) as i8); let ascii_la = _mm_set1_epi8((b'a' - 1) as i8); let ascii_lf = _mm_set1_epi8((b'f' + 1) as i8); while src.len() >= 16 { let unchecked = _mm_loadu_si128(src.as_ptr() as *const _); let gt0 = _mm_cmpgt_epi8(unchecked, ascii_zero); let lt9 = _mm_cmplt_epi8(unchecked, ascii_nine); let valid_digit = _mm_and_si128(gt0, lt9); let (valid_la_lf, valid_ua_uf) = match check_case { CheckCase::None => { let gtua = _mm_cmpgt_epi8(unchecked, ascii_ua); let ltuf = _mm_cmplt_epi8(unchecked, ascii_uf); let gtla = _mm_cmpgt_epi8(unchecked, ascii_la); let ltlf = _mm_cmplt_epi8(unchecked, ascii_lf); ( Some(_mm_and_si128(gtla, ltlf)), Some(_mm_and_si128(gtua, ltuf)), ) } CheckCase::Lower => { let gtla = _mm_cmpgt_epi8(unchecked, ascii_la); let ltlf = _mm_cmplt_epi8(unchecked, ascii_lf); (Some(_mm_and_si128(gtla, ltlf)), None) } CheckCase::Upper => { let gtua = _mm_cmpgt_epi8(unchecked, ascii_ua); let ltuf = _mm_cmplt_epi8(unchecked, ascii_uf); (None, Some(_mm_and_si128(gtua, ltuf))) } }; let valid_letter = match (valid_la_lf, valid_ua_uf) { (Some(valid_lower), Some(valid_upper)) => _mm_or_si128(valid_lower, valid_upper), (Some(valid_lower), None) => valid_lower, (None, Some(valid_upper)) => valid_upper, _ => unreachable!(), }; let ret = _mm_movemask_epi8(_mm_or_si128(valid_digit, valid_letter)); if ret != T_MASK { return false; } src = &src[16..]; } hex_check_fallback_with_case(src, check_case) } /// Hex decode src into dst. /// The length of src must be even and not zero. /// The length of dst must be at least src.len() / 2. pub fn hex_decode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { hex_decode_with_case(src, dst, CheckCase::None) } /// Hex decode src into dst. /// The length of src must be even, and it's allowed to decode a zero length src. /// The length of dst must be at least src.len() / 2. /// when check_case is CheckCase::Lower, the hex string must be lower case. /// when check_case is CheckCase::Upper, the hex string must be upper case. /// when check_case is CheckCase::None, the hex string can be lower case or upper case. pub fn hex_decode_with_case( src: &[u8], dst: &mut [u8], check_case: CheckCase, ) -> Result<(), Error> { if src.len() & 1 != 0 { return Err(Error::InvalidLength(src.len())); } let expect_dst_len = src.len().checked_div(2).unwrap(); if dst.len() < expect_dst_len { return Err(Error::InvalidLength(dst.len())); } if !hex_check_with_case(src, check_case) { return Err(Error::InvalidChar); } hex_decode_unchecked(src, dst); Ok(()) } pub fn hex_decode_unchecked(src: &[u8], dst: &mut [u8]) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { match crate::vectorization_support() { crate::Vectorization::AVX2 => unsafe { hex_decode_avx2(src, dst) }, crate::Vectorization::None | crate::Vectorization::SSE41 => { hex_decode_fallback(src, dst) } } } #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] hex_decode_fallback(src, dst); } #[target_feature(enable = "avx2")] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] unsafe fn hex_decode_avx2(mut src: &[u8], mut dst: &mut [u8]) { // 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, // 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1 let mask_a = _mm256_setr_epi8( 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, ); // 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1, // 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1 let mask_b = _mm256_setr_epi8( 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1, 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1, ); while dst.len() >= 32 { let av1 = _mm256_loadu_si256(src.as_ptr() as *const _); let av2 = _mm256_loadu_si256(src[32..].as_ptr() as *const _); let mut a1 = _mm256_shuffle_epi8(av1, mask_a); let mut b1 = _mm256_shuffle_epi8(av1, mask_b); let mut a2 = _mm256_shuffle_epi8(av2, mask_a); let mut b2 = _mm256_shuffle_epi8(av2, mask_b); a1 = unhex_avx2(a1); a2 = unhex_avx2(a2); b1 = unhex_avx2(b1); b2 = unhex_avx2(b2); let bytes = nib2byte_avx2(a1, b1, a2, b2); //dst does not need to be aligned on any particular boundary _mm256_storeu_si256(dst.as_mut_ptr() as *mut _, bytes); dst = &mut dst[32..]; src = &src[64..]; } hex_decode_fallback(src, dst) } pub fn hex_decode_fallback(src: &[u8], dst: &mut [u8]) { for (slot, bytes) in dst.iter_mut().zip(src.chunks_exact(2)) { let a = unhex_a(bytes[0] as usize); let b = unhex_b(bytes[1] as usize); *slot = a | b; } } #[cfg(test)] mod tests { use crate::decode::NIL; use crate::{ decode::{ hex_check_fallback, hex_check_fallback_with_case, hex_decode_fallback, CheckCase, }, encode::hex_string, }; use proptest::proptest; fn _test_decode_fallback(s: &String) { let len = s.as_bytes().len(); let mut dst = Vec::with_capacity(len); dst.resize(len, 0); let hex_string = hex_string(s.as_bytes()); hex_decode_fallback(hex_string.as_bytes(), &mut dst); assert_eq!(&dst[..], s.as_bytes()); } proptest! { #[test] fn test_decode_fallback(ref s in ".+") { _test_decode_fallback(s); } } fn _test_check_fallback_true(s: &String) { assert!(hex_check_fallback(s.as_bytes())); match ( s.contains(char::is_lowercase), s.contains(char::is_uppercase), ) { (true, true) => { assert!(!hex_check_fallback_with_case( s.as_bytes(), CheckCase::Lower )); assert!(!hex_check_fallback_with_case( s.as_bytes(), CheckCase::Upper )); } (true, false) => { assert!(hex_check_fallback_with_case(s.as_bytes(), CheckCase::Lower)); assert!(!hex_check_fallback_with_case( s.as_bytes(), CheckCase::Upper )); } (false, true) => { assert!(!hex_check_fallback_with_case( s.as_bytes(), CheckCase::Lower )); assert!(hex_check_fallback_with_case(s.as_bytes(), CheckCase::Upper)); } (false, false) => { assert!(hex_check_fallback_with_case(s.as_bytes(), CheckCase::Lower)); assert!(hex_check_fallback_with_case(s.as_bytes(), CheckCase::Upper)); } } } proptest! { #[test] fn test_check_fallback_true(ref s in "[0-9a-fA-F]+") { _test_check_fallback_true(s); } } fn _test_check_fallback_false(s: &String) { assert!(!hex_check_fallback(s.as_bytes())); assert!(!hex_check_fallback_with_case( s.as_bytes(), CheckCase::Upper )); assert!(!hex_check_fallback_with_case( s.as_bytes(), CheckCase::Lower )); } proptest! { #[test] fn test_check_fallback_false(ref s in ".{16}[^0-9a-fA-F]+") { _test_check_fallback_false(s); } } #[test] fn test_init_static_array_is_right() { static OLD_UNHEX: [u8; 256] = [ NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 10, 11, 12, 13, 14, 15, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 10, 11, 12, 13, 14, 15, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, ]; static OLD_UNHEX4: [u8; 256] = [ NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 160, 176, 192, 208, 224, 240, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 160, 176, 192, 208, 224, 240, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, ]; assert_eq!(OLD_UNHEX, crate::decode::UNHEX); assert_eq!(OLD_UNHEX4, crate::decode::UNHEX4); } } #[cfg(all(test, any(target_arch = "x86", target_arch = "x86_64")))] mod test_sse { use crate::decode::{ hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_sse, hex_check_sse_with_case, hex_check_with_case, hex_decode, hex_decode_unchecked, hex_decode_with_case, CheckCase, }; use proptest::proptest; fn _test_check_sse_with_case(s: &String, check_case: CheckCase, expect_result: bool) { if is_x86_feature_detected!("sse4.1") { assert_eq!( unsafe { hex_check_sse_with_case(s.as_bytes(), check_case) }, expect_result ) } } fn _test_check_sse_true(s: &String) { if is_x86_feature_detected!("sse4.1") { assert!(unsafe { hex_check_sse(s.as_bytes()) }); } } proptest! { #[test] fn test_check_sse_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") { _test_check_sse_true(s); _test_check_sse_with_case(s, CheckCase::None, true); match (s.contains(char::is_lowercase), s.contains(char::is_uppercase)){ (true, true) => { _test_check_sse_with_case(s, CheckCase::Lower, false); _test_check_sse_with_case(s, CheckCase::Upper, false); }, (true, false) => { _test_check_sse_with_case(s, CheckCase::Lower, true); _test_check_sse_with_case(s, CheckCase::Upper, false); }, (false, true) => { _test_check_sse_with_case(s, CheckCase::Lower, false); _test_check_sse_with_case(s, CheckCase::Upper, true); }, (false, false) => { _test_check_sse_with_case(s, CheckCase::Lower, true); _test_check_sse_with_case(s, CheckCase::Upper, true); } } } } fn _test_check_sse_false(s: &String) { if is_x86_feature_detected!("sse4.1") { assert!(!unsafe { hex_check_sse(s.as_bytes()) }); } } proptest! { #[test] fn test_check_sse_false(ref s in ".{16}[^0-9a-fA-F]+") { _test_check_sse_false(s); _test_check_sse_with_case(s, CheckCase::None, false); _test_check_sse_with_case(s, CheckCase::Lower, false); _test_check_sse_with_case(s, CheckCase::Upper, false); } } #[test] fn test_decode_zero_length_src_should_be_ok() { let src = b""; let mut dst = [0u8; 10]; assert!(hex_decode(src, &mut dst).is_ok()); assert!(hex_decode_with_case(src, &mut dst, CheckCase::None).is_ok()); assert!(hex_check(src)); assert!(hex_check_with_case(src, CheckCase::None)); assert!(hex_check_fallback(src)); assert!(hex_check_fallback_with_case(src, CheckCase::None)); if is_x86_feature_detected!("sse4.1") { assert!(unsafe { hex_check_sse_with_case(src, CheckCase::None) }); assert!(unsafe { hex_check_sse(src) }); } // this function have no return value, so we just execute it and expect no panic hex_decode_unchecked(src, &mut dst); } }