diff options
Diffstat (limited to 'third_party/rust/base64/src/engine')
-rw-r--r-- | third_party/rust/base64/src/engine/general_purpose/decode.rs | 348 | ||||
-rw-r--r-- | third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs | 161 | ||||
-rw-r--r-- | third_party/rust/base64/src/engine/general_purpose/mod.rs | 349 | ||||
-rw-r--r-- | third_party/rust/base64/src/engine/mod.rs | 410 | ||||
-rw-r--r-- | third_party/rust/base64/src/engine/naive.rs | 219 | ||||
-rw-r--r-- | third_party/rust/base64/src/engine/tests.rs | 1430 |
6 files changed, 2917 insertions, 0 deletions
diff --git a/third_party/rust/base64/src/engine/general_purpose/decode.rs b/third_party/rust/base64/src/engine/general_purpose/decode.rs new file mode 100644 index 0000000000..e9fd78877b --- /dev/null +++ b/third_party/rust/base64/src/engine/general_purpose/decode.rs @@ -0,0 +1,348 @@ +use crate::{ + engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode}, + DecodeError, PAD_BYTE, +}; + +// decode logic operates on chunks of 8 input bytes without padding +const INPUT_CHUNK_LEN: usize = 8; +const DECODED_CHUNK_LEN: usize = 6; + +// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last +// 2 bytes of any output u64 should not be counted as written to (but must be available in a +// slice). +const DECODED_CHUNK_SUFFIX: usize = 2; + +// how many u64's of input to handle at a time +const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; + +const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; + +// includes the trailing 2 bytes for the final u64 write +const DECODED_BLOCK_LEN: usize = + CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; + +#[doc(hidden)] +pub struct GeneralPurposeEstimate { + /// Total number of decode chunks, including a possibly partial last chunk + num_chunks: usize, + decoded_len_estimate: usize, +} + +impl GeneralPurposeEstimate { + pub(crate) fn new(encoded_len: usize) -> Self { + Self { + num_chunks: encoded_len + .checked_add(INPUT_CHUNK_LEN - 1) + .expect("Overflow when calculating number of chunks in input") + / INPUT_CHUNK_LEN, + decoded_len_estimate: encoded_len + .checked_add(3) + .expect("Overflow when calculating decoded len estimate") + / 4 + * 3, + } + } +} + +impl DecodeEstimate for GeneralPurposeEstimate { + fn decoded_len_estimate(&self) -> usize { + self.decoded_len_estimate + } +} + +/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. +/// Returns the number of bytes written, or an error. +// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is +// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, +// but this is fragile and the best setting changes with only minor code modifications. +#[inline] +pub(crate) fn decode_helper( + input: &[u8], + estimate: GeneralPurposeEstimate, + output: &mut [u8], + decode_table: &[u8; 256], + decode_allow_trailing_bits: bool, + padding_mode: DecodePaddingMode, +) -> Result<usize, DecodeError> { + let remainder_len = input.len() % INPUT_CHUNK_LEN; + + // Because the fast decode loop writes in groups of 8 bytes (unrolled to + // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of + // which only 6 are valid data), we need to be sure that we stop using the fast decode loop + // soon enough that there will always be 2 more bytes of valid data written after that loop. + let trailing_bytes_to_skip = match remainder_len { + // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, + // and the fast decode logic cannot handle padding + 0 => INPUT_CHUNK_LEN, + // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte + 1 | 5 => { + // trailing whitespace is so common that it's worth it to check the last byte to + // possibly return a better error message + if let Some(b) = input.last() { + if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, *b)); + } + } + + return Err(DecodeError::InvalidLength); + } + // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes + // written by the fast decode loop. So, we have to ignore both these 2 bytes and the + // previous chunk. + 2 => INPUT_CHUNK_LEN + 2, + // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this + // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail + // with an error, not panic from going past the bounds of the output slice, so we let it + // use stage 3 + 4. + 3 => INPUT_CHUNK_LEN + 3, + // This can also decode to one output byte because it may be 2 input chars + 2 padding + // chars, which would decode to 1 byte. + 4 => INPUT_CHUNK_LEN + 4, + // Everything else is a legal decode len (given that we don't require padding), and will + // decode to at least 2 bytes of output. + _ => remainder_len, + }; + + // rounded up to include partial chunks + let mut remaining_chunks = estimate.num_chunks; + + let mut input_index = 0; + let mut output_index = 0; + + { + let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); + + // Fast loop, stage 1 + // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks + if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { + while input_index <= max_start_index { + let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; + let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; + + decode_chunk( + &input_slice[0..], + input_index, + decode_table, + &mut output_slice[0..], + )?; + decode_chunk( + &input_slice[8..], + input_index + 8, + decode_table, + &mut output_slice[6..], + )?; + decode_chunk( + &input_slice[16..], + input_index + 16, + decode_table, + &mut output_slice[12..], + )?; + decode_chunk( + &input_slice[24..], + input_index + 24, + decode_table, + &mut output_slice[18..], + )?; + + input_index += INPUT_BLOCK_LEN; + output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; + remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; + } + } + + // Fast loop, stage 2 (aka still pretty fast loop) + // 8 bytes at a time for whatever we didn't do in stage 1. + if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { + while input_index < max_start_index { + decode_chunk( + &input[input_index..(input_index + INPUT_CHUNK_LEN)], + input_index, + decode_table, + &mut output + [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], + )?; + + output_index += DECODED_CHUNK_LEN; + input_index += INPUT_CHUNK_LEN; + remaining_chunks -= 1; + } + } + } + + // Stage 3 + // If input length was such that a chunk had to be deferred until after the fast loop + // because decoding it would have produced 2 trailing bytes that wouldn't then be + // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 + // trailing bytes. + // However, we still need to avoid the last chunk (partial or complete) because it could + // have padding, so we always do 1 fewer to avoid the last chunk. + for _ in 1..remaining_chunks { + decode_chunk_precise( + &input[input_index..], + input_index, + decode_table, + &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], + )?; + + input_index += INPUT_CHUNK_LEN; + output_index += DECODED_CHUNK_LEN; + } + + // always have one more (possibly partial) block of 8 input + debug_assert!(input.len() - input_index > 1 || input.is_empty()); + debug_assert!(input.len() - input_index <= 8); + + super::decode_suffix::decode_suffix( + input, + input_index, + output, + output_index, + decode_table, + decode_allow_trailing_bits, + padding_mode, + ) +} + +/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the +/// first 6 of those contain meaningful data. +/// +/// `input` is the bytes to decode, of which the first 8 bytes will be processed. +/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors +/// accurately) +/// `decode_table` is the lookup table for the particular base64 alphabet. +/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded +/// data. +// yes, really inline (worth 30-50% speedup) +#[inline(always)] +fn decode_chunk( + input: &[u8], + index_at_start_of_input: usize, + decode_table: &[u8; 256], + output: &mut [u8], +) -> Result<(), DecodeError> { + let morsel = decode_table[input[0] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = (morsel as u64) << 58; + + let morsel = decode_table[input[1] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= (morsel as u64) << 52; + + let morsel = decode_table[input[2] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= (morsel as u64) << 46; + + let morsel = decode_table[input[3] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= (morsel as u64) << 40; + + let morsel = decode_table[input[4] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 4, + input[4], + )); + } + accum |= (morsel as u64) << 34; + + let morsel = decode_table[input[5] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 5, + input[5], + )); + } + accum |= (morsel as u64) << 28; + + let morsel = decode_table[input[6] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 6, + input[6], + )); + } + accum |= (morsel as u64) << 22; + + let morsel = decode_table[input[7] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 7, + input[7], + )); + } + accum |= (morsel as u64) << 16; + + write_u64(output, accum); + + Ok(()) +} + +/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 +/// trailing garbage bytes. +#[inline] +fn decode_chunk_precise( + input: &[u8], + index_at_start_of_input: usize, + decode_table: &[u8; 256], + output: &mut [u8], +) -> Result<(), DecodeError> { + let mut tmp_buf = [0_u8; 8]; + + decode_chunk( + input, + index_at_start_of_input, + decode_table, + &mut tmp_buf[..], + )?; + + output[0..6].copy_from_slice(&tmp_buf[0..6]); + + Ok(()) +} + +#[inline] +fn write_u64(output: &mut [u8], value: u64) { + output[..8].copy_from_slice(&value.to_be_bytes()); +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::engine::general_purpose::STANDARD; + + #[test] + fn decode_chunk_precise_writes_only_6_bytes() { + let input = b"Zm9vYmFy"; // "foobar" + let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + + decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); + } + + #[test] + fn decode_chunk_writes_8_bytes() { + let input = b"Zm9vYmFy"; // "foobar" + let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + + decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); + } +} diff --git a/third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs b/third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs new file mode 100644 index 0000000000..5652035d0e --- /dev/null +++ b/third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs @@ -0,0 +1,161 @@ +use crate::{ + engine::{general_purpose::INVALID_VALUE, DecodePaddingMode}, + DecodeError, PAD_BYTE, +}; + +/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided +/// parameters. +/// +/// Returns the total number of bytes decoded, including the ones indicated as already written by +/// `output_index`. +pub(crate) fn decode_suffix( + input: &[u8], + input_index: usize, + output: &mut [u8], + mut output_index: usize, + decode_table: &[u8; 256], + decode_allow_trailing_bits: bool, + padding_mode: DecodePaddingMode, +) -> Result<usize, DecodeError> { + // Decode any leftovers that aren't a complete input block of 8 bytes. + // Use a u64 as a stack-resident 8 byte buffer. + let mut leftover_bits: u64 = 0; + let mut morsels_in_leftover = 0; + let mut padding_bytes = 0; + let mut first_padding_index: usize = 0; + let mut last_symbol = 0_u8; + let start_of_leftovers = input_index; + + for (i, &b) in input[start_of_leftovers..].iter().enumerate() { + // '=' padding + if b == PAD_BYTE { + // There can be bad padding bytes in a few ways: + // 1 - Padding with non-padding characters after it + // 2 - Padding after zero or one characters in the current quad (should only + // be after 2 or 3 chars) + // 3 - More than two characters of padding. If 3 or 4 padding chars + // are in the same quad, that implies it will be caught by #2. + // If it spreads from one quad to another, it will be an invalid byte + // in the first quad. + // 4 - Non-canonical padding -- 1 byte when it should be 2, etc. + // Per config, non-canonical but still functional non- or partially-padded base64 + // may be treated as an error condition. + + if i % 4 < 2 { + // Check for case #2. + let bad_padding_index = start_of_leftovers + + if padding_bytes > 0 { + // If we've already seen padding, report the first padding index. + // This is to be consistent with the normal decode logic: it will report an + // error on the first padding character (since it doesn't expect to see + // anything but actual encoded data). + // This could only happen if the padding started in the previous quad since + // otherwise this case would have been hit at i % 4 == 0 if it was the same + // quad. + first_padding_index + } else { + // haven't seen padding before, just use where we are now + i + }; + return Err(DecodeError::InvalidByte(bad_padding_index, b)); + } + + if padding_bytes == 0 { + first_padding_index = i; + } + + padding_bytes += 1; + continue; + } + + // Check for case #1. + // To make '=' handling consistent with the main loop, don't allow + // non-suffix '=' in trailing chunk either. Report error as first + // erroneous padding. + if padding_bytes > 0 { + return Err(DecodeError::InvalidByte( + start_of_leftovers + first_padding_index, + PAD_BYTE, + )); + } + + last_symbol = b; + + // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. + // Pack the leftovers from left to right. + let shift = 64 - (morsels_in_leftover + 1) * 6; + let morsel = decode_table[b as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); + } + + leftover_bits |= (morsel as u64) << shift; + morsels_in_leftover += 1; + } + + match padding_mode { + DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } + DecodePaddingMode::RequireCanonical => { + if (padding_bytes + morsels_in_leftover) % 4 != 0 { + return Err(DecodeError::InvalidPadding); + } + } + DecodePaddingMode::RequireNone => { + if padding_bytes > 0 { + // check at the end to make sure we let the cases of padding that should be InvalidByte + // get hit + return Err(DecodeError::InvalidPadding); + } + } + } + + // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed. + // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits + // of bottom 6 bits set). + // When decoding two symbols back to one trailing byte, any final symbol higher than + // w would still decode to the original byte because we only care about the top two + // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a + // mask based on how many bits are used for just the canonical encoding, and optionally + // error if any other bits are set. In the example of one encoded byte -> 2 symbols, + // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and + // useless since there are no more symbols to provide the necessary 4 additional bits + // to finish the second original byte. + + let leftover_bits_ready_to_append = match morsels_in_leftover { + 0 => 0, + 2 => 8, + 3 => 16, + 4 => 24, + 6 => 32, + 7 => 40, + 8 => 48, + // can also be detected as case #2 bad padding above + _ => unreachable!( + "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" + ), + }; + + // if there are bits set outside the bits we care about, last symbol encodes trailing bits that + // will not be included in the output + let mask = !0 >> leftover_bits_ready_to_append; + if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { + // last morsel is at `morsels_in_leftover` - 1 + return Err(DecodeError::InvalidLastSymbol( + start_of_leftovers + morsels_in_leftover - 1, + last_symbol, + )); + } + + // TODO benchmark simply converting to big endian bytes + let mut leftover_bits_appended_to_buf = 0; + while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { + // `as` simply truncates the higher bits, which is what we want here + let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; + output[output_index] = selected_bits; + output_index += 1; + + leftover_bits_appended_to_buf += 8; + } + + Ok(output_index) +} diff --git a/third_party/rust/base64/src/engine/general_purpose/mod.rs b/third_party/rust/base64/src/engine/general_purpose/mod.rs new file mode 100644 index 0000000000..af8897bc2b --- /dev/null +++ b/third_party/rust/base64/src/engine/general_purpose/mod.rs @@ -0,0 +1,349 @@ +//! Provides the [GeneralPurpose] engine and associated config types. +use crate::{ + alphabet, + alphabet::Alphabet, + engine::{Config, DecodePaddingMode}, + DecodeError, +}; +use core::convert::TryInto; + +mod decode; +pub(crate) mod decode_suffix; +pub use decode::GeneralPurposeEstimate; + +pub(crate) const INVALID_VALUE: u8 = 255; + +/// A general-purpose base64 engine. +/// +/// - It uses no vector CPU instructions, so it will work on any system. +/// - It is reasonably fast (~2-3GiB/s). +/// - It is not constant-time, though, so it is vulnerable to timing side-channel attacks. For loading cryptographic keys, etc, it is suggested to use the forthcoming constant-time implementation. +pub struct GeneralPurpose { + encode_table: [u8; 64], + decode_table: [u8; 256], + config: GeneralPurposeConfig, +} + +impl GeneralPurpose { + /// Create a `GeneralPurpose` engine from an [Alphabet]. + /// + /// While not very expensive to initialize, ideally these should be cached + /// if the engine will be used repeatedly. + pub const fn new(alphabet: &Alphabet, config: GeneralPurposeConfig) -> Self { + Self { + encode_table: encode_table(alphabet), + decode_table: decode_table(alphabet), + config, + } + } +} + +impl super::Engine for GeneralPurpose { + type Config = GeneralPurposeConfig; + type DecodeEstimate = GeneralPurposeEstimate; + + fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize { + let mut input_index: usize = 0; + + const BLOCKS_PER_FAST_LOOP: usize = 4; + const LOW_SIX_BITS: u64 = 0x3F; + + // we read 8 bytes at a time (u64) but only actually consume 6 of those bytes. Thus, we need + // 2 trailing bytes to be available to read.. + let last_fast_index = input.len().saturating_sub(BLOCKS_PER_FAST_LOOP * 6 + 2); + let mut output_index = 0; + + if last_fast_index > 0 { + while input_index <= last_fast_index { + // Major performance wins from letting the optimizer do the bounds check once, mostly + // on the output side + let input_chunk = + &input[input_index..(input_index + (BLOCKS_PER_FAST_LOOP * 6 + 2))]; + let output_chunk = + &mut output[output_index..(output_index + BLOCKS_PER_FAST_LOOP * 8)]; + + // Hand-unrolling for 32 vs 16 or 8 bytes produces yields performance about equivalent + // to unsafe pointer code on a Xeon E5-1650v3. 64 byte unrolling was slightly better for + // large inputs but significantly worse for 50-byte input, unsurprisingly. I suspect + // that it's a not uncommon use case to encode smallish chunks of data (e.g. a 64-byte + // SHA-512 digest), so it would be nice if that fit in the unrolled loop at least once. + // Plus, single-digit percentage performance differences might well be quite different + // on different hardware. + + let input_u64 = read_u64(&input_chunk[0..]); + + output_chunk[0] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[1] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[2] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[3] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[4] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[5] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[6] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[7] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + let input_u64 = read_u64(&input_chunk[6..]); + + output_chunk[8] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[9] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[10] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[11] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[12] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[13] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[14] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[15] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + let input_u64 = read_u64(&input_chunk[12..]); + + output_chunk[16] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[17] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[18] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[19] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[20] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[21] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[22] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[23] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + let input_u64 = read_u64(&input_chunk[18..]); + + output_chunk[24] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[25] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[26] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[27] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[28] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[29] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[30] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[31] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + output_index += BLOCKS_PER_FAST_LOOP * 8; + input_index += BLOCKS_PER_FAST_LOOP * 6; + } + } + + // Encode what's left after the fast loop. + + const LOW_SIX_BITS_U8: u8 = 0x3F; + + let rem = input.len() % 3; + let start_of_rem = input.len() - rem; + + // start at the first index not handled by fast loop, which may be 0. + + while input_index < start_of_rem { + let input_chunk = &input[input_index..(input_index + 3)]; + let output_chunk = &mut output[output_index..(output_index + 4)]; + + output_chunk[0] = self.encode_table[(input_chunk[0] >> 2) as usize]; + output_chunk[1] = self.encode_table + [((input_chunk[0] << 4 | input_chunk[1] >> 4) & LOW_SIX_BITS_U8) as usize]; + output_chunk[2] = self.encode_table + [((input_chunk[1] << 2 | input_chunk[2] >> 6) & LOW_SIX_BITS_U8) as usize]; + output_chunk[3] = self.encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize]; + + input_index += 3; + output_index += 4; + } + + if rem == 2 { + output[output_index] = self.encode_table[(input[start_of_rem] >> 2) as usize]; + output[output_index + 1] = + self.encode_table[((input[start_of_rem] << 4 | input[start_of_rem + 1] >> 4) + & LOW_SIX_BITS_U8) as usize]; + output[output_index + 2] = + self.encode_table[((input[start_of_rem + 1] << 2) & LOW_SIX_BITS_U8) as usize]; + output_index += 3; + } else if rem == 1 { + output[output_index] = self.encode_table[(input[start_of_rem] >> 2) as usize]; + output[output_index + 1] = + self.encode_table[((input[start_of_rem] << 4) & LOW_SIX_BITS_U8) as usize]; + output_index += 2; + } + + output_index + } + + fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate { + GeneralPurposeEstimate::new(input_len) + } + + fn internal_decode( + &self, + input: &[u8], + output: &mut [u8], + estimate: Self::DecodeEstimate, + ) -> Result<usize, DecodeError> { + decode::decode_helper( + input, + estimate, + output, + &self.decode_table, + self.config.decode_allow_trailing_bits, + self.config.decode_padding_mode, + ) + } + + fn config(&self) -> &Self::Config { + &self.config + } +} + +/// Returns a table mapping a 6-bit index to the ASCII byte encoding of the index +pub(crate) const fn encode_table(alphabet: &Alphabet) -> [u8; 64] { + // the encode table is just the alphabet: + // 6-bit index lookup -> printable byte + let mut encode_table = [0_u8; 64]; + { + let mut index = 0; + while index < 64 { + encode_table[index] = alphabet.symbols[index]; + index += 1; + } + } + + encode_table +} + +/// Returns a table mapping base64 bytes as the lookup index to either: +/// - [INVALID_VALUE] for bytes that aren't members of the alphabet +/// - a byte whose lower 6 bits are the value that was encoded into the index byte +pub(crate) const fn decode_table(alphabet: &Alphabet) -> [u8; 256] { + let mut decode_table = [INVALID_VALUE; 256]; + + // Since the table is full of `INVALID_VALUE` already, we only need to overwrite + // the parts that are valid. + let mut index = 0; + while index < 64 { + // The index in the alphabet is the 6-bit value we care about. + // Since the index is in 0-63, it is safe to cast to u8. + decode_table[alphabet.symbols[index] as usize] = index as u8; + index += 1; + } + + decode_table +} + +#[inline] +fn read_u64(s: &[u8]) -> u64 { + u64::from_be_bytes(s[..8].try_into().unwrap()) +} + +/// Contains configuration parameters for base64 encoding and decoding. +/// +/// ``` +/// # use base64::engine::GeneralPurposeConfig; +/// let config = GeneralPurposeConfig::new() +/// .with_encode_padding(false); +/// // further customize using `.with_*` methods as needed +/// ``` +/// +/// The constants [PAD] and [NO_PAD] cover most use cases. +/// +/// To specify the characters used, see [Alphabet]. +#[derive(Clone, Copy, Debug)] +pub struct GeneralPurposeConfig { + encode_padding: bool, + decode_allow_trailing_bits: bool, + decode_padding_mode: DecodePaddingMode, +} + +impl GeneralPurposeConfig { + /// Create a new config with `padding` = `true`, `decode_allow_trailing_bits` = `false`, and + /// `decode_padding_mode = DecodePaddingMode::RequireCanonicalPadding`. + /// + /// This probably matches most people's expectations, but consider disabling padding to save + /// a few bytes unless you specifically need it for compatibility with some legacy system. + pub const fn new() -> Self { + Self { + // RFC states that padding must be applied by default + encode_padding: true, + decode_allow_trailing_bits: false, + decode_padding_mode: DecodePaddingMode::RequireCanonical, + } + } + + /// Create a new config based on `self` with an updated `padding` setting. + /// + /// If `padding` is `true`, encoding will append either 1 or 2 `=` padding characters as needed + /// to produce an output whose length is a multiple of 4. + /// + /// Padding is not needed for correct decoding and only serves to waste bytes, but it's in the + /// [spec](https://datatracker.ietf.org/doc/html/rfc4648#section-3.2). + /// + /// For new applications, consider not using padding if the decoders you're using don't require + /// padding to be present. + pub const fn with_encode_padding(self, padding: bool) -> Self { + Self { + encode_padding: padding, + ..self + } + } + + /// Create a new config based on `self` with an updated `decode_allow_trailing_bits` setting. + /// + /// Most users will not need to configure this. It's useful if you need to decode base64 + /// produced by a buggy encoder that has bits set in the unused space on the last base64 + /// character as per [forgiving-base64 decode](https://infra.spec.whatwg.org/#forgiving-base64-decode). + /// If invalid trailing bits are present and this is `true`, those bits will + /// be silently ignored, else `DecodeError::InvalidLastSymbol` will be emitted. + pub const fn with_decode_allow_trailing_bits(self, allow: bool) -> Self { + Self { + decode_allow_trailing_bits: allow, + ..self + } + } + + /// Create a new config based on `self` with an updated `decode_padding_mode` setting. + /// + /// Padding is not useful in terms of representing encoded data -- it makes no difference to + /// the decoder if padding is present or not, so if you have some un-padded input to decode, it + /// is perfectly fine to use `DecodePaddingMode::Indifferent` to prevent errors from being + /// emitted. + /// + /// However, since in practice + /// [people who learned nothing from BER vs DER seem to expect base64 to have one canonical encoding](https://eprint.iacr.org/2022/361), + /// the default setting is the stricter `DecodePaddingMode::RequireCanonicalPadding`. + /// + /// Or, if "canonical" in your circumstance means _no_ padding rather than padding to the + /// next multiple of four, there's `DecodePaddingMode::RequireNoPadding`. + pub const fn with_decode_padding_mode(self, mode: DecodePaddingMode) -> Self { + Self { + decode_padding_mode: mode, + ..self + } + } +} + +impl Default for GeneralPurposeConfig { + /// Delegates to [GeneralPurposeConfig::new]. + fn default() -> Self { + Self::new() + } +} + +impl Config for GeneralPurposeConfig { + fn encode_padding(&self) -> bool { + self.encode_padding + } +} + +/// A [GeneralPurpose] engine using the [alphabet::STANDARD] base64 alphabet and [PAD] config. +pub const STANDARD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, PAD); + +/// A [GeneralPurpose] engine using the [alphabet::STANDARD] base64 alphabet and [NO_PAD] config. +pub const STANDARD_NO_PAD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, NO_PAD); + +/// A [GeneralPurpose] engine using the [alphabet::URL_SAFE] base64 alphabet and [PAD] config. +pub const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, PAD); + +/// A [GeneralPurpose] engine using the [alphabet::URL_SAFE] base64 alphabet and [NO_PAD] config. +pub const URL_SAFE_NO_PAD: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, NO_PAD); + +/// Include padding bytes when encoding, and require that they be present when decoding. +/// +/// This is the standard per the base64 RFC, but consider using [NO_PAD] instead as padding serves +/// little purpose in practice. +pub const PAD: GeneralPurposeConfig = GeneralPurposeConfig::new(); + +/// Don't add padding when encoding, and require no padding when decoding. +pub const NO_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new() + .with_encode_padding(false) + .with_decode_padding_mode(DecodePaddingMode::RequireNone); diff --git a/third_party/rust/base64/src/engine/mod.rs b/third_party/rust/base64/src/engine/mod.rs new file mode 100644 index 0000000000..12dfaa8845 --- /dev/null +++ b/third_party/rust/base64/src/engine/mod.rs @@ -0,0 +1,410 @@ +//! Provides the [Engine] abstraction and out of the box implementations. +#[cfg(any(feature = "alloc", feature = "std", test))] +use crate::chunked_encoder; +use crate::{ + encode::{encode_with_padding, EncodeSliceError}, + encoded_len, DecodeError, DecodeSliceError, +}; +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::vec::Vec; + +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::{string::String, vec}; + +pub mod general_purpose; + +#[cfg(test)] +mod naive; + +#[cfg(test)] +mod tests; + +pub use general_purpose::{GeneralPurpose, GeneralPurposeConfig}; + +/// An `Engine` provides low-level encoding and decoding operations that all other higher-level parts of the API use. Users of the library will generally not need to implement this. +/// +/// Different implementations offer different characteristics. The library currently ships with +/// [GeneralPurpose] that offers good speed and works on any CPU, with more choices +/// coming later, like a constant-time one when side channel resistance is called for, and vendor-specific vectorized ones for more speed. +/// +/// See [general_purpose::STANDARD_NO_PAD] if you just want standard base64. Otherwise, when possible, it's +/// recommended to store the engine in a `const` so that references to it won't pose any lifetime +/// issues, and to avoid repeating the cost of engine setup. +/// +/// Since almost nobody will need to implement `Engine`, docs for internal methods are hidden. +// When adding an implementation of Engine, include them in the engine test suite: +// - add an implementation of [engine::tests::EngineWrapper] +// - add the implementation to the `all_engines` macro +// All tests run on all engines listed in the macro. +pub trait Engine: Send + Sync { + /// The config type used by this engine + type Config: Config; + /// The decode estimate used by this engine + type DecodeEstimate: DecodeEstimate; + + /// This is not meant to be called directly; it is only for `Engine` implementors. + /// See the other `encode*` functions on this trait. + /// + /// Encode the `input` bytes into the `output` buffer based on the mapping in `encode_table`. + /// + /// `output` will be long enough to hold the encoded data. + /// + /// Returns the number of bytes written. + /// + /// No padding should be written; that is handled separately. + /// + /// Must not write any bytes into the output slice other than the encoded data. + #[doc(hidden)] + fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize; + + /// This is not meant to be called directly; it is only for `Engine` implementors. + /// + /// As an optimization to prevent the decoded length from being calculated twice, it is + /// sometimes helpful to have a conservative estimate of the decoded size before doing the + /// decoding, so this calculation is done separately and passed to [Engine::decode()] as needed. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + #[doc(hidden)] + fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate; + + /// This is not meant to be called directly; it is only for `Engine` implementors. + /// See the other `decode*` functions on this trait. + /// + /// Decode `input` base64 bytes into the `output` buffer. + /// + /// `decode_estimate` is the result of [Engine::internal_decoded_len_estimate()], which is passed in to avoid + /// calculating it again (expensive on short inputs).` + /// + /// Returns the number of bytes written to `output`. + /// + /// Each complete 4-byte chunk of encoded data decodes to 3 bytes of decoded data, but this + /// function must also handle the final possibly partial chunk. + /// If the input length is not a multiple of 4, or uses padding bytes to reach a multiple of 4, + /// the trailing 2 or 3 bytes must decode to 1 or 2 bytes, respectively, as per the + /// [RFC](https://tools.ietf.org/html/rfc4648#section-3.5). + /// + /// Decoding must not write any bytes into the output slice other than the decoded data. + /// + /// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as + /// errors unless the engine is configured otherwise. + /// + /// # Panics + /// + /// Panics if `output` is too small. + #[doc(hidden)] + fn internal_decode( + &self, + input: &[u8], + output: &mut [u8], + decode_estimate: Self::DecodeEstimate, + ) -> Result<usize, DecodeError>; + + /// Returns the config for this engine. + fn config(&self) -> &Self::Config; + + /// Encode arbitrary octets as base64 using the provided `Engine`. + /// Returns a `String`. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, engine::{self, general_purpose}, alphabet}; + /// + /// let b64 = general_purpose::STANDARD.encode(b"hello world~"); + /// println!("{}", b64); + /// + /// const CUSTOM_ENGINE: engine::GeneralPurpose = + /// engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); + /// + /// let b64_url = CUSTOM_ENGINE.encode(b"hello internet~"); + #[cfg(any(feature = "alloc", feature = "std", test))] + fn encode<T: AsRef<[u8]>>(&self, input: T) -> String { + let encoded_size = encoded_len(input.as_ref().len(), self.config().encode_padding()) + .expect("integer overflow when calculating buffer size"); + let mut buf = vec![0; encoded_size]; + + encode_with_padding(input.as_ref(), &mut buf[..], self, encoded_size); + + String::from_utf8(buf).expect("Invalid UTF8") + } + + /// Encode arbitrary octets as base64 into a supplied `String`. + /// Writes into the supplied `String`, which may allocate if its internal buffer isn't big enough. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, engine::{self, general_purpose}, alphabet}; + /// const CUSTOM_ENGINE: engine::GeneralPurpose = + /// engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); + /// + /// fn main() { + /// let mut buf = String::new(); + /// general_purpose::STANDARD.encode_string(b"hello world~", &mut buf); + /// println!("{}", buf); + /// + /// buf.clear(); + /// CUSTOM_ENGINE.encode_string(b"hello internet~", &mut buf); + /// println!("{}", buf); + /// } + /// ``` + #[cfg(any(feature = "alloc", feature = "std", test))] + fn encode_string<T: AsRef<[u8]>>(&self, input: T, output_buf: &mut String) { + let input_bytes = input.as_ref(); + + { + let mut sink = chunked_encoder::StringSink::new(output_buf); + + chunked_encoder::ChunkedEncoder::new(self) + .encode(input_bytes, &mut sink) + .expect("Writing to a String shouldn't fail"); + } + } + + /// Encode arbitrary octets as base64 into a supplied slice. + /// Writes into the supplied output buffer. + /// + /// This is useful if you wish to avoid allocation entirely (e.g. encoding into a stack-resident + /// or statically-allocated buffer). + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, engine::general_purpose}; + /// let s = b"hello internet!"; + /// let mut buf = Vec::new(); + /// // make sure we'll have a slice big enough for base64 + padding + /// buf.resize(s.len() * 4 / 3 + 4, 0); + /// + /// let bytes_written = general_purpose::STANDARD.encode_slice(s, &mut buf).unwrap(); + /// + /// // shorten our vec down to just what was written + /// buf.truncate(bytes_written); + /// + /// assert_eq!(s, general_purpose::STANDARD.decode(&buf).unwrap().as_slice()); + /// ``` + fn encode_slice<T: AsRef<[u8]>>( + &self, + input: T, + output_buf: &mut [u8], + ) -> Result<usize, EncodeSliceError> { + let input_bytes = input.as_ref(); + + let encoded_size = encoded_len(input_bytes.len(), self.config().encode_padding()) + .expect("usize overflow when calculating buffer size"); + + if output_buf.len() < encoded_size { + return Err(EncodeSliceError::OutputSliceTooSmall); + } + + let b64_output = &mut output_buf[0..encoded_size]; + + encode_with_padding(input_bytes, b64_output, self, encoded_size); + + Ok(encoded_size) + } + + /// Decode from string reference as octets using the specified [Engine]. + /// Returns a `Result` containing a `Vec<u8>`. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, alphabet, engine::{self, general_purpose}}; + /// + /// let bytes = general_purpose::STANDARD + /// .decode("aGVsbG8gd29ybGR+Cg==").unwrap(); + /// println!("{:?}", bytes); + /// + /// // custom engine setup + /// let bytes_url = engine::GeneralPurpose::new( + /// &alphabet::URL_SAFE, + /// general_purpose::NO_PAD) + /// .decode("aGVsbG8gaW50ZXJuZXR-Cg").unwrap(); + /// println!("{:?}", bytes_url); + /// ``` + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + #[cfg(any(feature = "alloc", feature = "std", test))] + fn decode<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, DecodeError> { + let input_bytes = input.as_ref(); + + let estimate = self.internal_decoded_len_estimate(input_bytes.len()); + let mut buffer = vec![0; estimate.decoded_len_estimate()]; + + let bytes_written = self.internal_decode(input_bytes, &mut buffer, estimate)?; + buffer.truncate(bytes_written); + + Ok(buffer) + } + + /// Decode from string reference as octets. + /// Writes into the supplied `Vec`, which may allocate if its internal buffer isn't big enough. + /// Returns a `Result` containing an empty tuple, aka `()`. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, alphabet, engine::{self, general_purpose}}; + /// const CUSTOM_ENGINE: engine::GeneralPurpose = + /// engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::PAD); + /// + /// fn main() { + /// use base64::Engine; + /// let mut buffer = Vec::<u8>::new(); + /// // with the default engine + /// general_purpose::STANDARD + /// .decode_vec("aGVsbG8gd29ybGR+Cg==", &mut buffer,).unwrap(); + /// println!("{:?}", buffer); + /// + /// buffer.clear(); + /// + /// // with a custom engine + /// CUSTOM_ENGINE.decode_vec( + /// "aGVsbG8gaW50ZXJuZXR-Cg==", + /// &mut buffer, + /// ).unwrap(); + /// println!("{:?}", buffer); + /// } + /// ``` + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + #[cfg(any(feature = "alloc", feature = "std", test))] + fn decode_vec<T: AsRef<[u8]>>( + &self, + input: T, + buffer: &mut Vec<u8>, + ) -> Result<(), DecodeError> { + let input_bytes = input.as_ref(); + + let starting_output_len = buffer.len(); + + let estimate = self.internal_decoded_len_estimate(input_bytes.len()); + let total_len_estimate = estimate + .decoded_len_estimate() + .checked_add(starting_output_len) + .expect("Overflow when calculating output buffer length"); + buffer.resize(total_len_estimate, 0); + + let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..]; + let bytes_written = self.internal_decode(input_bytes, buffer_slice, estimate)?; + + buffer.truncate(starting_output_len + bytes_written); + + Ok(()) + } + + /// Decode the input into the provided output slice. + /// + /// Returns an error if `output` is smaller than the estimated decoded length. + /// + /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). + /// + /// See [crate::decoded_len_estimate] for calculating buffer sizes. + /// + /// See [Engine::decode_slice_unchecked] for a version that panics instead of returning an error + /// if the output buffer is too small. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + fn decode_slice<T: AsRef<[u8]>>( + &self, + input: T, + output: &mut [u8], + ) -> Result<usize, DecodeSliceError> { + let input_bytes = input.as_ref(); + + let estimate = self.internal_decoded_len_estimate(input_bytes.len()); + if output.len() < estimate.decoded_len_estimate() { + return Err(DecodeSliceError::OutputSliceTooSmall); + } + + self.internal_decode(input_bytes, output, estimate) + .map_err(|e| e.into()) + } + + /// Decode the input into the provided output slice. + /// + /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). + /// + /// See [crate::decoded_len_estimate] for calculating buffer sizes. + /// + /// See [Engine::decode_slice] for a version that returns an error instead of panicking if the output + /// buffer is too small. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + /// + /// Panics if the provided output buffer is too small for the decoded data. + fn decode_slice_unchecked<T: AsRef<[u8]>>( + &self, + input: T, + output: &mut [u8], + ) -> Result<usize, DecodeError> { + let input_bytes = input.as_ref(); + + self.internal_decode( + input_bytes, + output, + self.internal_decoded_len_estimate(input_bytes.len()), + ) + } +} + +/// The minimal level of configuration that engines must support. +pub trait Config { + /// Returns `true` if padding should be added after the encoded output. + /// + /// Padding is added outside the engine's encode() since the engine may be used + /// to encode only a chunk of the overall output, so it can't always know when + /// the output is "done" and would therefore need padding (if configured). + // It could be provided as a separate parameter when encoding, but that feels like + // leaking an implementation detail to the user, and it's hopefully more convenient + // to have to only pass one thing (the engine) to any part of the API. + fn encode_padding(&self) -> bool; +} + +/// The decode estimate used by an engine implementation. Users do not need to interact with this; +/// it is only for engine implementors. +/// +/// Implementors may store relevant data here when constructing this to avoid having to calculate +/// them again during actual decoding. +pub trait DecodeEstimate { + /// Returns a conservative (err on the side of too big) estimate of the decoded length to use + /// for pre-allocating buffers, etc. + /// + /// The estimate must be no larger than the next largest complete triple of decoded bytes. + /// That is, the final quad of tokens to decode may be assumed to be complete with no padding. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + fn decoded_len_estimate(&self) -> usize; +} + +/// Controls how pad bytes are handled when decoding. +/// +/// Each [Engine] must support at least the behavior indicated by +/// [DecodePaddingMode::RequireCanonical], and may support other modes. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DecodePaddingMode { + /// Canonical padding is allowed, but any fewer padding bytes than that is also allowed. + Indifferent, + /// Padding must be canonical (0, 1, or 2 `=` as needed to produce a 4 byte suffix). + RequireCanonical, + /// Padding must be absent -- for when you want predictable padding, without any wasted bytes. + RequireNone, +} diff --git a/third_party/rust/base64/src/engine/naive.rs b/third_party/rust/base64/src/engine/naive.rs new file mode 100644 index 0000000000..6665c5eb41 --- /dev/null +++ b/third_party/rust/base64/src/engine/naive.rs @@ -0,0 +1,219 @@ +use crate::{ + alphabet::Alphabet, + engine::{ + general_purpose::{self, decode_table, encode_table}, + Config, DecodeEstimate, DecodePaddingMode, Engine, + }, + DecodeError, PAD_BYTE, +}; +use alloc::ops::BitOr; +use std::ops::{BitAnd, Shl, Shr}; + +/// Comparatively simple implementation that can be used as something to compare against in tests +pub struct Naive { + encode_table: [u8; 64], + decode_table: [u8; 256], + config: NaiveConfig, +} + +impl Naive { + const ENCODE_INPUT_CHUNK_SIZE: usize = 3; + const DECODE_INPUT_CHUNK_SIZE: usize = 4; + + pub const fn new(alphabet: &Alphabet, config: NaiveConfig) -> Self { + Self { + encode_table: encode_table(alphabet), + decode_table: decode_table(alphabet), + config, + } + } + + fn decode_byte_into_u32(&self, offset: usize, byte: u8) -> Result<u32, DecodeError> { + let decoded = self.decode_table[byte as usize]; + + if decoded == general_purpose::INVALID_VALUE { + return Err(DecodeError::InvalidByte(offset, byte)); + } + + Ok(decoded as u32) + } +} + +impl Engine for Naive { + type Config = NaiveConfig; + type DecodeEstimate = NaiveEstimate; + + fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize { + // complete chunks first + + const LOW_SIX_BITS: u32 = 0x3F; + + let rem = input.len() % Self::ENCODE_INPUT_CHUNK_SIZE; + // will never underflow + let complete_chunk_len = input.len() - rem; + + let mut input_index = 0_usize; + let mut output_index = 0_usize; + if let Some(last_complete_chunk_index) = + complete_chunk_len.checked_sub(Self::ENCODE_INPUT_CHUNK_SIZE) + { + while input_index <= last_complete_chunk_index { + let chunk = &input[input_index..input_index + Self::ENCODE_INPUT_CHUNK_SIZE]; + + // populate low 24 bits from 3 bytes + let chunk_int: u32 = + (chunk[0] as u32).shl(16) | (chunk[1] as u32).shl(8) | (chunk[2] as u32); + // encode 4x 6-bit output bytes + output[output_index] = self.encode_table[chunk_int.shr(18) as usize]; + output[output_index + 1] = + self.encode_table[chunk_int.shr(12_u8).bitand(LOW_SIX_BITS) as usize]; + output[output_index + 2] = + self.encode_table[chunk_int.shr(6_u8).bitand(LOW_SIX_BITS) as usize]; + output[output_index + 3] = + self.encode_table[chunk_int.bitand(LOW_SIX_BITS) as usize]; + + input_index += Self::ENCODE_INPUT_CHUNK_SIZE; + output_index += 4; + } + } + + // then leftovers + if rem == 2 { + let chunk = &input[input_index..input_index + 2]; + + // high six bits of chunk[0] + output[output_index] = self.encode_table[chunk[0].shr(2) as usize]; + // bottom 2 bits of [0], high 4 bits of [1] + output[output_index + 1] = + self.encode_table[(chunk[0].shl(4_u8).bitor(chunk[1].shr(4_u8)) as u32) + .bitand(LOW_SIX_BITS) as usize]; + // bottom 4 bits of [1], with the 2 bottom bits as zero + output[output_index + 2] = + self.encode_table[(chunk[1].shl(2_u8) as u32).bitand(LOW_SIX_BITS) as usize]; + + output_index += 3; + } else if rem == 1 { + let byte = input[input_index]; + output[output_index] = self.encode_table[byte.shr(2) as usize]; + output[output_index + 1] = + self.encode_table[(byte.shl(4_u8) as u32).bitand(LOW_SIX_BITS) as usize]; + output_index += 2; + } + + output_index + } + + fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate { + NaiveEstimate::new(input_len) + } + + fn internal_decode( + &self, + input: &[u8], + output: &mut [u8], + estimate: Self::DecodeEstimate, + ) -> Result<usize, DecodeError> { + if estimate.rem == 1 { + // trailing whitespace is so common that it's worth it to check the last byte to + // possibly return a better error message + if let Some(b) = input.last() { + if *b != PAD_BYTE + && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE + { + return Err(DecodeError::InvalidByte(input.len() - 1, *b)); + } + } + + return Err(DecodeError::InvalidLength); + } + + let mut input_index = 0_usize; + let mut output_index = 0_usize; + const BOTTOM_BYTE: u32 = 0xFF; + + // can only use the main loop on non-trailing chunks + if input.len() > Self::DECODE_INPUT_CHUNK_SIZE { + // skip the last chunk, whether it's partial or full, since it might + // have padding, and start at the beginning of the chunk before that + let last_complete_chunk_start_index = estimate.complete_chunk_len + - if estimate.rem == 0 { + // Trailing chunk is also full chunk, so there must be at least 2 chunks, and + // this won't underflow + Self::DECODE_INPUT_CHUNK_SIZE * 2 + } else { + // Trailing chunk is partial, so it's already excluded in + // complete_chunk_len + Self::DECODE_INPUT_CHUNK_SIZE + }; + + while input_index <= last_complete_chunk_start_index { + let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE]; + let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) + | self + .decode_byte_into_u32(input_index + 1, chunk[1])? + .shl(12) + | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) + | self.decode_byte_into_u32(input_index + 3, chunk[3])?; + + output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; + + input_index += Self::DECODE_INPUT_CHUNK_SIZE; + output_index += 3; + } + } + + general_purpose::decode_suffix::decode_suffix( + input, + input_index, + output, + output_index, + &self.decode_table, + self.config.decode_allow_trailing_bits, + self.config.decode_padding_mode, + ) + } + + fn config(&self) -> &Self::Config { + &self.config + } +} + +pub struct NaiveEstimate { + /// remainder from dividing input by `Naive::DECODE_CHUNK_SIZE` + rem: usize, + /// Length of input that is in complete `Naive::DECODE_CHUNK_SIZE`-length chunks + complete_chunk_len: usize, +} + +impl NaiveEstimate { + fn new(input_len: usize) -> Self { + let rem = input_len % Naive::DECODE_INPUT_CHUNK_SIZE; + let complete_chunk_len = input_len - rem; + + Self { + rem, + complete_chunk_len, + } + } +} + +impl DecodeEstimate for NaiveEstimate { + fn decoded_len_estimate(&self) -> usize { + ((self.complete_chunk_len / 4) + ((self.rem > 0) as usize)) * 3 + } +} + +#[derive(Clone, Copy, Debug)] +pub struct NaiveConfig { + pub encode_padding: bool, + pub decode_allow_trailing_bits: bool, + pub decode_padding_mode: DecodePaddingMode, +} + +impl Config for NaiveConfig { + fn encode_padding(&self) -> bool { + self.encode_padding + } +} diff --git a/third_party/rust/base64/src/engine/tests.rs b/third_party/rust/base64/src/engine/tests.rs new file mode 100644 index 0000000000..906bba04d8 --- /dev/null +++ b/third_party/rust/base64/src/engine/tests.rs @@ -0,0 +1,1430 @@ +// rstest_reuse template functions have unused variables +#![allow(unused_variables)] + +use rand::{ + self, + distributions::{self, Distribution as _}, + rngs, Rng as _, SeedableRng as _, +}; +use rstest::rstest; +use rstest_reuse::{apply, template}; +use std::{collections, fmt}; + +use crate::{ + alphabet::{Alphabet, STANDARD}, + encode::add_padding, + encoded_len, + engine::{general_purpose, naive, Config, DecodeEstimate, DecodePaddingMode, Engine}, + tests::{assert_encode_sanity, random_alphabet, random_config}, + DecodeError, PAD_BYTE, +}; + +// the case::foo syntax includes the "foo" in the generated test method names +#[template] +#[rstest(engine_wrapper, +case::general_purpose(GeneralPurposeWrapper {}), +case::naive(NaiveWrapper {}), +)] +fn all_engines<E: EngineWrapper>(engine_wrapper: E) {} + +#[apply(all_engines)] +fn rfc_test_vectors_std_alphabet<E: EngineWrapper>(engine_wrapper: E) { + let data = vec![ + ("", ""), + ("f", "Zg=="), + ("fo", "Zm8="), + ("foo", "Zm9v"), + ("foob", "Zm9vYg=="), + ("fooba", "Zm9vYmE="), + ("foobar", "Zm9vYmFy"), + ]; + + let engine = E::standard(); + let engine_no_padding = E::standard_unpadded(); + + for (orig, encoded) in &data { + let encoded_without_padding = encoded.trim_end_matches('='); + + // unpadded + { + let mut encode_buf = [0_u8; 8]; + let mut decode_buf = [0_u8; 6]; + + let encode_len = + engine_no_padding.internal_encode(orig.as_bytes(), &mut encode_buf[..]); + assert_eq!( + &encoded_without_padding, + &std::str::from_utf8(&encode_buf[0..encode_len]).unwrap() + ); + let decode_len = engine_no_padding + .decode_slice_unchecked(encoded_without_padding.as_bytes(), &mut decode_buf[..]) + .unwrap(); + assert_eq!(orig.len(), decode_len); + + assert_eq!( + orig, + &std::str::from_utf8(&decode_buf[0..decode_len]).unwrap() + ); + + // if there was any padding originally, the no padding engine won't decode it + if encoded.as_bytes().contains(&PAD_BYTE) { + assert_eq!( + Err(DecodeError::InvalidPadding), + engine_no_padding.decode(encoded) + ) + } + } + + // padded + { + let mut encode_buf = [0_u8; 8]; + let mut decode_buf = [0_u8; 6]; + + let encode_len = engine.internal_encode(orig.as_bytes(), &mut encode_buf[..]); + assert_eq!( + // doesn't have padding added yet + &encoded_without_padding, + &std::str::from_utf8(&encode_buf[0..encode_len]).unwrap() + ); + let pad_len = add_padding(orig.len(), &mut encode_buf[encode_len..]); + assert_eq!(encoded.as_bytes(), &encode_buf[..encode_len + pad_len]); + + let decode_len = engine + .decode_slice_unchecked(encoded.as_bytes(), &mut decode_buf[..]) + .unwrap(); + assert_eq!(orig.len(), decode_len); + + assert_eq!( + orig, + &std::str::from_utf8(&decode_buf[0..decode_len]).unwrap() + ); + + // if there was (canonical) padding, and we remove it, the standard engine won't decode + if encoded.as_bytes().contains(&PAD_BYTE) { + assert_eq!( + Err(DecodeError::InvalidPadding), + engine.decode(encoded_without_padding) + ) + } + } + } +} + +#[apply(all_engines)] +fn roundtrip_random<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut decode_buf = Vec::<u8>::new(); + + let len_range = distributions::Uniform::new(1, 1_000); + + for _ in 0..10_000 { + let engine = E::random(&mut rng); + + orig_data.clear(); + encode_buf.clear(); + decode_buf.clear(); + + let (orig_len, _, encoded_len) = generate_random_encoded_data( + &engine, + &mut orig_data, + &mut encode_buf, + &mut rng, + &len_range, + ); + + // exactly the right size + decode_buf.resize(orig_len, 0); + + let dec_len = engine + .decode_slice_unchecked(&encode_buf[0..encoded_len], &mut decode_buf[..]) + .unwrap(); + + assert_eq!(orig_len, dec_len); + assert_eq!(&orig_data[..], &decode_buf[..dec_len]); + } +} + +#[apply(all_engines)] +fn encode_doesnt_write_extra_bytes<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut encode_buf_backup = Vec::<u8>::new(); + + let input_len_range = distributions::Uniform::new(0, 1000); + + for _ in 0..10_000 { + let engine = E::random(&mut rng); + let padded = engine.config().encode_padding(); + + orig_data.clear(); + encode_buf.clear(); + encode_buf_backup.clear(); + + let orig_len = fill_rand(&mut orig_data, &mut rng, &input_len_range); + + let prefix_len = 1024; + // plenty of prefix and suffix + fill_rand_len(&mut encode_buf, &mut rng, prefix_len * 2 + orig_len * 2); + encode_buf_backup.extend_from_slice(&encode_buf[..]); + + let expected_encode_len_no_pad = encoded_len(orig_len, false).unwrap(); + + let encoded_len_no_pad = + engine.internal_encode(&orig_data[..], &mut encode_buf[prefix_len..]); + assert_eq!(expected_encode_len_no_pad, encoded_len_no_pad); + + // no writes past what it claimed to write + assert_eq!(&encode_buf_backup[..prefix_len], &encode_buf[..prefix_len]); + assert_eq!( + &encode_buf_backup[(prefix_len + encoded_len_no_pad)..], + &encode_buf[(prefix_len + encoded_len_no_pad)..] + ); + + let encoded_data = &encode_buf[prefix_len..(prefix_len + encoded_len_no_pad)]; + assert_encode_sanity( + std::str::from_utf8(encoded_data).unwrap(), + // engines don't pad + false, + orig_len, + ); + + // pad so we can decode it in case our random engine requires padding + let pad_len = if padded { + add_padding(orig_len, &mut encode_buf[prefix_len + encoded_len_no_pad..]) + } else { + 0 + }; + + assert_eq!( + orig_data, + engine + .decode(&encode_buf[prefix_len..(prefix_len + encoded_len_no_pad + pad_len)],) + .unwrap() + ); + } +} + +#[apply(all_engines)] +fn encode_engine_slice_fits_into_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E) { + let mut orig_data = Vec::new(); + let mut encoded_data = Vec::new(); + let mut decoded = Vec::new(); + + let input_len_range = distributions::Uniform::new(0, 1000); + + let mut rng = rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + decoded.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let engine = E::random(&mut rng); + + let encoded_size = encoded_len(input_len, engine.config().encode_padding()).unwrap(); + + encoded_data.resize(encoded_size, 0); + + assert_eq!( + encoded_size, + engine.encode_slice(&orig_data, &mut encoded_data).unwrap() + ); + + assert_encode_sanity( + std::str::from_utf8(&encoded_data[0..encoded_size]).unwrap(), + engine.config().encode_padding(), + input_len, + ); + + engine + .decode_vec(&encoded_data[0..encoded_size], &mut decoded) + .unwrap(); + assert_eq!(orig_data, decoded); + } +} + +#[apply(all_engines)] +fn decode_doesnt_write_extra_bytes<E>(engine_wrapper: E) +where + E: EngineWrapper, + <<E as EngineWrapper>::Engine as Engine>::Config: fmt::Debug, +{ + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut decode_buf = Vec::<u8>::new(); + let mut decode_buf_backup = Vec::<u8>::new(); + + let len_range = distributions::Uniform::new(1, 1_000); + + for _ in 0..10_000 { + let engine = E::random(&mut rng); + + orig_data.clear(); + encode_buf.clear(); + decode_buf.clear(); + decode_buf_backup.clear(); + + let orig_len = fill_rand(&mut orig_data, &mut rng, &len_range); + encode_buf.resize(orig_len * 2 + 100, 0); + + let encoded_len = engine + .encode_slice(&orig_data[..], &mut encode_buf[..]) + .unwrap(); + encode_buf.truncate(encoded_len); + + // oversize decode buffer so we can easily tell if it writes anything more than + // just the decoded data + let prefix_len = 1024; + // plenty of prefix and suffix + fill_rand_len(&mut decode_buf, &mut rng, prefix_len * 2 + orig_len * 2); + decode_buf_backup.extend_from_slice(&decode_buf[..]); + + let dec_len = engine + .decode_slice_unchecked(&encode_buf, &mut decode_buf[prefix_len..]) + .unwrap(); + + assert_eq!(orig_len, dec_len); + assert_eq!( + &orig_data[..], + &decode_buf[prefix_len..prefix_len + dec_len] + ); + assert_eq!(&decode_buf_backup[..prefix_len], &decode_buf[..prefix_len]); + assert_eq!( + &decode_buf_backup[prefix_len + dec_len..], + &decode_buf[prefix_len + dec_len..] + ); + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol<E: EngineWrapper>(engine_wrapper: E) { + // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol + let engine = E::standard(); + + assert_eq!(Ok(vec![0x89, 0x85]), engine.decode("iYU=")); + assert_eq!(Ok(vec![0xFF]), engine.decode("/w==")); + + for (suffix, offset) in vec![ + // suffix, offset of bad byte from start of suffix + ("/x==", 1_usize), + ("/z==", 1_usize), + ("/0==", 1_usize), + ("/9==", 1_usize), + ("/+==", 1_usize), + ("//==", 1_usize), + // trailing 01 + ("iYV=", 2_usize), + // trailing 10 + ("iYW=", 2_usize), + // trailing 11 + ("iYX=", 2_usize), + ] { + for prefix_quads in 0..256 { + let mut encoded = "AAAA".repeat(prefix_quads); + encoded.push_str(suffix); + + assert_eq!( + Err(DecodeError::InvalidLastSymbol( + encoded.len() - 4 + offset, + suffix.as_bytes()[offset], + )), + engine.decode(encoded.as_str()) + ); + } + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol_when_length_is_also_invalid<E: EngineWrapper>( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + + // check across enough lengths that it would likely cover any implementation's various internal + // small/large input division + for len in (0_usize..256).map(|len| len * 4 + 1) { + let engine = E::random_alphabet(&mut rng, &STANDARD); + + let mut input = vec![b'A'; len]; + + // with a valid last char, it's InvalidLength + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input)); + // after mangling the last char, it's InvalidByte + input[len - 1] = b'"'; + assert_eq!( + Err(DecodeError::InvalidByte(len - 1, b'"')), + engine.decode(&input) + ); + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol_every_possible_two_symbols<E: EngineWrapper>( + engine_wrapper: E, +) { + let engine = E::standard(); + + let mut base64_to_bytes = collections::HashMap::new(); + + for b in 0_u8..=255 { + let mut b64 = vec![0_u8; 4]; + assert_eq!(2, engine.internal_encode(&[b], &mut b64[..])); + let _ = add_padding(1, &mut b64[2..]); + + assert!(base64_to_bytes.insert(b64, vec![b]).is_none()); + } + + // every possible combination of trailing symbols must either decode to 1 byte or get InvalidLastSymbol, with or without any leading chunks + + let mut prefix = Vec::new(); + for _ in 0..256 { + let mut clone = prefix.clone(); + + let mut symbols = [0_u8; 4]; + for &s1 in STANDARD.symbols.iter() { + symbols[0] = s1; + for &s2 in STANDARD.symbols.iter() { + symbols[1] = s2; + symbols[2] = PAD_BYTE; + symbols[3] = PAD_BYTE; + + // chop off previous symbols + clone.truncate(prefix.len()); + clone.extend_from_slice(&symbols[..]); + let decoded_prefix_len = prefix.len() / 4 * 3; + + match base64_to_bytes.get(&symbols[..]) { + Some(bytes) => { + let res = engine + .decode(&clone) + // remove prefix + .map(|decoded| decoded[decoded_prefix_len..].to_vec()); + + assert_eq!(Ok(bytes.clone()), res); + } + None => assert_eq!( + Err(DecodeError::InvalidLastSymbol(1, s2)), + engine.decode(&symbols[..]) + ), + } + } + } + + prefix.extend_from_slice(b"AAAA"); + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol_every_possible_three_symbols<E: EngineWrapper>( + engine_wrapper: E, +) { + let engine = E::standard(); + + let mut base64_to_bytes = collections::HashMap::new(); + + let mut bytes = [0_u8; 2]; + for b1 in 0_u8..=255 { + bytes[0] = b1; + for b2 in 0_u8..=255 { + bytes[1] = b2; + let mut b64 = vec![0_u8; 4]; + assert_eq!(3, engine.internal_encode(&bytes, &mut b64[..])); + let _ = add_padding(2, &mut b64[3..]); + + let mut v = Vec::with_capacity(2); + v.extend_from_slice(&bytes[..]); + + assert!(base64_to_bytes.insert(b64, v).is_none()); + } + } + + // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol, with or without any leading chunks + + let mut prefix = Vec::new(); + for _ in 0..256 { + let mut input = prefix.clone(); + + let mut symbols = [0_u8; 4]; + for &s1 in STANDARD.symbols.iter() { + symbols[0] = s1; + for &s2 in STANDARD.symbols.iter() { + symbols[1] = s2; + for &s3 in STANDARD.symbols.iter() { + symbols[2] = s3; + symbols[3] = PAD_BYTE; + + // chop off previous symbols + input.truncate(prefix.len()); + input.extend_from_slice(&symbols[..]); + let decoded_prefix_len = prefix.len() / 4 * 3; + + match base64_to_bytes.get(&symbols[..]) { + Some(bytes) => { + let res = engine + .decode(&input) + // remove prefix + .map(|decoded| decoded[decoded_prefix_len..].to_vec()); + + assert_eq!(Ok(bytes.clone()), res); + } + None => assert_eq!( + Err(DecodeError::InvalidLastSymbol(2, s3)), + engine.decode(&symbols[..]) + ), + } + } + } + } + prefix.extend_from_slice(b"AAAA"); + } +} + +#[apply(all_engines)] +fn decode_invalid_trailing_bits_ignored_when_configured<E: EngineWrapper>(engine_wrapper: E) { + let strict = E::standard(); + let forgiving = E::standard_allow_trailing_bits(); + + fn assert_tolerant_decode<E: Engine>( + engine: &E, + input: &mut String, + b64_prefix_len: usize, + expected_decode_bytes: Vec<u8>, + data: &str, + ) { + let prefixed = prefixed_data(input, b64_prefix_len, data); + let decoded = engine.decode(prefixed); + // prefix is always complete chunks + let decoded_prefix_len = b64_prefix_len / 4 * 3; + assert_eq!( + Ok(expected_decode_bytes), + decoded.map(|v| v[decoded_prefix_len..].to_vec()) + ); + } + + let mut prefix = String::new(); + for _ in 0..256 { + let mut input = prefix.clone(); + + // example from https://github.com/marshallpierce/rust-base64/issues/75 + assert!(strict + .decode(prefixed_data(&mut input, prefix.len(), "/w==")) + .is_ok()); + assert!(strict + .decode(prefixed_data(&mut input, prefix.len(), "iYU=")) + .is_ok()); + // trailing 01 + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/x=="); + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYV="); + // trailing 10 + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/y=="); + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYW="); + // trailing 11 + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/z=="); + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYX="); + + prefix.push_str("AAAA"); + } +} + +#[apply(all_engines)] +fn decode_invalid_byte_error<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut decode_buf = Vec::<u8>::new(); + + let len_range = distributions::Uniform::new(1, 1_000); + + for _ in 0..10_000 { + let alphabet = random_alphabet(&mut rng); + let engine = E::random_alphabet(&mut rng, alphabet); + + orig_data.clear(); + encode_buf.clear(); + decode_buf.clear(); + + let (orig_len, encoded_len_just_data, encoded_len_with_padding) = + generate_random_encoded_data( + &engine, + &mut orig_data, + &mut encode_buf, + &mut rng, + &len_range, + ); + + // exactly the right size + decode_buf.resize(orig_len, 0); + + // replace one encoded byte with an invalid byte + let invalid_byte: u8 = loop { + let byte: u8 = rng.gen(); + + if alphabet.symbols.contains(&byte) { + continue; + } else { + break byte; + } + }; + + let invalid_range = distributions::Uniform::new(0, orig_len); + let invalid_index = invalid_range.sample(&mut rng); + encode_buf[invalid_index] = invalid_byte; + + assert_eq!( + Err(DecodeError::InvalidByte(invalid_index, invalid_byte)), + engine.decode_slice_unchecked( + &encode_buf[0..encoded_len_with_padding], + &mut decode_buf[..], + ) + ); + } +} + +/// Any amount of padding anywhere before the final non padding character = invalid byte at first +/// pad byte. +/// From this, we know padding must extend to the end of the input. +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte<E: EngineWrapper>( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + + // the different amounts of proper padding, w/ offset from end for the last non-padding char + let suffixes = vec![("/w==", 2), ("iYu=", 1), ("zzzz", 0)]; + + let prefix_quads_range = distributions::Uniform::from(0..=256); + + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for _ in 0..100_000 { + for (suffix, offset) in suffixes.iter() { + let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng)); + s.push_str(suffix); + let mut encoded = s.into_bytes(); + + // calculate a range to write padding into that leaves at least one non padding char + let last_non_padding_offset = encoded.len() - 1 - offset; + + // don't include last non padding char as it must stay not padding + let padding_end = rng.gen_range(0..last_non_padding_offset); + + // don't use more than 100 bytes of padding, but also use shorter lengths when + // padding_end is near the start of the encoded data to avoid biasing to padding + // the entire prefix on short lengths + let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); + let padding_start = padding_end.saturating_sub(padding_len); + + encoded[padding_start..=padding_end].fill(PAD_BYTE); + + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + ); + } + } + } +} + +/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes = +/// invalid byte at first pad byte (except for 1 byte suffix = invalid length). +/// From this we know the padding must start in the final chunk. +#[apply(all_engines)] +fn decode_padding_starts_before_final_chunk_error_invalid_byte<E: EngineWrapper>( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + + // must have at least one prefix quad + let prefix_quads_range = distributions::Uniform::from(1..256); + // including 1 just to make sure that it really does produce invalid length + let suffix_pad_len_range = distributions::Uniform::from(1..=4); + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + for _ in 0..100_000 { + let suffix_len = suffix_pad_len_range.sample(&mut rng); + let mut encoded = "ABCD" + .repeat(prefix_quads_range.sample(&mut rng)) + .into_bytes(); + encoded.resize(encoded.len() + suffix_len, PAD_BYTE); + + // amount of padding must be long enough to extend back from suffix into previous + // quads + let padding_len = rng.gen_range(suffix_len + 1..encoded.len()); + // no non-padding after padding in this test, so padding goes to the end + let padding_start = encoded.len() - padding_len; + encoded[padding_start..].fill(PAD_BYTE); + + if suffix_len == 1 { + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); + } else { + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + ); + } + } + } +} + +/// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding +/// is not valid data (consistent with error for pad bytes in earlier chunks). +/// From this we know there must be 2-3 bytes of data before padding +#[apply(all_engines)] +fn decode_too_little_data_before_padding_error_invalid_byte<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + // want to test no prefix quad case, so start at 0 + let prefix_quads_range = distributions::Uniform::from(0_usize..256); + let suffix_data_len_range = distributions::Uniform::from(0_usize..=1); + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + for _ in 0..100_000 { + let suffix_data_len = suffix_data_len_range.sample(&mut rng); + let prefix_quad_len = prefix_quads_range.sample(&mut rng); + + // ensure there is a suffix quad + let min_padding = usize::from(suffix_data_len == 0); + + // for all possible padding lengths + for padding_len in min_padding..=(4 - suffix_data_len) { + let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes(); + encoded.resize(encoded.len() + suffix_data_len, b'A'); + encoded.resize(encoded.len() + padding_len, PAD_BYTE); + + if suffix_data_len + padding_len == 1 { + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); + } else { + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.decode(&encoded), + "suffix data len {} pad len {}", + suffix_data_len, + padding_len + ); + } + } + } + } +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 1 +#[apply(all_engines)] +fn decode_malleability_test_case_3_byte_suffix_valid<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + b"Hello".as_slice(), + &E::standard().decode("SGVsbG8=").unwrap() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 2 +#[apply(all_engines)] +fn decode_malleability_test_case_3_byte_suffix_invalid_trailing_symbol<E: EngineWrapper>( + engine_wrapper: E, +) { + assert_eq!( + DecodeError::InvalidLastSymbol(6, 0x39), + E::standard().decode("SGVsbG9=").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 3 +#[apply(all_engines)] +fn decode_malleability_test_case_3_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + DecodeError::InvalidPadding, + E::standard().decode("SGVsbG9").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 4 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_valid_two_padding_symbols<E: EngineWrapper>( + engine_wrapper: E, +) { + assert_eq!( + b"Hell".as_slice(), + &E::standard().decode("SGVsbA==").unwrap() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 5 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_short_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + DecodeError::InvalidPadding, + E::standard().decode("SGVsbA=").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 6 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + DecodeError::InvalidPadding, + E::standard().decode("SGVsbA").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 7 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_too_much_padding<E: EngineWrapper>( + engine_wrapper: E, +) { + assert_eq!( + DecodeError::InvalidByte(6, PAD_BYTE), + E::standard().decode("SGVsbA====").unwrap_err() + ); +} + +/// Requires canonical padding -> accepts 2 + 2, 3 + 1, 4 + 0 final quad configurations +#[apply(all_engines)] +fn decode_pad_mode_requires_canonical_accepts_canonical<E: EngineWrapper>(engine_wrapper: E) { + assert_all_suffixes_ok( + E::standard_with_pad_mode(true, DecodePaddingMode::RequireCanonical), + vec!["/w==", "iYU=", "AAAA"], + ); +} + +/// Requires canonical padding -> rejects 2 + 0-1, 3 + 0 final chunk configurations +#[apply(all_engines)] +fn decode_pad_mode_requires_canonical_rejects_non_canonical<E: EngineWrapper>(engine_wrapper: E) { + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::RequireCanonical); + + let suffixes = vec!["/w", "/w=", "iYU"]; + for num_prefix_quads in 0..256 { + for &suffix in suffixes.iter() { + let mut encoded = "AAAA".repeat(num_prefix_quads); + encoded.push_str(suffix); + + let res = engine.decode(&encoded); + + assert_eq!(Err(DecodeError::InvalidPadding), res); + } + } +} + +/// Requires no padding -> accepts 2 + 0, 3 + 0, 4 + 0 final chunk configuration +#[apply(all_engines)] +fn decode_pad_mode_requires_no_padding_accepts_no_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_all_suffixes_ok( + E::standard_with_pad_mode(true, DecodePaddingMode::RequireNone), + vec!["/w", "iYU", "AAAA"], + ); +} + +/// Requires no padding -> rejects 2 + 1-2, 3 + 1 final chunk configuration +#[apply(all_engines)] +fn decode_pad_mode_requires_no_padding_rejects_any_padding<E: EngineWrapper>(engine_wrapper: E) { + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::RequireNone); + + let suffixes = vec!["/w=", "/w==", "iYU="]; + for num_prefix_quads in 0..256 { + for &suffix in suffixes.iter() { + let mut encoded = "AAAA".repeat(num_prefix_quads); + encoded.push_str(suffix); + + let res = engine.decode(&encoded); + + assert_eq!(Err(DecodeError::InvalidPadding), res); + } + } +} + +/// Indifferent padding accepts 2 + 0-2, 3 + 0-1, 4 + 0 final chunk configuration +#[apply(all_engines)] +fn decode_pad_mode_indifferent_padding_accepts_anything<E: EngineWrapper>(engine_wrapper: E) { + assert_all_suffixes_ok( + E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent), + vec!["/w", "/w=", "/w==", "iYU", "iYU=", "AAAA"], + ); +} + +//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3 +#[apply(all_engines)] +fn decode_pad_byte_in_penultimate_quad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // leave room for at least one pad byte in penultimate quad + for num_valid_bytes_penultimate_quad in 0..4 { + // can't have 1 or it would be invalid length + for num_pad_bytes_in_final_quad in 2..=4 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + + // varying amounts of padding in the penultimate quad + for _ in 0..num_valid_bytes_penultimate_quad { + s.push('A'); + } + // finish penultimate quad with padding + for _ in num_valid_bytes_penultimate_quad..4 { + s.push('='); + } + // and more padding in the final quad + for _ in 0..num_pad_bytes_in_final_quad { + s.push('='); + } + + // padding should be an invalid byte before the final quad. + // Could argue that the *next* padding byte (in the next quad) is technically the first + // erroneous one, but reporting that accurately is more complex and probably nobody cares + assert_eq!( + DecodeError::InvalidByte( + num_prefix_quads * 4 + num_valid_bytes_penultimate_quad, + b'=', + ), + engine.decode(&s).unwrap_err() + ); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_bytes_after_padding_in_final_quad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // leave at least one byte in the quad for padding + for bytes_after_padding in 1..4 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + + // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding + for _ in 0..(3 - bytes_after_padding) { + s.push('A'); + } + s.push('='); + for _ in 0..bytes_after_padding { + s.push('A'); + } + + // First (and only) padding byte is invalid. + assert_eq!( + DecodeError::InvalidByte( + num_prefix_quads * 4 + (3 - bytes_after_padding), + b'=' + ), + engine.decode(&s).unwrap_err() + ); + } + } + } +} + +#[apply(all_engines)] +fn decode_absurd_pad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("==Y=Wx===pY=2U====="); + + // first padding byte + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4, b'='), + engine.decode(&s).unwrap_err() + ); + } + } +} + +#[apply(all_engines)] +fn decode_too_much_padding_returns_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // add enough padding to ensure that we'll hit all decode stages at the different lengths + for pad_bytes in 1..=64 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + let padding: String = "=".repeat(pad_bytes); + s.push_str(&padding); + + if pad_bytes % 4 == 1 { + assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); + } else { + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4, b'='), + engine.decode(&s).unwrap_err() + ); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_padding_followed_by_non_padding_returns_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + for pad_bytes in 0..=32 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + let padding: String = "=".repeat(pad_bytes); + s.push_str(&padding); + s.push('E'); + + if pad_bytes % 4 == 0 { + assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); + } else { + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4, b'='), + engine.decode(&s).unwrap_err() + ); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_one_char_in_final_quad_with_padding_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("E="); + + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), + engine.decode(&s).unwrap_err() + ); + + // more padding doesn't change the error + s.push('='); + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), + engine.decode(&s).unwrap_err() + ); + + s.push('='); + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), + engine.decode(&s).unwrap_err() + ); + } + } +} + +#[apply(all_engines)] +fn decode_too_few_symbols_in_final_quad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // <2 is invalid + for final_quad_symbols in 0..2 { + for padding_symbols in 0..=(4 - final_quad_symbols) { + let mut s: String = "ABCD".repeat(num_prefix_quads); + + for _ in 0..final_quad_symbols { + s.push('A'); + } + for _ in 0..padding_symbols { + s.push('='); + } + + match final_quad_symbols + padding_symbols { + 0 => continue, + 1 => { + assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); + } + _ => { + // error reported at first padding byte + assert_eq!( + DecodeError::InvalidByte( + num_prefix_quads * 4 + final_quad_symbols, + b'=', + ), + engine.decode(&s).unwrap_err() + ); + } + } + } + } + } + } +} + +#[apply(all_engines)] +fn decode_invalid_trailing_bytes<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("Cg==\n"); + + // The case of trailing newlines is common enough to warrant a test for a good error + // message. + assert_eq!( + Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')), + engine.decode(&s) + ); + + // extra padding, however, is still InvalidLength + let s = s.replace('\n', "="); + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(s)); + } + } +} + +#[apply(all_engines)] +fn decode_wrong_length_error<E: EngineWrapper>(engine_wrapper: E) { + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); + + for num_prefix_quads in 0..256 { + // at least one token, otherwise it wouldn't be a final quad + for num_tokens_final_quad in 1..=4 { + for num_padding in 0..=(4 - num_tokens_final_quad) { + let mut s: String = "IIII".repeat(num_prefix_quads); + for _ in 0..num_tokens_final_quad { + s.push('g'); + } + for _ in 0..num_padding { + s.push('='); + } + + let res = engine.decode(&s); + if num_tokens_final_quad >= 2 { + assert!(res.is_ok()); + } else if num_tokens_final_quad == 1 && num_padding > 0 { + // = is invalid if it's too early + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + num_tokens_final_quad, + 61 + )), + res + ); + } else if num_padding > 2 { + assert_eq!(Err(DecodeError::InvalidPadding), res); + } else { + assert_eq!(Err(DecodeError::InvalidLength), res); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_into_slice_fits_in_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E) { + let mut orig_data = Vec::new(); + let mut encoded_data = String::new(); + let mut decode_buf = Vec::new(); + + let input_len_range = distributions::Uniform::new(0, 1000); + let mut rng = rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + decode_buf.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let engine = E::random(&mut rng); + engine.encode_string(&orig_data, &mut encoded_data); + assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); + + decode_buf.resize(input_len, 0); + + // decode into the non-empty buf + let decode_bytes_written = engine + .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) + .unwrap(); + + assert_eq!(orig_data.len(), decode_bytes_written); + assert_eq!(orig_data, decode_buf); + } +} + +#[apply(all_engines)] +fn decode_length_estimate_delta<E: EngineWrapper>(engine_wrapper: E) { + for engine in [E::standard(), E::standard_unpadded()] { + for &padding in &[true, false] { + for orig_len in 0..1000 { + let encoded_len = encoded_len(orig_len, padding).unwrap(); + + let decoded_estimate = engine + .internal_decoded_len_estimate(encoded_len) + .decoded_len_estimate(); + assert!(decoded_estimate >= orig_len); + assert!( + decoded_estimate - orig_len < 3, + "estimate: {}, encoded: {}, orig: {}", + decoded_estimate, + encoded_len, + orig_len + ); + } + } + } +} + +/// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. +/// +/// Vecs provided should be empty. +fn generate_random_encoded_data<E: Engine, R: rand::Rng, D: distributions::Distribution<usize>>( + engine: &E, + orig_data: &mut Vec<u8>, + encode_buf: &mut Vec<u8>, + rng: &mut R, + length_distribution: &D, +) -> (usize, usize, usize) { + let padding: bool = engine.config().encode_padding(); + + let orig_len = fill_rand(orig_data, rng, length_distribution); + let expected_encoded_len = encoded_len(orig_len, padding).unwrap(); + encode_buf.resize(expected_encoded_len, 0); + + let base_encoded_len = engine.internal_encode(&orig_data[..], &mut encode_buf[..]); + + let enc_len_with_padding = if padding { + base_encoded_len + add_padding(orig_len, &mut encode_buf[base_encoded_len..]) + } else { + base_encoded_len + }; + + assert_eq!(expected_encoded_len, enc_len_with_padding); + + (orig_len, base_encoded_len, enc_len_with_padding) +} + +// fill to a random length +fn fill_rand<R: rand::Rng, D: distributions::Distribution<usize>>( + vec: &mut Vec<u8>, + rng: &mut R, + length_distribution: &D, +) -> usize { + let len = length_distribution.sample(rng); + for _ in 0..len { + vec.push(rng.gen()); + } + + len +} + +fn fill_rand_len<R: rand::Rng>(vec: &mut Vec<u8>, rng: &mut R, len: usize) { + for _ in 0..len { + vec.push(rng.gen()); + } +} + +fn prefixed_data<'i, 'd>( + input_with_prefix: &'i mut String, + prefix_len: usize, + data: &'d str, +) -> &'i str { + input_with_prefix.truncate(prefix_len); + input_with_prefix.push_str(data); + input_with_prefix.as_str() +} + +/// A wrapper to make using engines in rstest fixtures easier. +/// The functions don't need to be instance methods, but rstest does seem +/// to want an instance, so instances are passed to test functions and then ignored. +trait EngineWrapper { + type Engine: Engine; + + /// Return an engine configured for RFC standard base64 + fn standard() -> Self::Engine; + + /// Return an engine configured for RFC standard base64, except with no padding appended on + /// encode, and required no padding on decode. + fn standard_unpadded() -> Self::Engine; + + /// Return an engine configured for RFC standard alphabet with the provided encode and decode + /// pad settings + fn standard_with_pad_mode(encode_pad: bool, decode_pad_mode: DecodePaddingMode) + -> Self::Engine; + + /// Return an engine configured for RFC standard base64 that allows invalid trailing bits + fn standard_allow_trailing_bits() -> Self::Engine; + + /// Return an engine configured with a randomized alphabet and config + fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine; + + /// Return an engine configured with the specified alphabet and randomized config + fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine; +} + +struct GeneralPurposeWrapper {} + +impl EngineWrapper for GeneralPurposeWrapper { + type Engine = general_purpose::GeneralPurpose; + + fn standard() -> Self::Engine { + general_purpose::GeneralPurpose::new(&STANDARD, general_purpose::PAD) + } + + fn standard_unpadded() -> Self::Engine { + general_purpose::GeneralPurpose::new(&STANDARD, general_purpose::NO_PAD) + } + + fn standard_with_pad_mode( + encode_pad: bool, + decode_pad_mode: DecodePaddingMode, + ) -> Self::Engine { + general_purpose::GeneralPurpose::new( + &STANDARD, + general_purpose::GeneralPurposeConfig::new() + .with_encode_padding(encode_pad) + .with_decode_padding_mode(decode_pad_mode), + ) + } + + fn standard_allow_trailing_bits() -> Self::Engine { + general_purpose::GeneralPurpose::new( + &STANDARD, + general_purpose::GeneralPurposeConfig::new().with_decode_allow_trailing_bits(true), + ) + } + + fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine { + let alphabet = random_alphabet(rng); + + Self::random_alphabet(rng, alphabet) + } + + fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine { + general_purpose::GeneralPurpose::new(alphabet, random_config(rng)) + } +} + +struct NaiveWrapper {} + +impl EngineWrapper for NaiveWrapper { + type Engine = naive::Naive; + + fn standard() -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: true, + decode_allow_trailing_bits: false, + decode_padding_mode: DecodePaddingMode::RequireCanonical, + }, + ) + } + + fn standard_unpadded() -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: false, + decode_allow_trailing_bits: false, + decode_padding_mode: DecodePaddingMode::RequireNone, + }, + ) + } + + fn standard_with_pad_mode( + encode_pad: bool, + decode_pad_mode: DecodePaddingMode, + ) -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: false, + decode_allow_trailing_bits: false, + decode_padding_mode: decode_pad_mode, + }, + ) + } + + fn standard_allow_trailing_bits() -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: true, + decode_allow_trailing_bits: true, + decode_padding_mode: DecodePaddingMode::RequireCanonical, + }, + ) + } + + fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine { + let alphabet = random_alphabet(rng); + + Self::random_alphabet(rng, alphabet) + } + + fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine { + let mode = rng.gen(); + + let config = naive::NaiveConfig { + encode_padding: match mode { + DecodePaddingMode::Indifferent => rng.gen(), + DecodePaddingMode::RequireCanonical => true, + DecodePaddingMode::RequireNone => false, + }, + decode_allow_trailing_bits: rng.gen(), + decode_padding_mode: mode, + }; + + naive::Naive::new(alphabet, config) + } +} + +fn seeded_rng() -> impl rand::Rng { + rngs::SmallRng::from_entropy() +} + +fn all_pad_modes() -> Vec<DecodePaddingMode> { + vec![ + DecodePaddingMode::Indifferent, + DecodePaddingMode::RequireCanonical, + DecodePaddingMode::RequireNone, + ] +} + +fn assert_all_suffixes_ok<E: Engine>(engine: E, suffixes: Vec<&str>) { + for num_prefix_quads in 0..256 { + for &suffix in suffixes.iter() { + let mut encoded = "AAAA".repeat(num_prefix_quads); + encoded.push_str(suffix); + + let res = &engine.decode(&encoded); + assert!(res.is_ok()); + } + } +} |