diff options
Diffstat (limited to 'third_party/rust/base64/src')
21 files changed, 6426 insertions, 0 deletions
diff --git a/third_party/rust/base64/src/alphabet.rs b/third_party/rust/base64/src/alphabet.rs new file mode 100644 index 0000000000..7cd1b57073 --- /dev/null +++ b/third_party/rust/base64/src/alphabet.rs @@ -0,0 +1,241 @@ +//! Provides [Alphabet] and constants for alphabets commonly used in the wild. + +use crate::PAD_BYTE; +use core::fmt; +#[cfg(any(feature = "std", test))] +use std::error; + +const ALPHABET_SIZE: usize = 64; + +/// An alphabet defines the 64 ASCII characters (symbols) used for base64. +/// +/// Common alphabets are provided as constants, and custom alphabets +/// can be made via `from_str` or the `TryFrom<str>` implementation. +/// +/// ``` +/// let custom = base64::alphabet::Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/").unwrap(); +/// +/// let engine = base64::engine::GeneralPurpose::new( +/// &custom, +/// base64::engine::general_purpose::PAD); +/// ``` +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Alphabet { + pub(crate) symbols: [u8; ALPHABET_SIZE], +} + +impl Alphabet { + /// Performs no checks so that it can be const. + /// Used only for known-valid strings. + const fn from_str_unchecked(alphabet: &str) -> Self { + let mut symbols = [0_u8; ALPHABET_SIZE]; + let source_bytes = alphabet.as_bytes(); + + // a way to copy that's allowed in const fn + let mut index = 0; + while index < ALPHABET_SIZE { + symbols[index] = source_bytes[index]; + index += 1; + } + + Self { symbols } + } + + /// Create an `Alphabet` from a string of 64 unique printable ASCII bytes. + /// + /// The `=` byte is not allowed as it is used for padding. + pub const fn new(alphabet: &str) -> Result<Self, ParseAlphabetError> { + let bytes = alphabet.as_bytes(); + if bytes.len() != ALPHABET_SIZE { + return Err(ParseAlphabetError::InvalidLength); + } + + { + let mut index = 0; + while index < ALPHABET_SIZE { + let byte = bytes[index]; + + // must be ascii printable. 127 (DEL) is commonly considered printable + // for some reason but clearly unsuitable for base64. + if !(byte >= 32_u8 && byte <= 126_u8) { + return Err(ParseAlphabetError::UnprintableByte(byte)); + } + // = is assumed to be padding, so cannot be used as a symbol + if byte == PAD_BYTE { + return Err(ParseAlphabetError::ReservedByte(byte)); + } + + // Check for duplicates while staying within what const allows. + // It's n^2, but only over 64 hot bytes, and only once, so it's likely in the single digit + // microsecond range. + + let mut probe_index = 0; + while probe_index < ALPHABET_SIZE { + if probe_index == index { + probe_index += 1; + continue; + } + + let probe_byte = bytes[probe_index]; + + if byte == probe_byte { + return Err(ParseAlphabetError::DuplicatedByte(byte)); + } + + probe_index += 1; + } + + index += 1; + } + } + + Ok(Self::from_str_unchecked(alphabet)) + } +} + +impl TryFrom<&str> for Alphabet { + type Error = ParseAlphabetError; + + fn try_from(value: &str) -> Result<Self, Self::Error> { + Self::new(value) + } +} + +/// Possible errors when constructing an [Alphabet] from a `str`. +#[derive(Debug, Eq, PartialEq)] +pub enum ParseAlphabetError { + /// Alphabets must be 64 ASCII bytes + InvalidLength, + /// All bytes must be unique + DuplicatedByte(u8), + /// All bytes must be printable (in the range `[32, 126]`). + UnprintableByte(u8), + /// `=` cannot be used + ReservedByte(u8), +} + +impl fmt::Display for ParseAlphabetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidLength => write!(f, "Invalid length - must be 64 bytes"), + Self::DuplicatedByte(b) => write!(f, "Duplicated byte: {:#04x}", b), + Self::UnprintableByte(b) => write!(f, "Unprintable byte: {:#04x}", b), + Self::ReservedByte(b) => write!(f, "Reserved byte: {:#04x}", b), + } + } +} + +#[cfg(any(feature = "std", test))] +impl error::Error for ParseAlphabetError {} + +/// The standard alphabet (uses `+` and `/`). +/// +/// See [RFC 3548](https://tools.ietf.org/html/rfc3548#section-3). +pub const STANDARD: Alphabet = Alphabet::from_str_unchecked( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", +); + +/// The URL safe alphabet (uses `-` and `_`). +/// +/// See [RFC 3548](https://tools.ietf.org/html/rfc3548#section-4). +pub const URL_SAFE: Alphabet = Alphabet::from_str_unchecked( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", +); + +/// The `crypt(3)` alphabet (uses `.` and `/` as the first two values). +/// +/// Not standardized, but folk wisdom on the net asserts that this alphabet is what crypt uses. +pub const CRYPT: Alphabet = Alphabet::from_str_unchecked( + "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", +); + +/// The bcrypt alphabet. +pub const BCRYPT: Alphabet = Alphabet::from_str_unchecked( + "./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789", +); + +/// The alphabet used in IMAP-modified UTF-7 (uses `+` and `,`). +/// +/// See [RFC 3501](https://tools.ietf.org/html/rfc3501#section-5.1.3) +pub const IMAP_MUTF7: Alphabet = Alphabet::from_str_unchecked( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+,", +); + +/// The alphabet used in BinHex 4.0 files. +/// +/// See [BinHex 4.0 Definition](http://files.stairways.com/other/binhex-40-specs-info.txt) +pub const BIN_HEX: Alphabet = Alphabet::from_str_unchecked( + "!\"#$%&'()*+,-0123456789@ABCDEFGHIJKLMNPQRSTUVXYZ[`abcdehijklmpqr", +); + +#[cfg(test)] +mod tests { + use crate::alphabet::*; + use std::convert::TryFrom as _; + + #[test] + fn detects_duplicate_start() { + assert_eq!( + ParseAlphabetError::DuplicatedByte(b'A'), + Alphabet::new("AACDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap_err() + ); + } + + #[test] + fn detects_duplicate_end() { + assert_eq!( + ParseAlphabetError::DuplicatedByte(b'/'), + Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789//") + .unwrap_err() + ); + } + + #[test] + fn detects_duplicate_middle() { + assert_eq!( + ParseAlphabetError::DuplicatedByte(b'Z'), + Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZZbcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap_err() + ); + } + + #[test] + fn detects_length() { + assert_eq!( + ParseAlphabetError::InvalidLength, + Alphabet::new( + "xxxxxxxxxABCDEFGHIJKLMNOPQRSTUVWXYZZbcdefghijklmnopqrstuvwxyz0123456789+/", + ) + .unwrap_err() + ); + } + + #[test] + fn detects_padding() { + assert_eq!( + ParseAlphabetError::ReservedByte(b'='), + Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+=") + .unwrap_err() + ); + } + + #[test] + fn detects_unprintable() { + // form feed + assert_eq!( + ParseAlphabetError::UnprintableByte(0xc), + Alphabet::new("\x0cBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap_err() + ); + } + + #[test] + fn same_as_unchecked() { + assert_eq!( + STANDARD, + Alphabet::try_from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap() + ); + } +} diff --git a/third_party/rust/base64/src/chunked_encoder.rs b/third_party/rust/base64/src/chunked_encoder.rs new file mode 100644 index 0000000000..0457259744 --- /dev/null +++ b/third_party/rust/base64/src/chunked_encoder.rs @@ -0,0 +1,231 @@ +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::string::String; +use core::cmp; +#[cfg(any(feature = "alloc", feature = "std", test))] +use core::str; + +use crate::encode::add_padding; +use crate::engine::{Config, Engine}; + +/// The output mechanism for ChunkedEncoder's encoded bytes. +pub trait Sink { + type Error; + + /// Handle a chunk of encoded base64 data (as UTF-8 bytes) + fn write_encoded_bytes(&mut self, encoded: &[u8]) -> Result<(), Self::Error>; +} + +const BUF_SIZE: usize = 1024; + +/// A base64 encoder that emits encoded bytes in chunks without heap allocation. +pub struct ChunkedEncoder<'e, E: Engine + ?Sized> { + engine: &'e E, + max_input_chunk_len: usize, +} + +impl<'e, E: Engine + ?Sized> ChunkedEncoder<'e, E> { + pub fn new(engine: &'e E) -> ChunkedEncoder<'e, E> { + ChunkedEncoder { + engine, + max_input_chunk_len: max_input_length(BUF_SIZE, engine.config().encode_padding()), + } + } + + pub fn encode<S: Sink>(&self, bytes: &[u8], sink: &mut S) -> Result<(), S::Error> { + let mut encode_buf: [u8; BUF_SIZE] = [0; BUF_SIZE]; + let mut input_index = 0; + + while input_index < bytes.len() { + // either the full input chunk size, or it's the last iteration + let input_chunk_len = cmp::min(self.max_input_chunk_len, bytes.len() - input_index); + + let chunk = &bytes[input_index..(input_index + input_chunk_len)]; + + let mut b64_bytes_written = self.engine.internal_encode(chunk, &mut encode_buf); + + input_index += input_chunk_len; + let more_input_left = input_index < bytes.len(); + + if self.engine.config().encode_padding() && !more_input_left { + // no more input, add padding if needed. Buffer will have room because + // max_input_length leaves room for it. + b64_bytes_written += add_padding(bytes.len(), &mut encode_buf[b64_bytes_written..]); + } + + sink.write_encoded_bytes(&encode_buf[0..b64_bytes_written])?; + } + + Ok(()) + } +} + +/// Calculate the longest input that can be encoded for the given output buffer size. +/// +/// If the config requires padding, two bytes of buffer space will be set aside so that the last +/// chunk of input can be encoded safely. +/// +/// The input length will always be a multiple of 3 so that no encoding state has to be carried over +/// between chunks. +fn max_input_length(encoded_buf_len: usize, padded: bool) -> usize { + let effective_buf_len = if padded { + // make room for padding + encoded_buf_len + .checked_sub(2) + .expect("Don't use a tiny buffer") + } else { + encoded_buf_len + }; + + // No padding, so just normal base64 expansion. + (effective_buf_len / 4) * 3 +} + +// A really simple sink that just appends to a string +#[cfg(any(feature = "alloc", feature = "std", test))] +pub(crate) struct StringSink<'a> { + string: &'a mut String, +} + +#[cfg(any(feature = "alloc", feature = "std", test))] +impl<'a> StringSink<'a> { + pub(crate) fn new(s: &mut String) -> StringSink { + StringSink { string: s } + } +} + +#[cfg(any(feature = "alloc", feature = "std", test))] +impl<'a> Sink for StringSink<'a> { + type Error = (); + + fn write_encoded_bytes(&mut self, s: &[u8]) -> Result<(), Self::Error> { + self.string.push_str(str::from_utf8(s).unwrap()); + + Ok(()) + } +} + +#[cfg(test)] +pub mod tests { + use rand::{ + distributions::{Distribution, Uniform}, + Rng, SeedableRng, + }; + + use crate::{ + alphabet::STANDARD, + engine::general_purpose::{GeneralPurpose, GeneralPurposeConfig, PAD}, + tests::random_engine, + }; + + use super::*; + + #[test] + fn chunked_encode_empty() { + assert_eq!("", chunked_encode_str(&[], PAD)); + } + + #[test] + fn chunked_encode_intermediate_fast_loop() { + // > 8 bytes input, will enter the pretty fast loop + assert_eq!("Zm9vYmFyYmF6cXV4", chunked_encode_str(b"foobarbazqux", PAD)); + } + + #[test] + fn chunked_encode_fast_loop() { + // > 32 bytes input, will enter the uber fast loop + assert_eq!( + "Zm9vYmFyYmF6cXV4cXV1eGNvcmdlZ3JhdWx0Z2FycGx5eg==", + chunked_encode_str(b"foobarbazquxquuxcorgegraultgarplyz", PAD) + ); + } + + #[test] + fn chunked_encode_slow_loop_only() { + // < 8 bytes input, slow loop only + assert_eq!("Zm9vYmFy", chunked_encode_str(b"foobar", PAD)); + } + + #[test] + fn chunked_encode_matches_normal_encode_random_string_sink() { + let helper = StringSinkTestHelper; + chunked_encode_matches_normal_encode_random(&helper); + } + + #[test] + fn max_input_length_no_pad() { + assert_eq!(768, max_input_length(1024, false)); + } + + #[test] + fn max_input_length_with_pad_decrements_one_triple() { + assert_eq!(765, max_input_length(1024, true)); + } + + #[test] + fn max_input_length_with_pad_one_byte_short() { + assert_eq!(765, max_input_length(1025, true)); + } + + #[test] + fn max_input_length_with_pad_fits_exactly() { + assert_eq!(768, max_input_length(1026, true)); + } + + #[test] + fn max_input_length_cant_use_extra_single_encoded_byte() { + assert_eq!(300, max_input_length(401, false)); + } + + pub fn chunked_encode_matches_normal_encode_random<S: SinkTestHelper>(sink_test_helper: &S) { + let mut input_buf: Vec<u8> = Vec::new(); + let mut output_buf = String::new(); + let mut rng = rand::rngs::SmallRng::from_entropy(); + let input_len_range = Uniform::new(1, 10_000); + + for _ in 0..5_000 { + input_buf.clear(); + output_buf.clear(); + + let buf_len = input_len_range.sample(&mut rng); + for _ in 0..buf_len { + input_buf.push(rng.gen()); + } + + let engine = random_engine(&mut rng); + + let chunk_encoded_string = sink_test_helper.encode_to_string(&engine, &input_buf); + engine.encode_string(&input_buf, &mut output_buf); + + assert_eq!(output_buf, chunk_encoded_string, "input len={}", buf_len); + } + } + + fn chunked_encode_str(bytes: &[u8], config: GeneralPurposeConfig) -> String { + let mut s = String::new(); + + let mut sink = StringSink::new(&mut s); + let engine = GeneralPurpose::new(&STANDARD, config); + let encoder = ChunkedEncoder::new(&engine); + encoder.encode(bytes, &mut sink).unwrap(); + + s + } + + // An abstraction around sinks so that we can have tests that easily to any sink implementation + pub trait SinkTestHelper { + fn encode_to_string<E: Engine>(&self, engine: &E, bytes: &[u8]) -> String; + } + + struct StringSinkTestHelper; + + impl SinkTestHelper for StringSinkTestHelper { + fn encode_to_string<E: Engine>(&self, engine: &E, bytes: &[u8]) -> String { + let encoder = ChunkedEncoder::new(engine); + let mut s = String::new(); + let mut sink = StringSink::new(&mut s); + encoder.encode(bytes, &mut sink).unwrap(); + + s + } + } +} diff --git a/third_party/rust/base64/src/decode.rs b/third_party/rust/base64/src/decode.rs new file mode 100644 index 0000000000..047151840c --- /dev/null +++ b/third_party/rust/base64/src/decode.rs @@ -0,0 +1,349 @@ +use crate::engine::{general_purpose::STANDARD, DecodeEstimate, Engine}; +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::vec::Vec; +use core::fmt; +#[cfg(any(feature = "std", test))] +use std::error; + +/// Errors that can occur while decoding. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum DecodeError { + /// An invalid byte was found in the input. The offset and offending byte are provided. + /// Padding characters (`=`) interspersed in the encoded form will be treated as invalid bytes. + InvalidByte(usize, u8), + /// The length of the input is invalid. + /// A typical cause of this is stray trailing whitespace or other separator bytes. + /// In the case where excess trailing bytes have produced an invalid length *and* the last byte + /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte` + /// will be emitted instead of `InvalidLength` to make the issue easier to debug. + InvalidLength, + /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. + /// This is indicative of corrupted or truncated Base64. + /// Unlike `InvalidByte`, which reports symbols that aren't in the alphabet, this error is for + /// symbols that are in the alphabet but represent nonsensical encodings. + InvalidLastSymbol(usize, u8), + /// The nature of the padding was not as configured: absent or incorrect when it must be + /// canonical, or present when it must be absent, etc. + InvalidPadding, +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Self::InvalidByte(index, byte) => write!(f, "Invalid byte {}, offset {}.", byte, index), + Self::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), + Self::InvalidLastSymbol(index, byte) => { + write!(f, "Invalid last symbol {}, offset {}.", byte, index) + } + Self::InvalidPadding => write!(f, "Invalid padding"), + } + } +} + +#[cfg(any(feature = "std", test))] +impl error::Error for DecodeError { + fn cause(&self) -> Option<&dyn error::Error> { + None + } +} + +/// Errors that can occur while decoding into a slice. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum DecodeSliceError { + /// A [DecodeError] occurred + DecodeError(DecodeError), + /// The provided slice _may_ be too small. + /// + /// The check is conservative (assumes the last triplet of output bytes will all be needed). + OutputSliceTooSmall, +} + +impl fmt::Display for DecodeSliceError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::DecodeError(e) => write!(f, "DecodeError: {}", e), + Self::OutputSliceTooSmall => write!(f, "Output slice too small"), + } + } +} + +#[cfg(any(feature = "std", test))] +impl error::Error for DecodeSliceError { + fn cause(&self) -> Option<&dyn error::Error> { + match self { + DecodeSliceError::DecodeError(e) => Some(e), + DecodeSliceError::OutputSliceTooSmall => None, + } + } +} + +impl From<DecodeError> for DecodeSliceError { + fn from(e: DecodeError) -> Self { + DecodeSliceError::DecodeError(e) + } +} + +/// Decode base64 using the [`STANDARD` engine](STANDARD). +/// +/// See [Engine::decode]. +#[deprecated(since = "0.21.0", note = "Use Engine::decode")] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> { + STANDARD.decode(input) +} + +/// Decode from string reference as octets using the specified [Engine]. +/// +/// See [Engine::decode]. +///Returns a `Result` containing a `Vec<u8>`. +#[deprecated(since = "0.21.0", note = "Use Engine::decode")] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub fn decode_engine<E: Engine, T: AsRef<[u8]>>( + input: T, + engine: &E, +) -> Result<Vec<u8>, DecodeError> { + engine.decode(input) +} + +/// Decode from string reference as octets. +/// +/// See [Engine::decode_vec]. +#[cfg(any(feature = "alloc", feature = "std", test))] +#[deprecated(since = "0.21.0", note = "Use Engine::decode_vec")] +pub fn decode_engine_vec<E: Engine, T: AsRef<[u8]>>( + input: T, + buffer: &mut Vec<u8>, + engine: &E, +) -> Result<(), DecodeError> { + engine.decode_vec(input, buffer) +} + +/// Decode the input into the provided output slice. +/// +/// See [Engine::decode_slice]. +#[deprecated(since = "0.21.0", note = "Use Engine::decode_slice")] +pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>( + input: T, + output: &mut [u8], + engine: &E, +) -> Result<usize, DecodeSliceError> { + engine.decode_slice(input, output) +} + +/// Returns a conservative estimate of the decoded size of `encoded_len` base64 symbols (rounded up +/// to the next group of 3 decoded bytes). +/// +/// The resulting length will be a safe choice for the size of a decode buffer, but may have up to +/// 2 trailing bytes that won't end up being needed. +/// +/// # Examples +/// +/// ``` +/// use base64::decoded_len_estimate; +/// +/// assert_eq!(3, decoded_len_estimate(1)); +/// assert_eq!(3, decoded_len_estimate(2)); +/// assert_eq!(3, decoded_len_estimate(3)); +/// assert_eq!(3, decoded_len_estimate(4)); +/// // start of the next quad of encoded symbols +/// assert_eq!(6, decoded_len_estimate(5)); +/// ``` +/// +/// # Panics +/// +/// Panics if decoded length estimation overflows. +/// This would happen for sizes within a few bytes of the maximum value of `usize`. +pub fn decoded_len_estimate(encoded_len: usize) -> usize { + STANDARD + .internal_decoded_len_estimate(encoded_len) + .decoded_len_estimate() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + alphabet, + engine::{general_purpose, Config, GeneralPurpose}, + tests::{assert_encode_sanity, random_engine}, + }; + use rand::{ + distributions::{Distribution, Uniform}, + Rng, SeedableRng, + }; + + #[test] + fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() { + let mut orig_data = Vec::new(); + let mut encoded_data = String::new(); + let mut decoded_with_prefix = Vec::new(); + let mut decoded_without_prefix = Vec::new(); + let mut prefix = Vec::new(); + + let prefix_len_range = Uniform::new(0, 1000); + let input_len_range = Uniform::new(0, 1000); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + decoded_with_prefix.clear(); + decoded_without_prefix.clear(); + prefix.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut encoded_data); + assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); + + let prefix_len = prefix_len_range.sample(&mut rng); + + // fill the buf with a prefix + for _ in 0..prefix_len { + prefix.push(rng.gen()); + } + + decoded_with_prefix.resize(prefix_len, 0); + decoded_with_prefix.copy_from_slice(&prefix); + + // decode into the non-empty buf + engine + .decode_vec(&encoded_data, &mut decoded_with_prefix) + .unwrap(); + // also decode into the empty buf + engine + .decode_vec(&encoded_data, &mut decoded_without_prefix) + .unwrap(); + + assert_eq!( + prefix_len + decoded_without_prefix.len(), + decoded_with_prefix.len() + ); + assert_eq!(orig_data, decoded_without_prefix); + + // append plain decode onto prefix + prefix.append(&mut decoded_without_prefix); + + assert_eq!(prefix, decoded_with_prefix); + } + } + + #[test] + fn decode_slice_doesnt_clobber_existing_prefix_or_suffix() { + do_decode_slice_doesnt_clobber_existing_prefix_or_suffix(|e, input, output| { + e.decode_slice(input, output).unwrap() + }) + } + + #[test] + fn decode_slice_unchecked_doesnt_clobber_existing_prefix_or_suffix() { + do_decode_slice_doesnt_clobber_existing_prefix_or_suffix(|e, input, output| { + e.decode_slice_unchecked(input, output).unwrap() + }) + } + + #[test] + fn decode_engine_estimation_works_for_various_lengths() { + let engine = GeneralPurpose::new(&alphabet::STANDARD, general_purpose::NO_PAD); + for num_prefix_quads in 0..100 { + for suffix in &["AA", "AAA", "AAAA"] { + let mut prefix = "AAAA".repeat(num_prefix_quads); + prefix.push_str(suffix); + // make sure no overflow (and thus a panic) occurs + let res = engine.decode(prefix); + assert!(res.is_ok()); + } + } + } + + #[test] + fn decode_slice_output_length_errors() { + for num_quads in 1..100 { + let input = "AAAA".repeat(num_quads); + let mut vec = vec![0; (num_quads - 1) * 3]; + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + STANDARD.decode_slice(&input, &mut vec).unwrap_err() + ); + vec.push(0); + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + STANDARD.decode_slice(&input, &mut vec).unwrap_err() + ); + vec.push(0); + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + STANDARD.decode_slice(&input, &mut vec).unwrap_err() + ); + vec.push(0); + // now it works + assert_eq!( + num_quads * 3, + STANDARD.decode_slice(&input, &mut vec).unwrap() + ); + } + } + + fn do_decode_slice_doesnt_clobber_existing_prefix_or_suffix< + F: Fn(&GeneralPurpose, &[u8], &mut [u8]) -> usize, + >( + call_decode: F, + ) { + let mut orig_data = Vec::new(); + let mut encoded_data = String::new(); + let mut decode_buf = Vec::new(); + let mut decode_buf_copy: Vec<u8> = Vec::new(); + + let input_len_range = Uniform::new(0, 1000); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + decode_buf.clear(); + decode_buf_copy.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut encoded_data); + assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); + + // fill the buffer with random garbage, long enough to have some room before and after + for _ in 0..5000 { + decode_buf.push(rng.gen()); + } + + // keep a copy for later comparison + decode_buf_copy.extend(decode_buf.iter()); + + let offset = 1000; + + // decode into the non-empty buf + let decode_bytes_written = + call_decode(&engine, encoded_data.as_bytes(), &mut decode_buf[offset..]); + + assert_eq!(orig_data.len(), decode_bytes_written); + assert_eq!( + orig_data, + &decode_buf[offset..(offset + decode_bytes_written)] + ); + assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]); + assert_eq!( + &decode_buf_copy[offset + decode_bytes_written..], + &decode_buf[offset + decode_bytes_written..] + ); + } + } +} diff --git a/third_party/rust/base64/src/display.rs b/third_party/rust/base64/src/display.rs new file mode 100644 index 0000000000..fc292f1b00 --- /dev/null +++ b/third_party/rust/base64/src/display.rs @@ -0,0 +1,88 @@ +//! Enables base64'd output anywhere you might use a `Display` implementation, like a format string. +//! +//! ``` +//! use base64::{display::Base64Display, engine::general_purpose::STANDARD}; +//! +//! let data = vec![0x0, 0x1, 0x2, 0x3]; +//! let wrapper = Base64Display::new(&data, &STANDARD); +//! +//! assert_eq!("base64: AAECAw==", format!("base64: {}", wrapper)); +//! ``` + +use super::chunked_encoder::ChunkedEncoder; +use crate::engine::Engine; +use core::fmt::{Display, Formatter}; +use core::{fmt, str}; + +/// A convenience wrapper for base64'ing bytes into a format string without heap allocation. +pub struct Base64Display<'a, 'e, E: Engine> { + bytes: &'a [u8], + chunked_encoder: ChunkedEncoder<'e, E>, +} + +impl<'a, 'e, E: Engine> Base64Display<'a, 'e, E> { + /// Create a `Base64Display` with the provided engine. + pub fn new(bytes: &'a [u8], engine: &'e E) -> Base64Display<'a, 'e, E> { + Base64Display { + bytes, + chunked_encoder: ChunkedEncoder::new(engine), + } + } +} + +impl<'a, 'e, E: Engine> Display for Base64Display<'a, 'e, E> { + fn fmt(&self, formatter: &mut Formatter) -> Result<(), fmt::Error> { + let mut sink = FormatterSink { f: formatter }; + self.chunked_encoder.encode(self.bytes, &mut sink) + } +} + +struct FormatterSink<'a, 'b: 'a> { + f: &'a mut Formatter<'b>, +} + +impl<'a, 'b: 'a> super::chunked_encoder::Sink for FormatterSink<'a, 'b> { + type Error = fmt::Error; + + fn write_encoded_bytes(&mut self, encoded: &[u8]) -> Result<(), Self::Error> { + // Avoid unsafe. If max performance is needed, write your own display wrapper that uses + // unsafe here to gain about 10-15%. + self.f + .write_str(str::from_utf8(encoded).expect("base64 data was not utf8")) + } +} + +#[cfg(test)] +mod tests { + use super::super::chunked_encoder::tests::{ + chunked_encode_matches_normal_encode_random, SinkTestHelper, + }; + use super::*; + use crate::engine::general_purpose::STANDARD; + + #[test] + fn basic_display() { + assert_eq!( + "~$Zm9vYmFy#*", + format!("~${}#*", Base64Display::new(b"foobar", &STANDARD)) + ); + assert_eq!( + "~$Zm9vYmFyZg==#*", + format!("~${}#*", Base64Display::new(b"foobarf", &STANDARD)) + ); + } + + #[test] + fn display_encode_matches_normal_encode() { + let helper = DisplaySinkTestHelper; + chunked_encode_matches_normal_encode_random(&helper); + } + + struct DisplaySinkTestHelper; + + impl SinkTestHelper for DisplaySinkTestHelper { + fn encode_to_string<E: Engine>(&self, engine: &E, bytes: &[u8]) -> String { + format!("{}", Base64Display::new(bytes, engine)) + } + } +} diff --git a/third_party/rust/base64/src/encode.rs b/third_party/rust/base64/src/encode.rs new file mode 100644 index 0000000000..cb176504a1 --- /dev/null +++ b/third_party/rust/base64/src/encode.rs @@ -0,0 +1,488 @@ +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::string::String; +use core::fmt; +#[cfg(any(feature = "std", test))] +use std::error; + +#[cfg(any(feature = "alloc", feature = "std", test))] +use crate::engine::general_purpose::STANDARD; +use crate::engine::{Config, Engine}; +use crate::PAD_BYTE; + +/// Encode arbitrary octets as base64 using the [`STANDARD` engine](STANDARD). +/// +/// See [Engine::encode]. +#[allow(unused)] +#[deprecated(since = "0.21.0", note = "Use Engine::encode")] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub fn encode<T: AsRef<[u8]>>(input: T) -> String { + STANDARD.encode(input) +} + +///Encode arbitrary octets as base64 using the provided `Engine` into a new `String`. +/// +/// See [Engine::encode]. +#[allow(unused)] +#[deprecated(since = "0.21.0", note = "Use Engine::encode")] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub fn encode_engine<E: Engine, T: AsRef<[u8]>>(input: T, engine: &E) -> String { + engine.encode(input) +} + +///Encode arbitrary octets as base64 into a supplied `String`. +/// +/// See [Engine::encode_string]. +#[allow(unused)] +#[deprecated(since = "0.21.0", note = "Use Engine::encode_string")] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub fn encode_engine_string<E: Engine, T: AsRef<[u8]>>( + input: T, + output_buf: &mut String, + engine: &E, +) { + engine.encode_string(input, output_buf) +} + +/// Encode arbitrary octets as base64 into a supplied slice. +/// +/// See [Engine::encode_slice]. +#[allow(unused)] +#[deprecated(since = "0.21.0", note = "Use Engine::encode_slice")] +pub fn encode_engine_slice<E: Engine, T: AsRef<[u8]>>( + input: T, + output_buf: &mut [u8], + engine: &E, +) -> Result<usize, EncodeSliceError> { + engine.encode_slice(input, output_buf) +} + +/// B64-encode and pad (if configured). +/// +/// This helper exists to avoid recalculating encoded_size, which is relatively expensive on short +/// inputs. +/// +/// `encoded_size` is the encoded size calculated for `input`. +/// +/// `output` must be of size `encoded_size`. +/// +/// All bytes in `output` will be written to since it is exactly the size of the output. +pub(crate) fn encode_with_padding<E: Engine + ?Sized>( + input: &[u8], + output: &mut [u8], + engine: &E, + expected_encoded_size: usize, +) { + debug_assert_eq!(expected_encoded_size, output.len()); + + let b64_bytes_written = engine.internal_encode(input, output); + + let padding_bytes = if engine.config().encode_padding() { + add_padding(input.len(), &mut output[b64_bytes_written..]) + } else { + 0 + }; + + let encoded_bytes = b64_bytes_written + .checked_add(padding_bytes) + .expect("usize overflow when calculating b64 length"); + + debug_assert_eq!(expected_encoded_size, encoded_bytes); +} + +/// Calculate the base64 encoded length for a given input length, optionally including any +/// appropriate padding bytes. +/// +/// Returns `None` if the encoded length can't be represented in `usize`. This will happen for +/// input lengths in approximately the top quarter of the range of `usize`. +pub fn encoded_len(bytes_len: usize, padding: bool) -> Option<usize> { + let rem = bytes_len % 3; + + let complete_input_chunks = bytes_len / 3; + let complete_chunk_output = complete_input_chunks.checked_mul(4); + + if rem > 0 { + if padding { + complete_chunk_output.and_then(|c| c.checked_add(4)) + } else { + let encoded_rem = match rem { + 1 => 2, + 2 => 3, + _ => unreachable!("Impossible remainder"), + }; + complete_chunk_output.and_then(|c| c.checked_add(encoded_rem)) + } + } else { + complete_chunk_output + } +} + +/// Write padding characters. +/// `input_len` is the size of the original, not encoded, input. +/// `output` is the slice where padding should be written, of length at least 2. +/// +/// Returns the number of padding bytes written. +pub(crate) fn add_padding(input_len: usize, output: &mut [u8]) -> usize { + // TODO base on encoded len to use cheaper mod by 4 (aka & 7) + let rem = input_len % 3; + let mut bytes_written = 0; + for _ in 0..((3 - rem) % 3) { + output[bytes_written] = PAD_BYTE; + bytes_written += 1; + } + + bytes_written +} + +/// Errors that can occur while encoding into a slice. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum EncodeSliceError { + /// The provided slice is too small. + OutputSliceTooSmall, +} + +impl fmt::Display for EncodeSliceError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::OutputSliceTooSmall => write!(f, "Output slice too small"), + } + } +} + +#[cfg(any(feature = "std", test))] +impl error::Error for EncodeSliceError { + fn cause(&self) -> Option<&dyn error::Error> { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{ + alphabet, + engine::general_purpose::{GeneralPurpose, NO_PAD, STANDARD}, + tests::{assert_encode_sanity, random_config, random_engine}, + }; + use rand::{ + distributions::{Distribution, Uniform}, + Rng, SeedableRng, + }; + use std::str; + + const URL_SAFE_NO_PAD_ENGINE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, NO_PAD); + + #[test] + fn encoded_size_correct_standard() { + assert_encoded_length(0, 0, &STANDARD, true); + + assert_encoded_length(1, 4, &STANDARD, true); + assert_encoded_length(2, 4, &STANDARD, true); + assert_encoded_length(3, 4, &STANDARD, true); + + assert_encoded_length(4, 8, &STANDARD, true); + assert_encoded_length(5, 8, &STANDARD, true); + assert_encoded_length(6, 8, &STANDARD, true); + + assert_encoded_length(7, 12, &STANDARD, true); + assert_encoded_length(8, 12, &STANDARD, true); + assert_encoded_length(9, 12, &STANDARD, true); + + assert_encoded_length(54, 72, &STANDARD, true); + + assert_encoded_length(55, 76, &STANDARD, true); + assert_encoded_length(56, 76, &STANDARD, true); + assert_encoded_length(57, 76, &STANDARD, true); + + assert_encoded_length(58, 80, &STANDARD, true); + } + + #[test] + fn encoded_size_correct_no_pad() { + assert_encoded_length(0, 0, &URL_SAFE_NO_PAD_ENGINE, false); + + assert_encoded_length(1, 2, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(2, 3, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(3, 4, &URL_SAFE_NO_PAD_ENGINE, false); + + assert_encoded_length(4, 6, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(5, 7, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(6, 8, &URL_SAFE_NO_PAD_ENGINE, false); + + assert_encoded_length(7, 10, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(8, 11, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(9, 12, &URL_SAFE_NO_PAD_ENGINE, false); + + assert_encoded_length(54, 72, &URL_SAFE_NO_PAD_ENGINE, false); + + assert_encoded_length(55, 74, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(56, 75, &URL_SAFE_NO_PAD_ENGINE, false); + assert_encoded_length(57, 76, &URL_SAFE_NO_PAD_ENGINE, false); + + assert_encoded_length(58, 78, &URL_SAFE_NO_PAD_ENGINE, false); + } + + #[test] + fn encoded_size_overflow() { + assert_eq!(None, encoded_len(usize::MAX, true)); + } + + #[test] + fn encode_engine_string_into_nonempty_buffer_doesnt_clobber_prefix() { + let mut orig_data = Vec::new(); + let mut prefix = String::new(); + let mut encoded_data_no_prefix = String::new(); + let mut encoded_data_with_prefix = String::new(); + let mut decoded = Vec::new(); + + let prefix_len_range = Uniform::new(0, 1000); + let input_len_range = Uniform::new(0, 1000); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + prefix.clear(); + encoded_data_no_prefix.clear(); + encoded_data_with_prefix.clear(); + decoded.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let prefix_len = prefix_len_range.sample(&mut rng); + for _ in 0..prefix_len { + // getting convenient random single-byte printable chars that aren't base64 is + // annoying + prefix.push('#'); + } + encoded_data_with_prefix.push_str(&prefix); + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut encoded_data_no_prefix); + engine.encode_string(&orig_data, &mut encoded_data_with_prefix); + + assert_eq!( + encoded_data_no_prefix.len() + prefix_len, + encoded_data_with_prefix.len() + ); + assert_encode_sanity( + &encoded_data_no_prefix, + engine.config().encode_padding(), + input_len, + ); + assert_encode_sanity( + &encoded_data_with_prefix[prefix_len..], + engine.config().encode_padding(), + input_len, + ); + + // append plain encode onto prefix + prefix.push_str(&encoded_data_no_prefix); + + assert_eq!(prefix, encoded_data_with_prefix); + + engine + .decode_vec(&encoded_data_no_prefix, &mut decoded) + .unwrap(); + assert_eq!(orig_data, decoded); + } + } + + #[test] + fn encode_engine_slice_into_nonempty_buffer_doesnt_clobber_suffix() { + let mut orig_data = Vec::new(); + let mut encoded_data = Vec::new(); + let mut encoded_data_original_state = Vec::new(); + let mut decoded = Vec::new(); + + let input_len_range = Uniform::new(0, 1000); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + encoded_data_original_state.clear(); + decoded.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + // plenty of existing garbage in the encoded buffer + for _ in 0..10 * input_len { + encoded_data.push(rng.gen()); + } + + encoded_data_original_state.extend_from_slice(&encoded_data); + + let engine = random_engine(&mut rng); + + let encoded_size = encoded_len(input_len, engine.config().encode_padding()).unwrap(); + + assert_eq!( + encoded_size, + engine.encode_slice(&orig_data, &mut encoded_data).unwrap() + ); + + assert_encode_sanity( + str::from_utf8(&encoded_data[0..encoded_size]).unwrap(), + engine.config().encode_padding(), + input_len, + ); + + assert_eq!( + &encoded_data[encoded_size..], + &encoded_data_original_state[encoded_size..] + ); + + engine + .decode_vec(&encoded_data[0..encoded_size], &mut decoded) + .unwrap(); + assert_eq!(orig_data, decoded); + } + } + + #[test] + fn encode_to_slice_random_valid_utf8() { + let mut input = Vec::new(); + let mut output = Vec::new(); + + let input_len_range = Uniform::new(0, 1000); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + input.clear(); + output.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + input.push(rng.gen()); + } + + let config = random_config(&mut rng); + let engine = random_engine(&mut rng); + + // fill up the output buffer with garbage + let encoded_size = encoded_len(input_len, config.encode_padding()).unwrap(); + for _ in 0..encoded_size { + output.push(rng.gen()); + } + + let orig_output_buf = output.clone(); + + let bytes_written = engine.internal_encode(&input, &mut output); + + // make sure the part beyond bytes_written is the same garbage it was before + assert_eq!(orig_output_buf[bytes_written..], output[bytes_written..]); + + // make sure the encoded bytes are UTF-8 + let _ = str::from_utf8(&output[0..bytes_written]).unwrap(); + } + } + + #[test] + fn encode_with_padding_random_valid_utf8() { + let mut input = Vec::new(); + let mut output = Vec::new(); + + let input_len_range = Uniform::new(0, 1000); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + input.clear(); + output.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + input.push(rng.gen()); + } + + let engine = random_engine(&mut rng); + + // fill up the output buffer with garbage + let encoded_size = encoded_len(input_len, engine.config().encode_padding()).unwrap(); + for _ in 0..encoded_size + 1000 { + output.push(rng.gen()); + } + + let orig_output_buf = output.clone(); + + encode_with_padding(&input, &mut output[0..encoded_size], &engine, encoded_size); + + // make sure the part beyond b64 is the same garbage it was before + assert_eq!(orig_output_buf[encoded_size..], output[encoded_size..]); + + // make sure the encoded bytes are UTF-8 + let _ = str::from_utf8(&output[0..encoded_size]).unwrap(); + } + } + + #[test] + fn add_padding_random_valid_utf8() { + let mut output = Vec::new(); + + let mut rng = rand::rngs::SmallRng::from_entropy(); + + // cover our bases for length % 3 + for input_len in 0..10 { + output.clear(); + + // fill output with random + for _ in 0..10 { + output.push(rng.gen()); + } + + let orig_output_buf = output.clone(); + + let bytes_written = add_padding(input_len, &mut output); + + // make sure the part beyond bytes_written is the same garbage it was before + assert_eq!(orig_output_buf[bytes_written..], output[bytes_written..]); + + // make sure the encoded bytes are UTF-8 + let _ = str::from_utf8(&output[0..bytes_written]).unwrap(); + } + } + + fn assert_encoded_length<E: Engine>( + input_len: usize, + enc_len: usize, + engine: &E, + padded: bool, + ) { + assert_eq!(enc_len, encoded_len(input_len, padded).unwrap()); + + let mut bytes: Vec<u8> = Vec::new(); + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..input_len { + bytes.push(rng.gen()); + } + + let encoded = engine.encode(&bytes); + assert_encode_sanity(&encoded, padded, input_len); + + assert_eq!(enc_len, encoded.len()); + } + + #[test] + fn encode_imap() { + assert_eq!( + &GeneralPurpose::new(&alphabet::IMAP_MUTF7, NO_PAD).encode(b"\xFB\xFF"), + &GeneralPurpose::new(&alphabet::STANDARD, NO_PAD) + .encode(b"\xFB\xFF") + .replace('/', ",") + ); + } +} diff --git a/third_party/rust/base64/src/engine/general_purpose/decode.rs b/third_party/rust/base64/src/engine/general_purpose/decode.rs new file mode 100644 index 0000000000..e9fd78877b --- /dev/null +++ b/third_party/rust/base64/src/engine/general_purpose/decode.rs @@ -0,0 +1,348 @@ +use crate::{ + engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode}, + DecodeError, PAD_BYTE, +}; + +// decode logic operates on chunks of 8 input bytes without padding +const INPUT_CHUNK_LEN: usize = 8; +const DECODED_CHUNK_LEN: usize = 6; + +// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last +// 2 bytes of any output u64 should not be counted as written to (but must be available in a +// slice). +const DECODED_CHUNK_SUFFIX: usize = 2; + +// how many u64's of input to handle at a time +const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; + +const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; + +// includes the trailing 2 bytes for the final u64 write +const DECODED_BLOCK_LEN: usize = + CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; + +#[doc(hidden)] +pub struct GeneralPurposeEstimate { + /// Total number of decode chunks, including a possibly partial last chunk + num_chunks: usize, + decoded_len_estimate: usize, +} + +impl GeneralPurposeEstimate { + pub(crate) fn new(encoded_len: usize) -> Self { + Self { + num_chunks: encoded_len + .checked_add(INPUT_CHUNK_LEN - 1) + .expect("Overflow when calculating number of chunks in input") + / INPUT_CHUNK_LEN, + decoded_len_estimate: encoded_len + .checked_add(3) + .expect("Overflow when calculating decoded len estimate") + / 4 + * 3, + } + } +} + +impl DecodeEstimate for GeneralPurposeEstimate { + fn decoded_len_estimate(&self) -> usize { + self.decoded_len_estimate + } +} + +/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. +/// Returns the number of bytes written, or an error. +// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is +// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, +// but this is fragile and the best setting changes with only minor code modifications. +#[inline] +pub(crate) fn decode_helper( + input: &[u8], + estimate: GeneralPurposeEstimate, + output: &mut [u8], + decode_table: &[u8; 256], + decode_allow_trailing_bits: bool, + padding_mode: DecodePaddingMode, +) -> Result<usize, DecodeError> { + let remainder_len = input.len() % INPUT_CHUNK_LEN; + + // Because the fast decode loop writes in groups of 8 bytes (unrolled to + // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of + // which only 6 are valid data), we need to be sure that we stop using the fast decode loop + // soon enough that there will always be 2 more bytes of valid data written after that loop. + let trailing_bytes_to_skip = match remainder_len { + // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, + // and the fast decode logic cannot handle padding + 0 => INPUT_CHUNK_LEN, + // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte + 1 | 5 => { + // trailing whitespace is so common that it's worth it to check the last byte to + // possibly return a better error message + if let Some(b) = input.last() { + if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, *b)); + } + } + + return Err(DecodeError::InvalidLength); + } + // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes + // written by the fast decode loop. So, we have to ignore both these 2 bytes and the + // previous chunk. + 2 => INPUT_CHUNK_LEN + 2, + // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this + // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail + // with an error, not panic from going past the bounds of the output slice, so we let it + // use stage 3 + 4. + 3 => INPUT_CHUNK_LEN + 3, + // This can also decode to one output byte because it may be 2 input chars + 2 padding + // chars, which would decode to 1 byte. + 4 => INPUT_CHUNK_LEN + 4, + // Everything else is a legal decode len (given that we don't require padding), and will + // decode to at least 2 bytes of output. + _ => remainder_len, + }; + + // rounded up to include partial chunks + let mut remaining_chunks = estimate.num_chunks; + + let mut input_index = 0; + let mut output_index = 0; + + { + let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); + + // Fast loop, stage 1 + // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks + if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { + while input_index <= max_start_index { + let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; + let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; + + decode_chunk( + &input_slice[0..], + input_index, + decode_table, + &mut output_slice[0..], + )?; + decode_chunk( + &input_slice[8..], + input_index + 8, + decode_table, + &mut output_slice[6..], + )?; + decode_chunk( + &input_slice[16..], + input_index + 16, + decode_table, + &mut output_slice[12..], + )?; + decode_chunk( + &input_slice[24..], + input_index + 24, + decode_table, + &mut output_slice[18..], + )?; + + input_index += INPUT_BLOCK_LEN; + output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; + remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; + } + } + + // Fast loop, stage 2 (aka still pretty fast loop) + // 8 bytes at a time for whatever we didn't do in stage 1. + if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { + while input_index < max_start_index { + decode_chunk( + &input[input_index..(input_index + INPUT_CHUNK_LEN)], + input_index, + decode_table, + &mut output + [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], + )?; + + output_index += DECODED_CHUNK_LEN; + input_index += INPUT_CHUNK_LEN; + remaining_chunks -= 1; + } + } + } + + // Stage 3 + // If input length was such that a chunk had to be deferred until after the fast loop + // because decoding it would have produced 2 trailing bytes that wouldn't then be + // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 + // trailing bytes. + // However, we still need to avoid the last chunk (partial or complete) because it could + // have padding, so we always do 1 fewer to avoid the last chunk. + for _ in 1..remaining_chunks { + decode_chunk_precise( + &input[input_index..], + input_index, + decode_table, + &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], + )?; + + input_index += INPUT_CHUNK_LEN; + output_index += DECODED_CHUNK_LEN; + } + + // always have one more (possibly partial) block of 8 input + debug_assert!(input.len() - input_index > 1 || input.is_empty()); + debug_assert!(input.len() - input_index <= 8); + + super::decode_suffix::decode_suffix( + input, + input_index, + output, + output_index, + decode_table, + decode_allow_trailing_bits, + padding_mode, + ) +} + +/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the +/// first 6 of those contain meaningful data. +/// +/// `input` is the bytes to decode, of which the first 8 bytes will be processed. +/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors +/// accurately) +/// `decode_table` is the lookup table for the particular base64 alphabet. +/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded +/// data. +// yes, really inline (worth 30-50% speedup) +#[inline(always)] +fn decode_chunk( + input: &[u8], + index_at_start_of_input: usize, + decode_table: &[u8; 256], + output: &mut [u8], +) -> Result<(), DecodeError> { + let morsel = decode_table[input[0] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = (morsel as u64) << 58; + + let morsel = decode_table[input[1] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= (morsel as u64) << 52; + + let morsel = decode_table[input[2] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= (morsel as u64) << 46; + + let morsel = decode_table[input[3] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= (morsel as u64) << 40; + + let morsel = decode_table[input[4] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 4, + input[4], + )); + } + accum |= (morsel as u64) << 34; + + let morsel = decode_table[input[5] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 5, + input[5], + )); + } + accum |= (morsel as u64) << 28; + + let morsel = decode_table[input[6] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 6, + input[6], + )); + } + accum |= (morsel as u64) << 22; + + let morsel = decode_table[input[7] as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 7, + input[7], + )); + } + accum |= (morsel as u64) << 16; + + write_u64(output, accum); + + Ok(()) +} + +/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 +/// trailing garbage bytes. +#[inline] +fn decode_chunk_precise( + input: &[u8], + index_at_start_of_input: usize, + decode_table: &[u8; 256], + output: &mut [u8], +) -> Result<(), DecodeError> { + let mut tmp_buf = [0_u8; 8]; + + decode_chunk( + input, + index_at_start_of_input, + decode_table, + &mut tmp_buf[..], + )?; + + output[0..6].copy_from_slice(&tmp_buf[0..6]); + + Ok(()) +} + +#[inline] +fn write_u64(output: &mut [u8], value: u64) { + output[..8].copy_from_slice(&value.to_be_bytes()); +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::engine::general_purpose::STANDARD; + + #[test] + fn decode_chunk_precise_writes_only_6_bytes() { + let input = b"Zm9vYmFy"; // "foobar" + let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + + decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); + } + + #[test] + fn decode_chunk_writes_8_bytes() { + let input = b"Zm9vYmFy"; // "foobar" + let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + + decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); + } +} diff --git a/third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs b/third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs new file mode 100644 index 0000000000..5652035d0e --- /dev/null +++ b/third_party/rust/base64/src/engine/general_purpose/decode_suffix.rs @@ -0,0 +1,161 @@ +use crate::{ + engine::{general_purpose::INVALID_VALUE, DecodePaddingMode}, + DecodeError, PAD_BYTE, +}; + +/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided +/// parameters. +/// +/// Returns the total number of bytes decoded, including the ones indicated as already written by +/// `output_index`. +pub(crate) fn decode_suffix( + input: &[u8], + input_index: usize, + output: &mut [u8], + mut output_index: usize, + decode_table: &[u8; 256], + decode_allow_trailing_bits: bool, + padding_mode: DecodePaddingMode, +) -> Result<usize, DecodeError> { + // Decode any leftovers that aren't a complete input block of 8 bytes. + // Use a u64 as a stack-resident 8 byte buffer. + let mut leftover_bits: u64 = 0; + let mut morsels_in_leftover = 0; + let mut padding_bytes = 0; + let mut first_padding_index: usize = 0; + let mut last_symbol = 0_u8; + let start_of_leftovers = input_index; + + for (i, &b) in input[start_of_leftovers..].iter().enumerate() { + // '=' padding + if b == PAD_BYTE { + // There can be bad padding bytes in a few ways: + // 1 - Padding with non-padding characters after it + // 2 - Padding after zero or one characters in the current quad (should only + // be after 2 or 3 chars) + // 3 - More than two characters of padding. If 3 or 4 padding chars + // are in the same quad, that implies it will be caught by #2. + // If it spreads from one quad to another, it will be an invalid byte + // in the first quad. + // 4 - Non-canonical padding -- 1 byte when it should be 2, etc. + // Per config, non-canonical but still functional non- or partially-padded base64 + // may be treated as an error condition. + + if i % 4 < 2 { + // Check for case #2. + let bad_padding_index = start_of_leftovers + + if padding_bytes > 0 { + // If we've already seen padding, report the first padding index. + // This is to be consistent with the normal decode logic: it will report an + // error on the first padding character (since it doesn't expect to see + // anything but actual encoded data). + // This could only happen if the padding started in the previous quad since + // otherwise this case would have been hit at i % 4 == 0 if it was the same + // quad. + first_padding_index + } else { + // haven't seen padding before, just use where we are now + i + }; + return Err(DecodeError::InvalidByte(bad_padding_index, b)); + } + + if padding_bytes == 0 { + first_padding_index = i; + } + + padding_bytes += 1; + continue; + } + + // Check for case #1. + // To make '=' handling consistent with the main loop, don't allow + // non-suffix '=' in trailing chunk either. Report error as first + // erroneous padding. + if padding_bytes > 0 { + return Err(DecodeError::InvalidByte( + start_of_leftovers + first_padding_index, + PAD_BYTE, + )); + } + + last_symbol = b; + + // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. + // Pack the leftovers from left to right. + let shift = 64 - (morsels_in_leftover + 1) * 6; + let morsel = decode_table[b as usize]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); + } + + leftover_bits |= (morsel as u64) << shift; + morsels_in_leftover += 1; + } + + match padding_mode { + DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } + DecodePaddingMode::RequireCanonical => { + if (padding_bytes + morsels_in_leftover) % 4 != 0 { + return Err(DecodeError::InvalidPadding); + } + } + DecodePaddingMode::RequireNone => { + if padding_bytes > 0 { + // check at the end to make sure we let the cases of padding that should be InvalidByte + // get hit + return Err(DecodeError::InvalidPadding); + } + } + } + + // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed. + // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits + // of bottom 6 bits set). + // When decoding two symbols back to one trailing byte, any final symbol higher than + // w would still decode to the original byte because we only care about the top two + // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a + // mask based on how many bits are used for just the canonical encoding, and optionally + // error if any other bits are set. In the example of one encoded byte -> 2 symbols, + // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and + // useless since there are no more symbols to provide the necessary 4 additional bits + // to finish the second original byte. + + let leftover_bits_ready_to_append = match morsels_in_leftover { + 0 => 0, + 2 => 8, + 3 => 16, + 4 => 24, + 6 => 32, + 7 => 40, + 8 => 48, + // can also be detected as case #2 bad padding above + _ => unreachable!( + "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" + ), + }; + + // if there are bits set outside the bits we care about, last symbol encodes trailing bits that + // will not be included in the output + let mask = !0 >> leftover_bits_ready_to_append; + if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { + // last morsel is at `morsels_in_leftover` - 1 + return Err(DecodeError::InvalidLastSymbol( + start_of_leftovers + morsels_in_leftover - 1, + last_symbol, + )); + } + + // TODO benchmark simply converting to big endian bytes + let mut leftover_bits_appended_to_buf = 0; + while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { + // `as` simply truncates the higher bits, which is what we want here + let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; + output[output_index] = selected_bits; + output_index += 1; + + leftover_bits_appended_to_buf += 8; + } + + Ok(output_index) +} diff --git a/third_party/rust/base64/src/engine/general_purpose/mod.rs b/third_party/rust/base64/src/engine/general_purpose/mod.rs new file mode 100644 index 0000000000..af8897bc2b --- /dev/null +++ b/third_party/rust/base64/src/engine/general_purpose/mod.rs @@ -0,0 +1,349 @@ +//! Provides the [GeneralPurpose] engine and associated config types. +use crate::{ + alphabet, + alphabet::Alphabet, + engine::{Config, DecodePaddingMode}, + DecodeError, +}; +use core::convert::TryInto; + +mod decode; +pub(crate) mod decode_suffix; +pub use decode::GeneralPurposeEstimate; + +pub(crate) const INVALID_VALUE: u8 = 255; + +/// A general-purpose base64 engine. +/// +/// - It uses no vector CPU instructions, so it will work on any system. +/// - It is reasonably fast (~2-3GiB/s). +/// - It is not constant-time, though, so it is vulnerable to timing side-channel attacks. For loading cryptographic keys, etc, it is suggested to use the forthcoming constant-time implementation. +pub struct GeneralPurpose { + encode_table: [u8; 64], + decode_table: [u8; 256], + config: GeneralPurposeConfig, +} + +impl GeneralPurpose { + /// Create a `GeneralPurpose` engine from an [Alphabet]. + /// + /// While not very expensive to initialize, ideally these should be cached + /// if the engine will be used repeatedly. + pub const fn new(alphabet: &Alphabet, config: GeneralPurposeConfig) -> Self { + Self { + encode_table: encode_table(alphabet), + decode_table: decode_table(alphabet), + config, + } + } +} + +impl super::Engine for GeneralPurpose { + type Config = GeneralPurposeConfig; + type DecodeEstimate = GeneralPurposeEstimate; + + fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize { + let mut input_index: usize = 0; + + const BLOCKS_PER_FAST_LOOP: usize = 4; + const LOW_SIX_BITS: u64 = 0x3F; + + // we read 8 bytes at a time (u64) but only actually consume 6 of those bytes. Thus, we need + // 2 trailing bytes to be available to read.. + let last_fast_index = input.len().saturating_sub(BLOCKS_PER_FAST_LOOP * 6 + 2); + let mut output_index = 0; + + if last_fast_index > 0 { + while input_index <= last_fast_index { + // Major performance wins from letting the optimizer do the bounds check once, mostly + // on the output side + let input_chunk = + &input[input_index..(input_index + (BLOCKS_PER_FAST_LOOP * 6 + 2))]; + let output_chunk = + &mut output[output_index..(output_index + BLOCKS_PER_FAST_LOOP * 8)]; + + // Hand-unrolling for 32 vs 16 or 8 bytes produces yields performance about equivalent + // to unsafe pointer code on a Xeon E5-1650v3. 64 byte unrolling was slightly better for + // large inputs but significantly worse for 50-byte input, unsurprisingly. I suspect + // that it's a not uncommon use case to encode smallish chunks of data (e.g. a 64-byte + // SHA-512 digest), so it would be nice if that fit in the unrolled loop at least once. + // Plus, single-digit percentage performance differences might well be quite different + // on different hardware. + + let input_u64 = read_u64(&input_chunk[0..]); + + output_chunk[0] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[1] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[2] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[3] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[4] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[5] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[6] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[7] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + let input_u64 = read_u64(&input_chunk[6..]); + + output_chunk[8] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[9] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[10] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[11] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[12] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[13] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[14] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[15] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + let input_u64 = read_u64(&input_chunk[12..]); + + output_chunk[16] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[17] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[18] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[19] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[20] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[21] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[22] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[23] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + let input_u64 = read_u64(&input_chunk[18..]); + + output_chunk[24] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize]; + output_chunk[25] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize]; + output_chunk[26] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize]; + output_chunk[27] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize]; + output_chunk[28] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize]; + output_chunk[29] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize]; + output_chunk[30] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize]; + output_chunk[31] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize]; + + output_index += BLOCKS_PER_FAST_LOOP * 8; + input_index += BLOCKS_PER_FAST_LOOP * 6; + } + } + + // Encode what's left after the fast loop. + + const LOW_SIX_BITS_U8: u8 = 0x3F; + + let rem = input.len() % 3; + let start_of_rem = input.len() - rem; + + // start at the first index not handled by fast loop, which may be 0. + + while input_index < start_of_rem { + let input_chunk = &input[input_index..(input_index + 3)]; + let output_chunk = &mut output[output_index..(output_index + 4)]; + + output_chunk[0] = self.encode_table[(input_chunk[0] >> 2) as usize]; + output_chunk[1] = self.encode_table + [((input_chunk[0] << 4 | input_chunk[1] >> 4) & LOW_SIX_BITS_U8) as usize]; + output_chunk[2] = self.encode_table + [((input_chunk[1] << 2 | input_chunk[2] >> 6) & LOW_SIX_BITS_U8) as usize]; + output_chunk[3] = self.encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize]; + + input_index += 3; + output_index += 4; + } + + if rem == 2 { + output[output_index] = self.encode_table[(input[start_of_rem] >> 2) as usize]; + output[output_index + 1] = + self.encode_table[((input[start_of_rem] << 4 | input[start_of_rem + 1] >> 4) + & LOW_SIX_BITS_U8) as usize]; + output[output_index + 2] = + self.encode_table[((input[start_of_rem + 1] << 2) & LOW_SIX_BITS_U8) as usize]; + output_index += 3; + } else if rem == 1 { + output[output_index] = self.encode_table[(input[start_of_rem] >> 2) as usize]; + output[output_index + 1] = + self.encode_table[((input[start_of_rem] << 4) & LOW_SIX_BITS_U8) as usize]; + output_index += 2; + } + + output_index + } + + fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate { + GeneralPurposeEstimate::new(input_len) + } + + fn internal_decode( + &self, + input: &[u8], + output: &mut [u8], + estimate: Self::DecodeEstimate, + ) -> Result<usize, DecodeError> { + decode::decode_helper( + input, + estimate, + output, + &self.decode_table, + self.config.decode_allow_trailing_bits, + self.config.decode_padding_mode, + ) + } + + fn config(&self) -> &Self::Config { + &self.config + } +} + +/// Returns a table mapping a 6-bit index to the ASCII byte encoding of the index +pub(crate) const fn encode_table(alphabet: &Alphabet) -> [u8; 64] { + // the encode table is just the alphabet: + // 6-bit index lookup -> printable byte + let mut encode_table = [0_u8; 64]; + { + let mut index = 0; + while index < 64 { + encode_table[index] = alphabet.symbols[index]; + index += 1; + } + } + + encode_table +} + +/// Returns a table mapping base64 bytes as the lookup index to either: +/// - [INVALID_VALUE] for bytes that aren't members of the alphabet +/// - a byte whose lower 6 bits are the value that was encoded into the index byte +pub(crate) const fn decode_table(alphabet: &Alphabet) -> [u8; 256] { + let mut decode_table = [INVALID_VALUE; 256]; + + // Since the table is full of `INVALID_VALUE` already, we only need to overwrite + // the parts that are valid. + let mut index = 0; + while index < 64 { + // The index in the alphabet is the 6-bit value we care about. + // Since the index is in 0-63, it is safe to cast to u8. + decode_table[alphabet.symbols[index] as usize] = index as u8; + index += 1; + } + + decode_table +} + +#[inline] +fn read_u64(s: &[u8]) -> u64 { + u64::from_be_bytes(s[..8].try_into().unwrap()) +} + +/// Contains configuration parameters for base64 encoding and decoding. +/// +/// ``` +/// # use base64::engine::GeneralPurposeConfig; +/// let config = GeneralPurposeConfig::new() +/// .with_encode_padding(false); +/// // further customize using `.with_*` methods as needed +/// ``` +/// +/// The constants [PAD] and [NO_PAD] cover most use cases. +/// +/// To specify the characters used, see [Alphabet]. +#[derive(Clone, Copy, Debug)] +pub struct GeneralPurposeConfig { + encode_padding: bool, + decode_allow_trailing_bits: bool, + decode_padding_mode: DecodePaddingMode, +} + +impl GeneralPurposeConfig { + /// Create a new config with `padding` = `true`, `decode_allow_trailing_bits` = `false`, and + /// `decode_padding_mode = DecodePaddingMode::RequireCanonicalPadding`. + /// + /// This probably matches most people's expectations, but consider disabling padding to save + /// a few bytes unless you specifically need it for compatibility with some legacy system. + pub const fn new() -> Self { + Self { + // RFC states that padding must be applied by default + encode_padding: true, + decode_allow_trailing_bits: false, + decode_padding_mode: DecodePaddingMode::RequireCanonical, + } + } + + /// Create a new config based on `self` with an updated `padding` setting. + /// + /// If `padding` is `true`, encoding will append either 1 or 2 `=` padding characters as needed + /// to produce an output whose length is a multiple of 4. + /// + /// Padding is not needed for correct decoding and only serves to waste bytes, but it's in the + /// [spec](https://datatracker.ietf.org/doc/html/rfc4648#section-3.2). + /// + /// For new applications, consider not using padding if the decoders you're using don't require + /// padding to be present. + pub const fn with_encode_padding(self, padding: bool) -> Self { + Self { + encode_padding: padding, + ..self + } + } + + /// Create a new config based on `self` with an updated `decode_allow_trailing_bits` setting. + /// + /// Most users will not need to configure this. It's useful if you need to decode base64 + /// produced by a buggy encoder that has bits set in the unused space on the last base64 + /// character as per [forgiving-base64 decode](https://infra.spec.whatwg.org/#forgiving-base64-decode). + /// If invalid trailing bits are present and this is `true`, those bits will + /// be silently ignored, else `DecodeError::InvalidLastSymbol` will be emitted. + pub const fn with_decode_allow_trailing_bits(self, allow: bool) -> Self { + Self { + decode_allow_trailing_bits: allow, + ..self + } + } + + /// Create a new config based on `self` with an updated `decode_padding_mode` setting. + /// + /// Padding is not useful in terms of representing encoded data -- it makes no difference to + /// the decoder if padding is present or not, so if you have some un-padded input to decode, it + /// is perfectly fine to use `DecodePaddingMode::Indifferent` to prevent errors from being + /// emitted. + /// + /// However, since in practice + /// [people who learned nothing from BER vs DER seem to expect base64 to have one canonical encoding](https://eprint.iacr.org/2022/361), + /// the default setting is the stricter `DecodePaddingMode::RequireCanonicalPadding`. + /// + /// Or, if "canonical" in your circumstance means _no_ padding rather than padding to the + /// next multiple of four, there's `DecodePaddingMode::RequireNoPadding`. + pub const fn with_decode_padding_mode(self, mode: DecodePaddingMode) -> Self { + Self { + decode_padding_mode: mode, + ..self + } + } +} + +impl Default for GeneralPurposeConfig { + /// Delegates to [GeneralPurposeConfig::new]. + fn default() -> Self { + Self::new() + } +} + +impl Config for GeneralPurposeConfig { + fn encode_padding(&self) -> bool { + self.encode_padding + } +} + +/// A [GeneralPurpose] engine using the [alphabet::STANDARD] base64 alphabet and [PAD] config. +pub const STANDARD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, PAD); + +/// A [GeneralPurpose] engine using the [alphabet::STANDARD] base64 alphabet and [NO_PAD] config. +pub const STANDARD_NO_PAD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, NO_PAD); + +/// A [GeneralPurpose] engine using the [alphabet::URL_SAFE] base64 alphabet and [PAD] config. +pub const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, PAD); + +/// A [GeneralPurpose] engine using the [alphabet::URL_SAFE] base64 alphabet and [NO_PAD] config. +pub const URL_SAFE_NO_PAD: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, NO_PAD); + +/// Include padding bytes when encoding, and require that they be present when decoding. +/// +/// This is the standard per the base64 RFC, but consider using [NO_PAD] instead as padding serves +/// little purpose in practice. +pub const PAD: GeneralPurposeConfig = GeneralPurposeConfig::new(); + +/// Don't add padding when encoding, and require no padding when decoding. +pub const NO_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new() + .with_encode_padding(false) + .with_decode_padding_mode(DecodePaddingMode::RequireNone); diff --git a/third_party/rust/base64/src/engine/mod.rs b/third_party/rust/base64/src/engine/mod.rs new file mode 100644 index 0000000000..12dfaa8845 --- /dev/null +++ b/third_party/rust/base64/src/engine/mod.rs @@ -0,0 +1,410 @@ +//! Provides the [Engine] abstraction and out of the box implementations. +#[cfg(any(feature = "alloc", feature = "std", test))] +use crate::chunked_encoder; +use crate::{ + encode::{encode_with_padding, EncodeSliceError}, + encoded_len, DecodeError, DecodeSliceError, +}; +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::vec::Vec; + +#[cfg(any(feature = "alloc", feature = "std", test))] +use alloc::{string::String, vec}; + +pub mod general_purpose; + +#[cfg(test)] +mod naive; + +#[cfg(test)] +mod tests; + +pub use general_purpose::{GeneralPurpose, GeneralPurposeConfig}; + +/// An `Engine` provides low-level encoding and decoding operations that all other higher-level parts of the API use. Users of the library will generally not need to implement this. +/// +/// Different implementations offer different characteristics. The library currently ships with +/// [GeneralPurpose] that offers good speed and works on any CPU, with more choices +/// coming later, like a constant-time one when side channel resistance is called for, and vendor-specific vectorized ones for more speed. +/// +/// See [general_purpose::STANDARD_NO_PAD] if you just want standard base64. Otherwise, when possible, it's +/// recommended to store the engine in a `const` so that references to it won't pose any lifetime +/// issues, and to avoid repeating the cost of engine setup. +/// +/// Since almost nobody will need to implement `Engine`, docs for internal methods are hidden. +// When adding an implementation of Engine, include them in the engine test suite: +// - add an implementation of [engine::tests::EngineWrapper] +// - add the implementation to the `all_engines` macro +// All tests run on all engines listed in the macro. +pub trait Engine: Send + Sync { + /// The config type used by this engine + type Config: Config; + /// The decode estimate used by this engine + type DecodeEstimate: DecodeEstimate; + + /// This is not meant to be called directly; it is only for `Engine` implementors. + /// See the other `encode*` functions on this trait. + /// + /// Encode the `input` bytes into the `output` buffer based on the mapping in `encode_table`. + /// + /// `output` will be long enough to hold the encoded data. + /// + /// Returns the number of bytes written. + /// + /// No padding should be written; that is handled separately. + /// + /// Must not write any bytes into the output slice other than the encoded data. + #[doc(hidden)] + fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize; + + /// This is not meant to be called directly; it is only for `Engine` implementors. + /// + /// As an optimization to prevent the decoded length from being calculated twice, it is + /// sometimes helpful to have a conservative estimate of the decoded size before doing the + /// decoding, so this calculation is done separately and passed to [Engine::decode()] as needed. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + #[doc(hidden)] + fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate; + + /// This is not meant to be called directly; it is only for `Engine` implementors. + /// See the other `decode*` functions on this trait. + /// + /// Decode `input` base64 bytes into the `output` buffer. + /// + /// `decode_estimate` is the result of [Engine::internal_decoded_len_estimate()], which is passed in to avoid + /// calculating it again (expensive on short inputs).` + /// + /// Returns the number of bytes written to `output`. + /// + /// Each complete 4-byte chunk of encoded data decodes to 3 bytes of decoded data, but this + /// function must also handle the final possibly partial chunk. + /// If the input length is not a multiple of 4, or uses padding bytes to reach a multiple of 4, + /// the trailing 2 or 3 bytes must decode to 1 or 2 bytes, respectively, as per the + /// [RFC](https://tools.ietf.org/html/rfc4648#section-3.5). + /// + /// Decoding must not write any bytes into the output slice other than the decoded data. + /// + /// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as + /// errors unless the engine is configured otherwise. + /// + /// # Panics + /// + /// Panics if `output` is too small. + #[doc(hidden)] + fn internal_decode( + &self, + input: &[u8], + output: &mut [u8], + decode_estimate: Self::DecodeEstimate, + ) -> Result<usize, DecodeError>; + + /// Returns the config for this engine. + fn config(&self) -> &Self::Config; + + /// Encode arbitrary octets as base64 using the provided `Engine`. + /// Returns a `String`. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, engine::{self, general_purpose}, alphabet}; + /// + /// let b64 = general_purpose::STANDARD.encode(b"hello world~"); + /// println!("{}", b64); + /// + /// const CUSTOM_ENGINE: engine::GeneralPurpose = + /// engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); + /// + /// let b64_url = CUSTOM_ENGINE.encode(b"hello internet~"); + #[cfg(any(feature = "alloc", feature = "std", test))] + fn encode<T: AsRef<[u8]>>(&self, input: T) -> String { + let encoded_size = encoded_len(input.as_ref().len(), self.config().encode_padding()) + .expect("integer overflow when calculating buffer size"); + let mut buf = vec![0; encoded_size]; + + encode_with_padding(input.as_ref(), &mut buf[..], self, encoded_size); + + String::from_utf8(buf).expect("Invalid UTF8") + } + + /// Encode arbitrary octets as base64 into a supplied `String`. + /// Writes into the supplied `String`, which may allocate if its internal buffer isn't big enough. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, engine::{self, general_purpose}, alphabet}; + /// const CUSTOM_ENGINE: engine::GeneralPurpose = + /// engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); + /// + /// fn main() { + /// let mut buf = String::new(); + /// general_purpose::STANDARD.encode_string(b"hello world~", &mut buf); + /// println!("{}", buf); + /// + /// buf.clear(); + /// CUSTOM_ENGINE.encode_string(b"hello internet~", &mut buf); + /// println!("{}", buf); + /// } + /// ``` + #[cfg(any(feature = "alloc", feature = "std", test))] + fn encode_string<T: AsRef<[u8]>>(&self, input: T, output_buf: &mut String) { + let input_bytes = input.as_ref(); + + { + let mut sink = chunked_encoder::StringSink::new(output_buf); + + chunked_encoder::ChunkedEncoder::new(self) + .encode(input_bytes, &mut sink) + .expect("Writing to a String shouldn't fail"); + } + } + + /// Encode arbitrary octets as base64 into a supplied slice. + /// Writes into the supplied output buffer. + /// + /// This is useful if you wish to avoid allocation entirely (e.g. encoding into a stack-resident + /// or statically-allocated buffer). + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, engine::general_purpose}; + /// let s = b"hello internet!"; + /// let mut buf = Vec::new(); + /// // make sure we'll have a slice big enough for base64 + padding + /// buf.resize(s.len() * 4 / 3 + 4, 0); + /// + /// let bytes_written = general_purpose::STANDARD.encode_slice(s, &mut buf).unwrap(); + /// + /// // shorten our vec down to just what was written + /// buf.truncate(bytes_written); + /// + /// assert_eq!(s, general_purpose::STANDARD.decode(&buf).unwrap().as_slice()); + /// ``` + fn encode_slice<T: AsRef<[u8]>>( + &self, + input: T, + output_buf: &mut [u8], + ) -> Result<usize, EncodeSliceError> { + let input_bytes = input.as_ref(); + + let encoded_size = encoded_len(input_bytes.len(), self.config().encode_padding()) + .expect("usize overflow when calculating buffer size"); + + if output_buf.len() < encoded_size { + return Err(EncodeSliceError::OutputSliceTooSmall); + } + + let b64_output = &mut output_buf[0..encoded_size]; + + encode_with_padding(input_bytes, b64_output, self, encoded_size); + + Ok(encoded_size) + } + + /// Decode from string reference as octets using the specified [Engine]. + /// Returns a `Result` containing a `Vec<u8>`. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, alphabet, engine::{self, general_purpose}}; + /// + /// let bytes = general_purpose::STANDARD + /// .decode("aGVsbG8gd29ybGR+Cg==").unwrap(); + /// println!("{:?}", bytes); + /// + /// // custom engine setup + /// let bytes_url = engine::GeneralPurpose::new( + /// &alphabet::URL_SAFE, + /// general_purpose::NO_PAD) + /// .decode("aGVsbG8gaW50ZXJuZXR-Cg").unwrap(); + /// println!("{:?}", bytes_url); + /// ``` + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + #[cfg(any(feature = "alloc", feature = "std", test))] + fn decode<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, DecodeError> { + let input_bytes = input.as_ref(); + + let estimate = self.internal_decoded_len_estimate(input_bytes.len()); + let mut buffer = vec![0; estimate.decoded_len_estimate()]; + + let bytes_written = self.internal_decode(input_bytes, &mut buffer, estimate)?; + buffer.truncate(bytes_written); + + Ok(buffer) + } + + /// Decode from string reference as octets. + /// Writes into the supplied `Vec`, which may allocate if its internal buffer isn't big enough. + /// Returns a `Result` containing an empty tuple, aka `()`. + /// + /// # Example + /// + /// ```rust + /// use base64::{Engine as _, alphabet, engine::{self, general_purpose}}; + /// const CUSTOM_ENGINE: engine::GeneralPurpose = + /// engine::GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::PAD); + /// + /// fn main() { + /// use base64::Engine; + /// let mut buffer = Vec::<u8>::new(); + /// // with the default engine + /// general_purpose::STANDARD + /// .decode_vec("aGVsbG8gd29ybGR+Cg==", &mut buffer,).unwrap(); + /// println!("{:?}", buffer); + /// + /// buffer.clear(); + /// + /// // with a custom engine + /// CUSTOM_ENGINE.decode_vec( + /// "aGVsbG8gaW50ZXJuZXR-Cg==", + /// &mut buffer, + /// ).unwrap(); + /// println!("{:?}", buffer); + /// } + /// ``` + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + #[cfg(any(feature = "alloc", feature = "std", test))] + fn decode_vec<T: AsRef<[u8]>>( + &self, + input: T, + buffer: &mut Vec<u8>, + ) -> Result<(), DecodeError> { + let input_bytes = input.as_ref(); + + let starting_output_len = buffer.len(); + + let estimate = self.internal_decoded_len_estimate(input_bytes.len()); + let total_len_estimate = estimate + .decoded_len_estimate() + .checked_add(starting_output_len) + .expect("Overflow when calculating output buffer length"); + buffer.resize(total_len_estimate, 0); + + let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..]; + let bytes_written = self.internal_decode(input_bytes, buffer_slice, estimate)?; + + buffer.truncate(starting_output_len + bytes_written); + + Ok(()) + } + + /// Decode the input into the provided output slice. + /// + /// Returns an error if `output` is smaller than the estimated decoded length. + /// + /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). + /// + /// See [crate::decoded_len_estimate] for calculating buffer sizes. + /// + /// See [Engine::decode_slice_unchecked] for a version that panics instead of returning an error + /// if the output buffer is too small. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + fn decode_slice<T: AsRef<[u8]>>( + &self, + input: T, + output: &mut [u8], + ) -> Result<usize, DecodeSliceError> { + let input_bytes = input.as_ref(); + + let estimate = self.internal_decoded_len_estimate(input_bytes.len()); + if output.len() < estimate.decoded_len_estimate() { + return Err(DecodeSliceError::OutputSliceTooSmall); + } + + self.internal_decode(input_bytes, output, estimate) + .map_err(|e| e.into()) + } + + /// Decode the input into the provided output slice. + /// + /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). + /// + /// See [crate::decoded_len_estimate] for calculating buffer sizes. + /// + /// See [Engine::decode_slice] for a version that returns an error instead of panicking if the output + /// buffer is too small. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + /// + /// Panics if the provided output buffer is too small for the decoded data. + fn decode_slice_unchecked<T: AsRef<[u8]>>( + &self, + input: T, + output: &mut [u8], + ) -> Result<usize, DecodeError> { + let input_bytes = input.as_ref(); + + self.internal_decode( + input_bytes, + output, + self.internal_decoded_len_estimate(input_bytes.len()), + ) + } +} + +/// The minimal level of configuration that engines must support. +pub trait Config { + /// Returns `true` if padding should be added after the encoded output. + /// + /// Padding is added outside the engine's encode() since the engine may be used + /// to encode only a chunk of the overall output, so it can't always know when + /// the output is "done" and would therefore need padding (if configured). + // It could be provided as a separate parameter when encoding, but that feels like + // leaking an implementation detail to the user, and it's hopefully more convenient + // to have to only pass one thing (the engine) to any part of the API. + fn encode_padding(&self) -> bool; +} + +/// The decode estimate used by an engine implementation. Users do not need to interact with this; +/// it is only for engine implementors. +/// +/// Implementors may store relevant data here when constructing this to avoid having to calculate +/// them again during actual decoding. +pub trait DecodeEstimate { + /// Returns a conservative (err on the side of too big) estimate of the decoded length to use + /// for pre-allocating buffers, etc. + /// + /// The estimate must be no larger than the next largest complete triple of decoded bytes. + /// That is, the final quad of tokens to decode may be assumed to be complete with no padding. + /// + /// # Panics + /// + /// Panics if decoded length estimation overflows. + /// This would happen for sizes within a few bytes of the maximum value of `usize`. + fn decoded_len_estimate(&self) -> usize; +} + +/// Controls how pad bytes are handled when decoding. +/// +/// Each [Engine] must support at least the behavior indicated by +/// [DecodePaddingMode::RequireCanonical], and may support other modes. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DecodePaddingMode { + /// Canonical padding is allowed, but any fewer padding bytes than that is also allowed. + Indifferent, + /// Padding must be canonical (0, 1, or 2 `=` as needed to produce a 4 byte suffix). + RequireCanonical, + /// Padding must be absent -- for when you want predictable padding, without any wasted bytes. + RequireNone, +} diff --git a/third_party/rust/base64/src/engine/naive.rs b/third_party/rust/base64/src/engine/naive.rs new file mode 100644 index 0000000000..6665c5eb41 --- /dev/null +++ b/third_party/rust/base64/src/engine/naive.rs @@ -0,0 +1,219 @@ +use crate::{ + alphabet::Alphabet, + engine::{ + general_purpose::{self, decode_table, encode_table}, + Config, DecodeEstimate, DecodePaddingMode, Engine, + }, + DecodeError, PAD_BYTE, +}; +use alloc::ops::BitOr; +use std::ops::{BitAnd, Shl, Shr}; + +/// Comparatively simple implementation that can be used as something to compare against in tests +pub struct Naive { + encode_table: [u8; 64], + decode_table: [u8; 256], + config: NaiveConfig, +} + +impl Naive { + const ENCODE_INPUT_CHUNK_SIZE: usize = 3; + const DECODE_INPUT_CHUNK_SIZE: usize = 4; + + pub const fn new(alphabet: &Alphabet, config: NaiveConfig) -> Self { + Self { + encode_table: encode_table(alphabet), + decode_table: decode_table(alphabet), + config, + } + } + + fn decode_byte_into_u32(&self, offset: usize, byte: u8) -> Result<u32, DecodeError> { + let decoded = self.decode_table[byte as usize]; + + if decoded == general_purpose::INVALID_VALUE { + return Err(DecodeError::InvalidByte(offset, byte)); + } + + Ok(decoded as u32) + } +} + +impl Engine for Naive { + type Config = NaiveConfig; + type DecodeEstimate = NaiveEstimate; + + fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize { + // complete chunks first + + const LOW_SIX_BITS: u32 = 0x3F; + + let rem = input.len() % Self::ENCODE_INPUT_CHUNK_SIZE; + // will never underflow + let complete_chunk_len = input.len() - rem; + + let mut input_index = 0_usize; + let mut output_index = 0_usize; + if let Some(last_complete_chunk_index) = + complete_chunk_len.checked_sub(Self::ENCODE_INPUT_CHUNK_SIZE) + { + while input_index <= last_complete_chunk_index { + let chunk = &input[input_index..input_index + Self::ENCODE_INPUT_CHUNK_SIZE]; + + // populate low 24 bits from 3 bytes + let chunk_int: u32 = + (chunk[0] as u32).shl(16) | (chunk[1] as u32).shl(8) | (chunk[2] as u32); + // encode 4x 6-bit output bytes + output[output_index] = self.encode_table[chunk_int.shr(18) as usize]; + output[output_index + 1] = + self.encode_table[chunk_int.shr(12_u8).bitand(LOW_SIX_BITS) as usize]; + output[output_index + 2] = + self.encode_table[chunk_int.shr(6_u8).bitand(LOW_SIX_BITS) as usize]; + output[output_index + 3] = + self.encode_table[chunk_int.bitand(LOW_SIX_BITS) as usize]; + + input_index += Self::ENCODE_INPUT_CHUNK_SIZE; + output_index += 4; + } + } + + // then leftovers + if rem == 2 { + let chunk = &input[input_index..input_index + 2]; + + // high six bits of chunk[0] + output[output_index] = self.encode_table[chunk[0].shr(2) as usize]; + // bottom 2 bits of [0], high 4 bits of [1] + output[output_index + 1] = + self.encode_table[(chunk[0].shl(4_u8).bitor(chunk[1].shr(4_u8)) as u32) + .bitand(LOW_SIX_BITS) as usize]; + // bottom 4 bits of [1], with the 2 bottom bits as zero + output[output_index + 2] = + self.encode_table[(chunk[1].shl(2_u8) as u32).bitand(LOW_SIX_BITS) as usize]; + + output_index += 3; + } else if rem == 1 { + let byte = input[input_index]; + output[output_index] = self.encode_table[byte.shr(2) as usize]; + output[output_index + 1] = + self.encode_table[(byte.shl(4_u8) as u32).bitand(LOW_SIX_BITS) as usize]; + output_index += 2; + } + + output_index + } + + fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate { + NaiveEstimate::new(input_len) + } + + fn internal_decode( + &self, + input: &[u8], + output: &mut [u8], + estimate: Self::DecodeEstimate, + ) -> Result<usize, DecodeError> { + if estimate.rem == 1 { + // trailing whitespace is so common that it's worth it to check the last byte to + // possibly return a better error message + if let Some(b) = input.last() { + if *b != PAD_BYTE + && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE + { + return Err(DecodeError::InvalidByte(input.len() - 1, *b)); + } + } + + return Err(DecodeError::InvalidLength); + } + + let mut input_index = 0_usize; + let mut output_index = 0_usize; + const BOTTOM_BYTE: u32 = 0xFF; + + // can only use the main loop on non-trailing chunks + if input.len() > Self::DECODE_INPUT_CHUNK_SIZE { + // skip the last chunk, whether it's partial or full, since it might + // have padding, and start at the beginning of the chunk before that + let last_complete_chunk_start_index = estimate.complete_chunk_len + - if estimate.rem == 0 { + // Trailing chunk is also full chunk, so there must be at least 2 chunks, and + // this won't underflow + Self::DECODE_INPUT_CHUNK_SIZE * 2 + } else { + // Trailing chunk is partial, so it's already excluded in + // complete_chunk_len + Self::DECODE_INPUT_CHUNK_SIZE + }; + + while input_index <= last_complete_chunk_start_index { + let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE]; + let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) + | self + .decode_byte_into_u32(input_index + 1, chunk[1])? + .shl(12) + | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) + | self.decode_byte_into_u32(input_index + 3, chunk[3])?; + + output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; + + input_index += Self::DECODE_INPUT_CHUNK_SIZE; + output_index += 3; + } + } + + general_purpose::decode_suffix::decode_suffix( + input, + input_index, + output, + output_index, + &self.decode_table, + self.config.decode_allow_trailing_bits, + self.config.decode_padding_mode, + ) + } + + fn config(&self) -> &Self::Config { + &self.config + } +} + +pub struct NaiveEstimate { + /// remainder from dividing input by `Naive::DECODE_CHUNK_SIZE` + rem: usize, + /// Length of input that is in complete `Naive::DECODE_CHUNK_SIZE`-length chunks + complete_chunk_len: usize, +} + +impl NaiveEstimate { + fn new(input_len: usize) -> Self { + let rem = input_len % Naive::DECODE_INPUT_CHUNK_SIZE; + let complete_chunk_len = input_len - rem; + + Self { + rem, + complete_chunk_len, + } + } +} + +impl DecodeEstimate for NaiveEstimate { + fn decoded_len_estimate(&self) -> usize { + ((self.complete_chunk_len / 4) + ((self.rem > 0) as usize)) * 3 + } +} + +#[derive(Clone, Copy, Debug)] +pub struct NaiveConfig { + pub encode_padding: bool, + pub decode_allow_trailing_bits: bool, + pub decode_padding_mode: DecodePaddingMode, +} + +impl Config for NaiveConfig { + fn encode_padding(&self) -> bool { + self.encode_padding + } +} diff --git a/third_party/rust/base64/src/engine/tests.rs b/third_party/rust/base64/src/engine/tests.rs new file mode 100644 index 0000000000..906bba04d8 --- /dev/null +++ b/third_party/rust/base64/src/engine/tests.rs @@ -0,0 +1,1430 @@ +// rstest_reuse template functions have unused variables +#![allow(unused_variables)] + +use rand::{ + self, + distributions::{self, Distribution as _}, + rngs, Rng as _, SeedableRng as _, +}; +use rstest::rstest; +use rstest_reuse::{apply, template}; +use std::{collections, fmt}; + +use crate::{ + alphabet::{Alphabet, STANDARD}, + encode::add_padding, + encoded_len, + engine::{general_purpose, naive, Config, DecodeEstimate, DecodePaddingMode, Engine}, + tests::{assert_encode_sanity, random_alphabet, random_config}, + DecodeError, PAD_BYTE, +}; + +// the case::foo syntax includes the "foo" in the generated test method names +#[template] +#[rstest(engine_wrapper, +case::general_purpose(GeneralPurposeWrapper {}), +case::naive(NaiveWrapper {}), +)] +fn all_engines<E: EngineWrapper>(engine_wrapper: E) {} + +#[apply(all_engines)] +fn rfc_test_vectors_std_alphabet<E: EngineWrapper>(engine_wrapper: E) { + let data = vec![ + ("", ""), + ("f", "Zg=="), + ("fo", "Zm8="), + ("foo", "Zm9v"), + ("foob", "Zm9vYg=="), + ("fooba", "Zm9vYmE="), + ("foobar", "Zm9vYmFy"), + ]; + + let engine = E::standard(); + let engine_no_padding = E::standard_unpadded(); + + for (orig, encoded) in &data { + let encoded_without_padding = encoded.trim_end_matches('='); + + // unpadded + { + let mut encode_buf = [0_u8; 8]; + let mut decode_buf = [0_u8; 6]; + + let encode_len = + engine_no_padding.internal_encode(orig.as_bytes(), &mut encode_buf[..]); + assert_eq!( + &encoded_without_padding, + &std::str::from_utf8(&encode_buf[0..encode_len]).unwrap() + ); + let decode_len = engine_no_padding + .decode_slice_unchecked(encoded_without_padding.as_bytes(), &mut decode_buf[..]) + .unwrap(); + assert_eq!(orig.len(), decode_len); + + assert_eq!( + orig, + &std::str::from_utf8(&decode_buf[0..decode_len]).unwrap() + ); + + // if there was any padding originally, the no padding engine won't decode it + if encoded.as_bytes().contains(&PAD_BYTE) { + assert_eq!( + Err(DecodeError::InvalidPadding), + engine_no_padding.decode(encoded) + ) + } + } + + // padded + { + let mut encode_buf = [0_u8; 8]; + let mut decode_buf = [0_u8; 6]; + + let encode_len = engine.internal_encode(orig.as_bytes(), &mut encode_buf[..]); + assert_eq!( + // doesn't have padding added yet + &encoded_without_padding, + &std::str::from_utf8(&encode_buf[0..encode_len]).unwrap() + ); + let pad_len = add_padding(orig.len(), &mut encode_buf[encode_len..]); + assert_eq!(encoded.as_bytes(), &encode_buf[..encode_len + pad_len]); + + let decode_len = engine + .decode_slice_unchecked(encoded.as_bytes(), &mut decode_buf[..]) + .unwrap(); + assert_eq!(orig.len(), decode_len); + + assert_eq!( + orig, + &std::str::from_utf8(&decode_buf[0..decode_len]).unwrap() + ); + + // if there was (canonical) padding, and we remove it, the standard engine won't decode + if encoded.as_bytes().contains(&PAD_BYTE) { + assert_eq!( + Err(DecodeError::InvalidPadding), + engine.decode(encoded_without_padding) + ) + } + } + } +} + +#[apply(all_engines)] +fn roundtrip_random<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut decode_buf = Vec::<u8>::new(); + + let len_range = distributions::Uniform::new(1, 1_000); + + for _ in 0..10_000 { + let engine = E::random(&mut rng); + + orig_data.clear(); + encode_buf.clear(); + decode_buf.clear(); + + let (orig_len, _, encoded_len) = generate_random_encoded_data( + &engine, + &mut orig_data, + &mut encode_buf, + &mut rng, + &len_range, + ); + + // exactly the right size + decode_buf.resize(orig_len, 0); + + let dec_len = engine + .decode_slice_unchecked(&encode_buf[0..encoded_len], &mut decode_buf[..]) + .unwrap(); + + assert_eq!(orig_len, dec_len); + assert_eq!(&orig_data[..], &decode_buf[..dec_len]); + } +} + +#[apply(all_engines)] +fn encode_doesnt_write_extra_bytes<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut encode_buf_backup = Vec::<u8>::new(); + + let input_len_range = distributions::Uniform::new(0, 1000); + + for _ in 0..10_000 { + let engine = E::random(&mut rng); + let padded = engine.config().encode_padding(); + + orig_data.clear(); + encode_buf.clear(); + encode_buf_backup.clear(); + + let orig_len = fill_rand(&mut orig_data, &mut rng, &input_len_range); + + let prefix_len = 1024; + // plenty of prefix and suffix + fill_rand_len(&mut encode_buf, &mut rng, prefix_len * 2 + orig_len * 2); + encode_buf_backup.extend_from_slice(&encode_buf[..]); + + let expected_encode_len_no_pad = encoded_len(orig_len, false).unwrap(); + + let encoded_len_no_pad = + engine.internal_encode(&orig_data[..], &mut encode_buf[prefix_len..]); + assert_eq!(expected_encode_len_no_pad, encoded_len_no_pad); + + // no writes past what it claimed to write + assert_eq!(&encode_buf_backup[..prefix_len], &encode_buf[..prefix_len]); + assert_eq!( + &encode_buf_backup[(prefix_len + encoded_len_no_pad)..], + &encode_buf[(prefix_len + encoded_len_no_pad)..] + ); + + let encoded_data = &encode_buf[prefix_len..(prefix_len + encoded_len_no_pad)]; + assert_encode_sanity( + std::str::from_utf8(encoded_data).unwrap(), + // engines don't pad + false, + orig_len, + ); + + // pad so we can decode it in case our random engine requires padding + let pad_len = if padded { + add_padding(orig_len, &mut encode_buf[prefix_len + encoded_len_no_pad..]) + } else { + 0 + }; + + assert_eq!( + orig_data, + engine + .decode(&encode_buf[prefix_len..(prefix_len + encoded_len_no_pad + pad_len)],) + .unwrap() + ); + } +} + +#[apply(all_engines)] +fn encode_engine_slice_fits_into_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E) { + let mut orig_data = Vec::new(); + let mut encoded_data = Vec::new(); + let mut decoded = Vec::new(); + + let input_len_range = distributions::Uniform::new(0, 1000); + + let mut rng = rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + decoded.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let engine = E::random(&mut rng); + + let encoded_size = encoded_len(input_len, engine.config().encode_padding()).unwrap(); + + encoded_data.resize(encoded_size, 0); + + assert_eq!( + encoded_size, + engine.encode_slice(&orig_data, &mut encoded_data).unwrap() + ); + + assert_encode_sanity( + std::str::from_utf8(&encoded_data[0..encoded_size]).unwrap(), + engine.config().encode_padding(), + input_len, + ); + + engine + .decode_vec(&encoded_data[0..encoded_size], &mut decoded) + .unwrap(); + assert_eq!(orig_data, decoded); + } +} + +#[apply(all_engines)] +fn decode_doesnt_write_extra_bytes<E>(engine_wrapper: E) +where + E: EngineWrapper, + <<E as EngineWrapper>::Engine as Engine>::Config: fmt::Debug, +{ + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut decode_buf = Vec::<u8>::new(); + let mut decode_buf_backup = Vec::<u8>::new(); + + let len_range = distributions::Uniform::new(1, 1_000); + + for _ in 0..10_000 { + let engine = E::random(&mut rng); + + orig_data.clear(); + encode_buf.clear(); + decode_buf.clear(); + decode_buf_backup.clear(); + + let orig_len = fill_rand(&mut orig_data, &mut rng, &len_range); + encode_buf.resize(orig_len * 2 + 100, 0); + + let encoded_len = engine + .encode_slice(&orig_data[..], &mut encode_buf[..]) + .unwrap(); + encode_buf.truncate(encoded_len); + + // oversize decode buffer so we can easily tell if it writes anything more than + // just the decoded data + let prefix_len = 1024; + // plenty of prefix and suffix + fill_rand_len(&mut decode_buf, &mut rng, prefix_len * 2 + orig_len * 2); + decode_buf_backup.extend_from_slice(&decode_buf[..]); + + let dec_len = engine + .decode_slice_unchecked(&encode_buf, &mut decode_buf[prefix_len..]) + .unwrap(); + + assert_eq!(orig_len, dec_len); + assert_eq!( + &orig_data[..], + &decode_buf[prefix_len..prefix_len + dec_len] + ); + assert_eq!(&decode_buf_backup[..prefix_len], &decode_buf[..prefix_len]); + assert_eq!( + &decode_buf_backup[prefix_len + dec_len..], + &decode_buf[prefix_len + dec_len..] + ); + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol<E: EngineWrapper>(engine_wrapper: E) { + // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol + let engine = E::standard(); + + assert_eq!(Ok(vec![0x89, 0x85]), engine.decode("iYU=")); + assert_eq!(Ok(vec![0xFF]), engine.decode("/w==")); + + for (suffix, offset) in vec![ + // suffix, offset of bad byte from start of suffix + ("/x==", 1_usize), + ("/z==", 1_usize), + ("/0==", 1_usize), + ("/9==", 1_usize), + ("/+==", 1_usize), + ("//==", 1_usize), + // trailing 01 + ("iYV=", 2_usize), + // trailing 10 + ("iYW=", 2_usize), + // trailing 11 + ("iYX=", 2_usize), + ] { + for prefix_quads in 0..256 { + let mut encoded = "AAAA".repeat(prefix_quads); + encoded.push_str(suffix); + + assert_eq!( + Err(DecodeError::InvalidLastSymbol( + encoded.len() - 4 + offset, + suffix.as_bytes()[offset], + )), + engine.decode(encoded.as_str()) + ); + } + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol_when_length_is_also_invalid<E: EngineWrapper>( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + + // check across enough lengths that it would likely cover any implementation's various internal + // small/large input division + for len in (0_usize..256).map(|len| len * 4 + 1) { + let engine = E::random_alphabet(&mut rng, &STANDARD); + + let mut input = vec![b'A'; len]; + + // with a valid last char, it's InvalidLength + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input)); + // after mangling the last char, it's InvalidByte + input[len - 1] = b'"'; + assert_eq!( + Err(DecodeError::InvalidByte(len - 1, b'"')), + engine.decode(&input) + ); + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol_every_possible_two_symbols<E: EngineWrapper>( + engine_wrapper: E, +) { + let engine = E::standard(); + + let mut base64_to_bytes = collections::HashMap::new(); + + for b in 0_u8..=255 { + let mut b64 = vec![0_u8; 4]; + assert_eq!(2, engine.internal_encode(&[b], &mut b64[..])); + let _ = add_padding(1, &mut b64[2..]); + + assert!(base64_to_bytes.insert(b64, vec![b]).is_none()); + } + + // every possible combination of trailing symbols must either decode to 1 byte or get InvalidLastSymbol, with or without any leading chunks + + let mut prefix = Vec::new(); + for _ in 0..256 { + let mut clone = prefix.clone(); + + let mut symbols = [0_u8; 4]; + for &s1 in STANDARD.symbols.iter() { + symbols[0] = s1; + for &s2 in STANDARD.symbols.iter() { + symbols[1] = s2; + symbols[2] = PAD_BYTE; + symbols[3] = PAD_BYTE; + + // chop off previous symbols + clone.truncate(prefix.len()); + clone.extend_from_slice(&symbols[..]); + let decoded_prefix_len = prefix.len() / 4 * 3; + + match base64_to_bytes.get(&symbols[..]) { + Some(bytes) => { + let res = engine + .decode(&clone) + // remove prefix + .map(|decoded| decoded[decoded_prefix_len..].to_vec()); + + assert_eq!(Ok(bytes.clone()), res); + } + None => assert_eq!( + Err(DecodeError::InvalidLastSymbol(1, s2)), + engine.decode(&symbols[..]) + ), + } + } + } + + prefix.extend_from_slice(b"AAAA"); + } +} + +#[apply(all_engines)] +fn decode_detect_invalid_last_symbol_every_possible_three_symbols<E: EngineWrapper>( + engine_wrapper: E, +) { + let engine = E::standard(); + + let mut base64_to_bytes = collections::HashMap::new(); + + let mut bytes = [0_u8; 2]; + for b1 in 0_u8..=255 { + bytes[0] = b1; + for b2 in 0_u8..=255 { + bytes[1] = b2; + let mut b64 = vec![0_u8; 4]; + assert_eq!(3, engine.internal_encode(&bytes, &mut b64[..])); + let _ = add_padding(2, &mut b64[3..]); + + let mut v = Vec::with_capacity(2); + v.extend_from_slice(&bytes[..]); + + assert!(base64_to_bytes.insert(b64, v).is_none()); + } + } + + // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol, with or without any leading chunks + + let mut prefix = Vec::new(); + for _ in 0..256 { + let mut input = prefix.clone(); + + let mut symbols = [0_u8; 4]; + for &s1 in STANDARD.symbols.iter() { + symbols[0] = s1; + for &s2 in STANDARD.symbols.iter() { + symbols[1] = s2; + for &s3 in STANDARD.symbols.iter() { + symbols[2] = s3; + symbols[3] = PAD_BYTE; + + // chop off previous symbols + input.truncate(prefix.len()); + input.extend_from_slice(&symbols[..]); + let decoded_prefix_len = prefix.len() / 4 * 3; + + match base64_to_bytes.get(&symbols[..]) { + Some(bytes) => { + let res = engine + .decode(&input) + // remove prefix + .map(|decoded| decoded[decoded_prefix_len..].to_vec()); + + assert_eq!(Ok(bytes.clone()), res); + } + None => assert_eq!( + Err(DecodeError::InvalidLastSymbol(2, s3)), + engine.decode(&symbols[..]) + ), + } + } + } + } + prefix.extend_from_slice(b"AAAA"); + } +} + +#[apply(all_engines)] +fn decode_invalid_trailing_bits_ignored_when_configured<E: EngineWrapper>(engine_wrapper: E) { + let strict = E::standard(); + let forgiving = E::standard_allow_trailing_bits(); + + fn assert_tolerant_decode<E: Engine>( + engine: &E, + input: &mut String, + b64_prefix_len: usize, + expected_decode_bytes: Vec<u8>, + data: &str, + ) { + let prefixed = prefixed_data(input, b64_prefix_len, data); + let decoded = engine.decode(prefixed); + // prefix is always complete chunks + let decoded_prefix_len = b64_prefix_len / 4 * 3; + assert_eq!( + Ok(expected_decode_bytes), + decoded.map(|v| v[decoded_prefix_len..].to_vec()) + ); + } + + let mut prefix = String::new(); + for _ in 0..256 { + let mut input = prefix.clone(); + + // example from https://github.com/marshallpierce/rust-base64/issues/75 + assert!(strict + .decode(prefixed_data(&mut input, prefix.len(), "/w==")) + .is_ok()); + assert!(strict + .decode(prefixed_data(&mut input, prefix.len(), "iYU=")) + .is_ok()); + // trailing 01 + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/x=="); + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYV="); + // trailing 10 + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/y=="); + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYW="); + // trailing 11 + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/z=="); + assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYX="); + + prefix.push_str("AAAA"); + } +} + +#[apply(all_engines)] +fn decode_invalid_byte_error<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + let mut orig_data = Vec::<u8>::new(); + let mut encode_buf = Vec::<u8>::new(); + let mut decode_buf = Vec::<u8>::new(); + + let len_range = distributions::Uniform::new(1, 1_000); + + for _ in 0..10_000 { + let alphabet = random_alphabet(&mut rng); + let engine = E::random_alphabet(&mut rng, alphabet); + + orig_data.clear(); + encode_buf.clear(); + decode_buf.clear(); + + let (orig_len, encoded_len_just_data, encoded_len_with_padding) = + generate_random_encoded_data( + &engine, + &mut orig_data, + &mut encode_buf, + &mut rng, + &len_range, + ); + + // exactly the right size + decode_buf.resize(orig_len, 0); + + // replace one encoded byte with an invalid byte + let invalid_byte: u8 = loop { + let byte: u8 = rng.gen(); + + if alphabet.symbols.contains(&byte) { + continue; + } else { + break byte; + } + }; + + let invalid_range = distributions::Uniform::new(0, orig_len); + let invalid_index = invalid_range.sample(&mut rng); + encode_buf[invalid_index] = invalid_byte; + + assert_eq!( + Err(DecodeError::InvalidByte(invalid_index, invalid_byte)), + engine.decode_slice_unchecked( + &encode_buf[0..encoded_len_with_padding], + &mut decode_buf[..], + ) + ); + } +} + +/// Any amount of padding anywhere before the final non padding character = invalid byte at first +/// pad byte. +/// From this, we know padding must extend to the end of the input. +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte<E: EngineWrapper>( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + + // the different amounts of proper padding, w/ offset from end for the last non-padding char + let suffixes = vec![("/w==", 2), ("iYu=", 1), ("zzzz", 0)]; + + let prefix_quads_range = distributions::Uniform::from(0..=256); + + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for _ in 0..100_000 { + for (suffix, offset) in suffixes.iter() { + let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng)); + s.push_str(suffix); + let mut encoded = s.into_bytes(); + + // calculate a range to write padding into that leaves at least one non padding char + let last_non_padding_offset = encoded.len() - 1 - offset; + + // don't include last non padding char as it must stay not padding + let padding_end = rng.gen_range(0..last_non_padding_offset); + + // don't use more than 100 bytes of padding, but also use shorter lengths when + // padding_end is near the start of the encoded data to avoid biasing to padding + // the entire prefix on short lengths + let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); + let padding_start = padding_end.saturating_sub(padding_len); + + encoded[padding_start..=padding_end].fill(PAD_BYTE); + + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + ); + } + } + } +} + +/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes = +/// invalid byte at first pad byte (except for 1 byte suffix = invalid length). +/// From this we know the padding must start in the final chunk. +#[apply(all_engines)] +fn decode_padding_starts_before_final_chunk_error_invalid_byte<E: EngineWrapper>( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + + // must have at least one prefix quad + let prefix_quads_range = distributions::Uniform::from(1..256); + // including 1 just to make sure that it really does produce invalid length + let suffix_pad_len_range = distributions::Uniform::from(1..=4); + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + for _ in 0..100_000 { + let suffix_len = suffix_pad_len_range.sample(&mut rng); + let mut encoded = "ABCD" + .repeat(prefix_quads_range.sample(&mut rng)) + .into_bytes(); + encoded.resize(encoded.len() + suffix_len, PAD_BYTE); + + // amount of padding must be long enough to extend back from suffix into previous + // quads + let padding_len = rng.gen_range(suffix_len + 1..encoded.len()); + // no non-padding after padding in this test, so padding goes to the end + let padding_start = encoded.len() - padding_len; + encoded[padding_start..].fill(PAD_BYTE); + + if suffix_len == 1 { + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); + } else { + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + ); + } + } + } +} + +/// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding +/// is not valid data (consistent with error for pad bytes in earlier chunks). +/// From this we know there must be 2-3 bytes of data before padding +#[apply(all_engines)] +fn decode_too_little_data_before_padding_error_invalid_byte<E: EngineWrapper>(engine_wrapper: E) { + let mut rng = seeded_rng(); + + // want to test no prefix quad case, so start at 0 + let prefix_quads_range = distributions::Uniform::from(0_usize..256); + let suffix_data_len_range = distributions::Uniform::from(0_usize..=1); + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + for _ in 0..100_000 { + let suffix_data_len = suffix_data_len_range.sample(&mut rng); + let prefix_quad_len = prefix_quads_range.sample(&mut rng); + + // ensure there is a suffix quad + let min_padding = usize::from(suffix_data_len == 0); + + // for all possible padding lengths + for padding_len in min_padding..=(4 - suffix_data_len) { + let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes(); + encoded.resize(encoded.len() + suffix_data_len, b'A'); + encoded.resize(encoded.len() + padding_len, PAD_BYTE); + + if suffix_data_len + padding_len == 1 { + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); + } else { + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.decode(&encoded), + "suffix data len {} pad len {}", + suffix_data_len, + padding_len + ); + } + } + } + } +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 1 +#[apply(all_engines)] +fn decode_malleability_test_case_3_byte_suffix_valid<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + b"Hello".as_slice(), + &E::standard().decode("SGVsbG8=").unwrap() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 2 +#[apply(all_engines)] +fn decode_malleability_test_case_3_byte_suffix_invalid_trailing_symbol<E: EngineWrapper>( + engine_wrapper: E, +) { + assert_eq!( + DecodeError::InvalidLastSymbol(6, 0x39), + E::standard().decode("SGVsbG9=").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 3 +#[apply(all_engines)] +fn decode_malleability_test_case_3_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + DecodeError::InvalidPadding, + E::standard().decode("SGVsbG9").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 4 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_valid_two_padding_symbols<E: EngineWrapper>( + engine_wrapper: E, +) { + assert_eq!( + b"Hell".as_slice(), + &E::standard().decode("SGVsbA==").unwrap() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 5 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_short_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + DecodeError::InvalidPadding, + E::standard().decode("SGVsbA=").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 6 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_eq!( + DecodeError::InvalidPadding, + E::standard().decode("SGVsbA").unwrap_err() + ); +} + +// https://eprint.iacr.org/2022/361.pdf table 2, test 7 +#[apply(all_engines)] +fn decode_malleability_test_case_2_byte_suffix_too_much_padding<E: EngineWrapper>( + engine_wrapper: E, +) { + assert_eq!( + DecodeError::InvalidByte(6, PAD_BYTE), + E::standard().decode("SGVsbA====").unwrap_err() + ); +} + +/// Requires canonical padding -> accepts 2 + 2, 3 + 1, 4 + 0 final quad configurations +#[apply(all_engines)] +fn decode_pad_mode_requires_canonical_accepts_canonical<E: EngineWrapper>(engine_wrapper: E) { + assert_all_suffixes_ok( + E::standard_with_pad_mode(true, DecodePaddingMode::RequireCanonical), + vec!["/w==", "iYU=", "AAAA"], + ); +} + +/// Requires canonical padding -> rejects 2 + 0-1, 3 + 0 final chunk configurations +#[apply(all_engines)] +fn decode_pad_mode_requires_canonical_rejects_non_canonical<E: EngineWrapper>(engine_wrapper: E) { + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::RequireCanonical); + + let suffixes = vec!["/w", "/w=", "iYU"]; + for num_prefix_quads in 0..256 { + for &suffix in suffixes.iter() { + let mut encoded = "AAAA".repeat(num_prefix_quads); + encoded.push_str(suffix); + + let res = engine.decode(&encoded); + + assert_eq!(Err(DecodeError::InvalidPadding), res); + } + } +} + +/// Requires no padding -> accepts 2 + 0, 3 + 0, 4 + 0 final chunk configuration +#[apply(all_engines)] +fn decode_pad_mode_requires_no_padding_accepts_no_padding<E: EngineWrapper>(engine_wrapper: E) { + assert_all_suffixes_ok( + E::standard_with_pad_mode(true, DecodePaddingMode::RequireNone), + vec!["/w", "iYU", "AAAA"], + ); +} + +/// Requires no padding -> rejects 2 + 1-2, 3 + 1 final chunk configuration +#[apply(all_engines)] +fn decode_pad_mode_requires_no_padding_rejects_any_padding<E: EngineWrapper>(engine_wrapper: E) { + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::RequireNone); + + let suffixes = vec!["/w=", "/w==", "iYU="]; + for num_prefix_quads in 0..256 { + for &suffix in suffixes.iter() { + let mut encoded = "AAAA".repeat(num_prefix_quads); + encoded.push_str(suffix); + + let res = engine.decode(&encoded); + + assert_eq!(Err(DecodeError::InvalidPadding), res); + } + } +} + +/// Indifferent padding accepts 2 + 0-2, 3 + 0-1, 4 + 0 final chunk configuration +#[apply(all_engines)] +fn decode_pad_mode_indifferent_padding_accepts_anything<E: EngineWrapper>(engine_wrapper: E) { + assert_all_suffixes_ok( + E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent), + vec!["/w", "/w=", "/w==", "iYU", "iYU=", "AAAA"], + ); +} + +//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3 +#[apply(all_engines)] +fn decode_pad_byte_in_penultimate_quad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // leave room for at least one pad byte in penultimate quad + for num_valid_bytes_penultimate_quad in 0..4 { + // can't have 1 or it would be invalid length + for num_pad_bytes_in_final_quad in 2..=4 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + + // varying amounts of padding in the penultimate quad + for _ in 0..num_valid_bytes_penultimate_quad { + s.push('A'); + } + // finish penultimate quad with padding + for _ in num_valid_bytes_penultimate_quad..4 { + s.push('='); + } + // and more padding in the final quad + for _ in 0..num_pad_bytes_in_final_quad { + s.push('='); + } + + // padding should be an invalid byte before the final quad. + // Could argue that the *next* padding byte (in the next quad) is technically the first + // erroneous one, but reporting that accurately is more complex and probably nobody cares + assert_eq!( + DecodeError::InvalidByte( + num_prefix_quads * 4 + num_valid_bytes_penultimate_quad, + b'=', + ), + engine.decode(&s).unwrap_err() + ); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_bytes_after_padding_in_final_quad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // leave at least one byte in the quad for padding + for bytes_after_padding in 1..4 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + + // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding + for _ in 0..(3 - bytes_after_padding) { + s.push('A'); + } + s.push('='); + for _ in 0..bytes_after_padding { + s.push('A'); + } + + // First (and only) padding byte is invalid. + assert_eq!( + DecodeError::InvalidByte( + num_prefix_quads * 4 + (3 - bytes_after_padding), + b'=' + ), + engine.decode(&s).unwrap_err() + ); + } + } + } +} + +#[apply(all_engines)] +fn decode_absurd_pad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("==Y=Wx===pY=2U====="); + + // first padding byte + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4, b'='), + engine.decode(&s).unwrap_err() + ); + } + } +} + +#[apply(all_engines)] +fn decode_too_much_padding_returns_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // add enough padding to ensure that we'll hit all decode stages at the different lengths + for pad_bytes in 1..=64 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + let padding: String = "=".repeat(pad_bytes); + s.push_str(&padding); + + if pad_bytes % 4 == 1 { + assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); + } else { + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4, b'='), + engine.decode(&s).unwrap_err() + ); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_padding_followed_by_non_padding_returns_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + for pad_bytes in 0..=32 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + let padding: String = "=".repeat(pad_bytes); + s.push_str(&padding); + s.push('E'); + + if pad_bytes % 4 == 0 { + assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); + } else { + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4, b'='), + engine.decode(&s).unwrap_err() + ); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_one_char_in_final_quad_with_padding_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("E="); + + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), + engine.decode(&s).unwrap_err() + ); + + // more padding doesn't change the error + s.push('='); + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), + engine.decode(&s).unwrap_err() + ); + + s.push('='); + assert_eq!( + DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), + engine.decode(&s).unwrap_err() + ); + } + } +} + +#[apply(all_engines)] +fn decode_too_few_symbols_in_final_quad_error<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + // <2 is invalid + for final_quad_symbols in 0..2 { + for padding_symbols in 0..=(4 - final_quad_symbols) { + let mut s: String = "ABCD".repeat(num_prefix_quads); + + for _ in 0..final_quad_symbols { + s.push('A'); + } + for _ in 0..padding_symbols { + s.push('='); + } + + match final_quad_symbols + padding_symbols { + 0 => continue, + 1 => { + assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); + } + _ => { + // error reported at first padding byte + assert_eq!( + DecodeError::InvalidByte( + num_prefix_quads * 4 + final_quad_symbols, + b'=', + ), + engine.decode(&s).unwrap_err() + ); + } + } + } + } + } + } +} + +#[apply(all_engines)] +fn decode_invalid_trailing_bytes<E: EngineWrapper>(engine_wrapper: E) { + for mode in all_pad_modes() { + // we don't encode so we don't care about encode padding + let engine = E::standard_with_pad_mode(true, mode); + + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("Cg==\n"); + + // The case of trailing newlines is common enough to warrant a test for a good error + // message. + assert_eq!( + Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')), + engine.decode(&s) + ); + + // extra padding, however, is still InvalidLength + let s = s.replace('\n', "="); + assert_eq!(Err(DecodeError::InvalidLength), engine.decode(s)); + } + } +} + +#[apply(all_engines)] +fn decode_wrong_length_error<E: EngineWrapper>(engine_wrapper: E) { + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); + + for num_prefix_quads in 0..256 { + // at least one token, otherwise it wouldn't be a final quad + for num_tokens_final_quad in 1..=4 { + for num_padding in 0..=(4 - num_tokens_final_quad) { + let mut s: String = "IIII".repeat(num_prefix_quads); + for _ in 0..num_tokens_final_quad { + s.push('g'); + } + for _ in 0..num_padding { + s.push('='); + } + + let res = engine.decode(&s); + if num_tokens_final_quad >= 2 { + assert!(res.is_ok()); + } else if num_tokens_final_quad == 1 && num_padding > 0 { + // = is invalid if it's too early + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + num_tokens_final_quad, + 61 + )), + res + ); + } else if num_padding > 2 { + assert_eq!(Err(DecodeError::InvalidPadding), res); + } else { + assert_eq!(Err(DecodeError::InvalidLength), res); + } + } + } + } +} + +#[apply(all_engines)] +fn decode_into_slice_fits_in_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E) { + let mut orig_data = Vec::new(); + let mut encoded_data = String::new(); + let mut decode_buf = Vec::new(); + + let input_len_range = distributions::Uniform::new(0, 1000); + let mut rng = rngs::SmallRng::from_entropy(); + + for _ in 0..10_000 { + orig_data.clear(); + encoded_data.clear(); + decode_buf.clear(); + + let input_len = input_len_range.sample(&mut rng); + + for _ in 0..input_len { + orig_data.push(rng.gen()); + } + + let engine = E::random(&mut rng); + engine.encode_string(&orig_data, &mut encoded_data); + assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); + + decode_buf.resize(input_len, 0); + + // decode into the non-empty buf + let decode_bytes_written = engine + .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) + .unwrap(); + + assert_eq!(orig_data.len(), decode_bytes_written); + assert_eq!(orig_data, decode_buf); + } +} + +#[apply(all_engines)] +fn decode_length_estimate_delta<E: EngineWrapper>(engine_wrapper: E) { + for engine in [E::standard(), E::standard_unpadded()] { + for &padding in &[true, false] { + for orig_len in 0..1000 { + let encoded_len = encoded_len(orig_len, padding).unwrap(); + + let decoded_estimate = engine + .internal_decoded_len_estimate(encoded_len) + .decoded_len_estimate(); + assert!(decoded_estimate >= orig_len); + assert!( + decoded_estimate - orig_len < 3, + "estimate: {}, encoded: {}, orig: {}", + decoded_estimate, + encoded_len, + orig_len + ); + } + } + } +} + +/// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. +/// +/// Vecs provided should be empty. +fn generate_random_encoded_data<E: Engine, R: rand::Rng, D: distributions::Distribution<usize>>( + engine: &E, + orig_data: &mut Vec<u8>, + encode_buf: &mut Vec<u8>, + rng: &mut R, + length_distribution: &D, +) -> (usize, usize, usize) { + let padding: bool = engine.config().encode_padding(); + + let orig_len = fill_rand(orig_data, rng, length_distribution); + let expected_encoded_len = encoded_len(orig_len, padding).unwrap(); + encode_buf.resize(expected_encoded_len, 0); + + let base_encoded_len = engine.internal_encode(&orig_data[..], &mut encode_buf[..]); + + let enc_len_with_padding = if padding { + base_encoded_len + add_padding(orig_len, &mut encode_buf[base_encoded_len..]) + } else { + base_encoded_len + }; + + assert_eq!(expected_encoded_len, enc_len_with_padding); + + (orig_len, base_encoded_len, enc_len_with_padding) +} + +// fill to a random length +fn fill_rand<R: rand::Rng, D: distributions::Distribution<usize>>( + vec: &mut Vec<u8>, + rng: &mut R, + length_distribution: &D, +) -> usize { + let len = length_distribution.sample(rng); + for _ in 0..len { + vec.push(rng.gen()); + } + + len +} + +fn fill_rand_len<R: rand::Rng>(vec: &mut Vec<u8>, rng: &mut R, len: usize) { + for _ in 0..len { + vec.push(rng.gen()); + } +} + +fn prefixed_data<'i, 'd>( + input_with_prefix: &'i mut String, + prefix_len: usize, + data: &'d str, +) -> &'i str { + input_with_prefix.truncate(prefix_len); + input_with_prefix.push_str(data); + input_with_prefix.as_str() +} + +/// A wrapper to make using engines in rstest fixtures easier. +/// The functions don't need to be instance methods, but rstest does seem +/// to want an instance, so instances are passed to test functions and then ignored. +trait EngineWrapper { + type Engine: Engine; + + /// Return an engine configured for RFC standard base64 + fn standard() -> Self::Engine; + + /// Return an engine configured for RFC standard base64, except with no padding appended on + /// encode, and required no padding on decode. + fn standard_unpadded() -> Self::Engine; + + /// Return an engine configured for RFC standard alphabet with the provided encode and decode + /// pad settings + fn standard_with_pad_mode(encode_pad: bool, decode_pad_mode: DecodePaddingMode) + -> Self::Engine; + + /// Return an engine configured for RFC standard base64 that allows invalid trailing bits + fn standard_allow_trailing_bits() -> Self::Engine; + + /// Return an engine configured with a randomized alphabet and config + fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine; + + /// Return an engine configured with the specified alphabet and randomized config + fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine; +} + +struct GeneralPurposeWrapper {} + +impl EngineWrapper for GeneralPurposeWrapper { + type Engine = general_purpose::GeneralPurpose; + + fn standard() -> Self::Engine { + general_purpose::GeneralPurpose::new(&STANDARD, general_purpose::PAD) + } + + fn standard_unpadded() -> Self::Engine { + general_purpose::GeneralPurpose::new(&STANDARD, general_purpose::NO_PAD) + } + + fn standard_with_pad_mode( + encode_pad: bool, + decode_pad_mode: DecodePaddingMode, + ) -> Self::Engine { + general_purpose::GeneralPurpose::new( + &STANDARD, + general_purpose::GeneralPurposeConfig::new() + .with_encode_padding(encode_pad) + .with_decode_padding_mode(decode_pad_mode), + ) + } + + fn standard_allow_trailing_bits() -> Self::Engine { + general_purpose::GeneralPurpose::new( + &STANDARD, + general_purpose::GeneralPurposeConfig::new().with_decode_allow_trailing_bits(true), + ) + } + + fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine { + let alphabet = random_alphabet(rng); + + Self::random_alphabet(rng, alphabet) + } + + fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine { + general_purpose::GeneralPurpose::new(alphabet, random_config(rng)) + } +} + +struct NaiveWrapper {} + +impl EngineWrapper for NaiveWrapper { + type Engine = naive::Naive; + + fn standard() -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: true, + decode_allow_trailing_bits: false, + decode_padding_mode: DecodePaddingMode::RequireCanonical, + }, + ) + } + + fn standard_unpadded() -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: false, + decode_allow_trailing_bits: false, + decode_padding_mode: DecodePaddingMode::RequireNone, + }, + ) + } + + fn standard_with_pad_mode( + encode_pad: bool, + decode_pad_mode: DecodePaddingMode, + ) -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: false, + decode_allow_trailing_bits: false, + decode_padding_mode: decode_pad_mode, + }, + ) + } + + fn standard_allow_trailing_bits() -> Self::Engine { + naive::Naive::new( + &STANDARD, + naive::NaiveConfig { + encode_padding: true, + decode_allow_trailing_bits: true, + decode_padding_mode: DecodePaddingMode::RequireCanonical, + }, + ) + } + + fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine { + let alphabet = random_alphabet(rng); + + Self::random_alphabet(rng, alphabet) + } + + fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine { + let mode = rng.gen(); + + let config = naive::NaiveConfig { + encode_padding: match mode { + DecodePaddingMode::Indifferent => rng.gen(), + DecodePaddingMode::RequireCanonical => true, + DecodePaddingMode::RequireNone => false, + }, + decode_allow_trailing_bits: rng.gen(), + decode_padding_mode: mode, + }; + + naive::Naive::new(alphabet, config) + } +} + +fn seeded_rng() -> impl rand::Rng { + rngs::SmallRng::from_entropy() +} + +fn all_pad_modes() -> Vec<DecodePaddingMode> { + vec![ + DecodePaddingMode::Indifferent, + DecodePaddingMode::RequireCanonical, + DecodePaddingMode::RequireNone, + ] +} + +fn assert_all_suffixes_ok<E: Engine>(engine: E, suffixes: Vec<&str>) { + for num_prefix_quads in 0..256 { + for &suffix in suffixes.iter() { + let mut encoded = "AAAA".repeat(num_prefix_quads); + encoded.push_str(suffix); + + let res = &engine.decode(&encoded); + assert!(res.is_ok()); + } + } +} diff --git a/third_party/rust/base64/src/lib.rs b/third_party/rust/base64/src/lib.rs new file mode 100644 index 0000000000..cc9d628df6 --- /dev/null +++ b/third_party/rust/base64/src/lib.rs @@ -0,0 +1,179 @@ +//! # Getting started +//! +//! 1. Perhaps one of the preconfigured engines in [engine::general_purpose] will suit, e.g. +//! [engine::general_purpose::STANDARD_NO_PAD]. +//! - These are re-exported in [prelude] with a `BASE64_` prefix for those who prefer to +//! `use base64::prelude::*` or equivalent, e.g. [prelude::BASE64_STANDARD_NO_PAD] +//! 1. If not, choose which alphabet you want. Most usage will want [alphabet::STANDARD] or [alphabet::URL_SAFE]. +//! 1. Choose which [Engine] implementation you want. For the moment there is only one: [engine::GeneralPurpose]. +//! 1. Configure the engine appropriately using the engine's `Config` type. +//! - This is where you'll select whether to add padding (when encoding) or expect it (when +//! decoding). If given the choice, prefer no padding. +//! 1. Build the engine using the selected alphabet and config. +//! +//! For more detail, see below. +//! +//! ## Alphabets +//! +//! An [alphabet::Alphabet] defines what ASCII symbols are used to encode to or decode from. +//! +//! Constants in [alphabet] like [alphabet::STANDARD] or [alphabet::URL_SAFE] provide commonly used +//! alphabets, but you can also build your own custom [alphabet::Alphabet] if needed. +//! +//! ## Engines +//! +//! Once you have an `Alphabet`, you can pick which `Engine` you want. A few parts of the public +//! API provide a default, but otherwise the user must provide an `Engine` to use. +//! +//! See [Engine] for more. +//! +//! ## Config +//! +//! In addition to an `Alphabet`, constructing an `Engine` also requires an [engine::Config]. Each +//! `Engine` has a corresponding `Config` implementation since different `Engine`s may offer different +//! levels of configurability. +//! +//! # Encoding +//! +//! Several different encoding methods on [Engine] are available to you depending on your desire for +//! convenience vs performance. +//! +//! | Method | Output | Allocates | +//! | ------------------------ | ---------------------------- | ------------------------------ | +//! | [Engine::encode] | Returns a new `String` | Always | +//! | [Engine::encode_string] | Appends to provided `String` | Only if `String` needs to grow | +//! | [Engine::encode_slice] | Writes to provided `&[u8]` | Never - fastest | +//! +//! All of the encoding methods will pad as per the engine's config. +//! +//! # Decoding +//! +//! Just as for encoding, there are different decoding methods available. +//! +//! | Method | Output | Allocates | +//! | ------------------------ | ----------------------------- | ------------------------------ | +//! | [Engine::decode] | Returns a new `Vec<u8>` | Always | +//! | [Engine::decode_vec] | Appends to provided `Vec<u8>` | Only if `Vec` needs to grow | +//! | [Engine::decode_slice] | Writes to provided `&[u8]` | Never - fastest | +//! +//! Unlike encoding, where all possible input is valid, decoding can fail (see [DecodeError]). +//! +//! Input can be invalid because it has invalid characters or invalid padding. The nature of how +//! padding is checked depends on the engine's config. +//! Whitespace in the input is invalid, just like any other non-base64 byte. +//! +//! # `Read` and `Write` +//! +//! To decode a [std::io::Read] of b64 bytes, wrap a reader (file, network socket, etc) with +//! [read::DecoderReader]. +//! +//! To write raw bytes and have them b64 encoded on the fly, wrap a [std::io::Write] with +//! [write::EncoderWriter]. +//! +//! There is some performance overhead (15% or so) because of the necessary buffer shuffling -- +//! still fast enough that almost nobody cares. Also, these implementations do not heap allocate. +//! +//! # `Display` +//! +//! See [display] for how to transparently base64 data via a `Display` implementation. +//! +//! # Examples +//! +//! ## Using predefined engines +//! +//! ``` +//! use base64::{Engine as _, engine::general_purpose}; +//! +//! let orig = b"data"; +//! let encoded: String = general_purpose::STANDARD_NO_PAD.encode(orig); +//! assert_eq!("ZGF0YQ", encoded); +//! assert_eq!(orig.as_slice(), &general_purpose::STANDARD_NO_PAD.decode(encoded).unwrap()); +//! +//! // or, URL-safe +//! let encoded_url = general_purpose::URL_SAFE_NO_PAD.encode(orig); +//! ``` +//! +//! ## Custom alphabet, config, and engine +//! +//! ``` +//! use base64::{engine, alphabet, Engine as _}; +//! +//! // bizarro-world base64: +/ as the first symbols instead of the last +//! let alphabet = +//! alphabet::Alphabet::new("+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") +//! .unwrap(); +//! +//! // a very weird config that encodes with padding but requires no padding when decoding...? +//! let crazy_config = engine::GeneralPurposeConfig::new() +//! .with_decode_allow_trailing_bits(true) +//! .with_encode_padding(true) +//! .with_decode_padding_mode(engine::DecodePaddingMode::RequireNone); +//! +//! let crazy_engine = engine::GeneralPurpose::new(&alphabet, crazy_config); +//! +//! let encoded = crazy_engine.encode(b"abc 123"); +//! +//! ``` +//! +//! # Panics +//! +//! If length calculations result in overflowing `usize`, a panic will result. + +#![cfg_attr(feature = "cargo-clippy", allow(clippy::cast_lossless))] +#![deny( + missing_docs, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + unused_results, + variant_size_differences, + warnings +)] +#![forbid(unsafe_code)] +// Allow globally until https://github.com/rust-lang/rust-clippy/issues/8768 is resolved. +// The desired state is to allow it only for the rstest_reuse import. +#![allow(clippy::single_component_path_imports)] +#![cfg_attr(not(any(feature = "std", test)), no_std)] + +#[cfg(all(feature = "alloc", not(any(feature = "std", test))))] +extern crate alloc; +#[cfg(any(feature = "std", test))] +extern crate std as alloc; + +// has to be included at top level because of the way rstest_reuse defines its macros +#[cfg(test)] +use rstest_reuse; + +mod chunked_encoder; +pub mod display; +#[cfg(any(feature = "std", test))] +pub mod read; +#[cfg(any(feature = "std", test))] +pub mod write; + +pub mod engine; +pub use engine::Engine; + +pub mod alphabet; + +mod encode; +#[allow(deprecated)] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub use crate::encode::{encode, encode_engine, encode_engine_string}; +#[allow(deprecated)] +pub use crate::encode::{encode_engine_slice, encoded_len, EncodeSliceError}; + +mod decode; +#[allow(deprecated)] +#[cfg(any(feature = "alloc", feature = "std", test))] +pub use crate::decode::{decode, decode_engine, decode_engine_vec}; +#[allow(deprecated)] +pub use crate::decode::{decode_engine_slice, decoded_len_estimate, DecodeError, DecodeSliceError}; + +pub mod prelude; + +#[cfg(test)] +mod tests; + +const PAD_BYTE: u8 = b'='; diff --git a/third_party/rust/base64/src/prelude.rs b/third_party/rust/base64/src/prelude.rs new file mode 100644 index 0000000000..fbeb5babc7 --- /dev/null +++ b/third_party/rust/base64/src/prelude.rs @@ -0,0 +1,19 @@ +//! Preconfigured engines for common use cases. +//! +//! These are re-exports of `const` engines in [crate::engine::general_purpose], renamed with a `BASE64_` +//! prefix for those who prefer to `use` the entire path to a name. +//! +//! # Examples +//! +//! ``` +//! use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; +//! +//! assert_eq!("c29tZSBieXRlcw", &BASE64_STANDARD_NO_PAD.encode(b"some bytes")); +//! ``` + +pub use crate::engine::Engine; + +pub use crate::engine::general_purpose::STANDARD as BASE64_STANDARD; +pub use crate::engine::general_purpose::STANDARD_NO_PAD as BASE64_STANDARD_NO_PAD; +pub use crate::engine::general_purpose::URL_SAFE as BASE64_URL_SAFE; +pub use crate::engine::general_purpose::URL_SAFE_NO_PAD as BASE64_URL_SAFE_NO_PAD; diff --git a/third_party/rust/base64/src/read/decoder.rs b/third_party/rust/base64/src/read/decoder.rs new file mode 100644 index 0000000000..4888c9c4e7 --- /dev/null +++ b/third_party/rust/base64/src/read/decoder.rs @@ -0,0 +1,295 @@ +use crate::{engine::Engine, DecodeError}; +use std::{cmp, fmt, io}; + +// This should be large, but it has to fit on the stack. +pub(crate) const BUF_SIZE: usize = 1024; + +// 4 bytes of base64 data encode 3 bytes of raw data (modulo padding). +const BASE64_CHUNK_SIZE: usize = 4; +const DECODED_CHUNK_SIZE: usize = 3; + +/// A `Read` implementation that decodes base64 data read from an underlying reader. +/// +/// # Examples +/// +/// ``` +/// use std::io::Read; +/// use std::io::Cursor; +/// use base64::engine::general_purpose; +/// +/// // use a cursor as the simplest possible `Read` -- in real code this is probably a file, etc. +/// let mut wrapped_reader = Cursor::new(b"YXNkZg=="); +/// let mut decoder = base64::read::DecoderReader::new( +/// &mut wrapped_reader, +/// &general_purpose::STANDARD); +/// +/// // handle errors as you normally would +/// let mut result = Vec::new(); +/// decoder.read_to_end(&mut result).unwrap(); +/// +/// assert_eq!(b"asdf", &result[..]); +/// +/// ``` +pub struct DecoderReader<'e, E: Engine, R: io::Read> { + engine: &'e E, + /// Where b64 data is read from + inner: R, + + // Holds b64 data read from the delegate reader. + b64_buffer: [u8; BUF_SIZE], + // The start of the pending buffered data in b64_buffer. + b64_offset: usize, + // The amount of buffered b64 data. + b64_len: usize, + // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a + // decoded chunk in to, we have to be able to hang on to a few decoded bytes. + // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to + // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest + // into here, which seems like a lot of complexity for 1 extra byte of storage. + decoded_buffer: [u8; 3], + // index of start of decoded data + decoded_offset: usize, + // length of decoded data + decoded_len: usize, + // used to provide accurate offsets in errors + total_b64_decoded: usize, +} + +impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("DecoderReader") + .field("b64_offset", &self.b64_offset) + .field("b64_len", &self.b64_len) + .field("decoded_buffer", &self.decoded_buffer) + .field("decoded_offset", &self.decoded_offset) + .field("decoded_len", &self.decoded_len) + .field("total_b64_decoded", &self.total_b64_decoded) + .finish() + } +} + +impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { + /// Create a new decoder that will read from the provided reader `r`. + pub fn new(reader: R, engine: &'e E) -> Self { + DecoderReader { + engine, + inner: reader, + b64_buffer: [0; BUF_SIZE], + b64_offset: 0, + b64_len: 0, + decoded_buffer: [0; DECODED_CHUNK_SIZE], + decoded_offset: 0, + decoded_len: 0, + total_b64_decoded: 0, + } + } + + /// Write as much as possible of the decoded buffer into the target buffer. + /// Must only be called when there is something to write and space to write into. + /// Returns a Result with the number of (decoded) bytes copied. + fn flush_decoded_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> { + debug_assert!(self.decoded_len > 0); + debug_assert!(!buf.is_empty()); + + let copy_len = cmp::min(self.decoded_len, buf.len()); + debug_assert!(copy_len > 0); + debug_assert!(copy_len <= self.decoded_len); + + buf[..copy_len].copy_from_slice( + &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len], + ); + + self.decoded_offset += copy_len; + self.decoded_len -= copy_len; + + debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE); + + Ok(copy_len) + } + + /// Read into the remaining space in the buffer after the current contents. + /// Must only be called when there is space to read into in the buffer. + /// Returns the number of bytes read. + fn read_from_delegate(&mut self) -> io::Result<usize> { + debug_assert!(self.b64_offset + self.b64_len < BUF_SIZE); + + let read = self + .inner + .read(&mut self.b64_buffer[self.b64_offset + self.b64_len..])?; + self.b64_len += read; + + debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); + + Ok(read) + } + + /// Decode the requested number of bytes from the b64 buffer into the provided buffer. It's the + /// caller's responsibility to choose the number of b64 bytes to decode correctly. + /// + /// Returns a Result with the number of decoded bytes written to `buf`. + fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize> { + debug_assert!(self.b64_len >= num_bytes); + debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); + debug_assert!(!buf.is_empty()); + + let decoded = self + .engine + .internal_decode( + &self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes], + buf, + self.engine.internal_decoded_len_estimate(num_bytes), + ) + .map_err(|e| match e { + DecodeError::InvalidByte(offset, byte) => { + DecodeError::InvalidByte(self.total_b64_decoded + offset, byte) + } + DecodeError::InvalidLength => DecodeError::InvalidLength, + DecodeError::InvalidLastSymbol(offset, byte) => { + DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte) + } + DecodeError::InvalidPadding => DecodeError::InvalidPadding, + }) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + self.total_b64_decoded += num_bytes; + self.b64_offset += num_bytes; + self.b64_len -= num_bytes; + + debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); + + Ok(decoded) + } + + /// Unwraps this `DecoderReader`, returning the base reader which it reads base64 encoded + /// input from. + /// + /// Because `DecoderReader` performs internal buffering, the state of the inner reader is + /// unspecified. This function is mainly provided because the inner reader type may provide + /// additional functionality beyond the `Read` implementation which may still be useful. + pub fn into_inner(self) -> R { + self.inner + } +} + +impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { + /// Decode input from the wrapped reader. + /// + /// Under non-error circumstances, this returns `Ok` with the value being the number of bytes + /// written in `buf`. + /// + /// Where possible, this function buffers base64 to minimize the number of read() calls to the + /// delegate reader. + /// + /// # Errors + /// + /// Any errors emitted by the delegate reader are returned. Decoding errors due to invalid + /// base64 are also possible, and will have `io::ErrorKind::InvalidData`. + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + if buf.is_empty() { + return Ok(0); + } + + // offset == BUF_SIZE when we copied it all last time + debug_assert!(self.b64_offset <= BUF_SIZE); + debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); + debug_assert!(if self.b64_offset == BUF_SIZE { + self.b64_len == 0 + } else { + self.b64_len <= BUF_SIZE + }); + + debug_assert!(if self.decoded_len == 0 { + // can be = when we were able to copy the complete chunk + self.decoded_offset <= DECODED_CHUNK_SIZE + } else { + self.decoded_offset < DECODED_CHUNK_SIZE + }); + + // We shouldn't ever decode into here when we can't immediately write at least one byte into + // the provided buf, so the effective length should only be 3 momentarily between when we + // decode and when we copy into the target buffer. + debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE); + debug_assert!(self.decoded_len + self.decoded_offset <= DECODED_CHUNK_SIZE); + + if self.decoded_len > 0 { + // we have a few leftover decoded bytes; flush that rather than pull in more b64 + self.flush_decoded_buf(buf) + } else { + let mut at_eof = false; + while self.b64_len < BASE64_CHUNK_SIZE { + // Work around lack of copy_within, which is only present in 1.37 + // Copy any bytes we have to the start of the buffer. + // We know we have < 1 chunk, so we can use a tiny tmp buffer. + let mut memmove_buf = [0_u8; BASE64_CHUNK_SIZE]; + memmove_buf[..self.b64_len].copy_from_slice( + &self.b64_buffer[self.b64_offset..self.b64_offset + self.b64_len], + ); + self.b64_buffer[0..self.b64_len].copy_from_slice(&memmove_buf[..self.b64_len]); + self.b64_offset = 0; + + // then fill in more data + let read = self.read_from_delegate()?; + if read == 0 { + // we never pass in an empty buf, so 0 => we've hit EOF + at_eof = true; + break; + } + } + + if self.b64_len == 0 { + debug_assert!(at_eof); + // we must be at EOF, and we have no data left to decode + return Ok(0); + }; + + debug_assert!(if at_eof { + // if we are at eof, we may not have a complete chunk + self.b64_len > 0 + } else { + // otherwise, we must have at least one chunk + self.b64_len >= BASE64_CHUNK_SIZE + }); + + debug_assert_eq!(0, self.decoded_len); + + if buf.len() < DECODED_CHUNK_SIZE { + // caller requested an annoyingly short read + // have to write to a tmp buf first to avoid double mutable borrow + let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE]; + // if we are at eof, could have less than BASE64_CHUNK_SIZE, in which case we have + // to assume that these last few tokens are, in fact, valid (i.e. must be 2-4 b64 + // tokens, not 1, since 1 token can't decode to 1 byte). + let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE); + + let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?; + self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); + + self.decoded_offset = 0; + self.decoded_len = decoded; + + // can be less than 3 on last block due to padding + debug_assert!(decoded <= 3); + + self.flush_decoded_buf(buf) + } else { + let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE) + .checked_mul(BASE64_CHUNK_SIZE) + .expect("too many chunks"); + debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE); + + let b64_bytes_available_to_decode = if at_eof { + self.b64_len + } else { + // only use complete chunks + self.b64_len - self.b64_len % 4 + }; + + let actual_decode_len = cmp::min( + b64_bytes_that_can_decode_into_buf, + b64_bytes_available_to_decode, + ); + self.decode_to_buf(actual_decode_len, buf) + } + } + } +} diff --git a/third_party/rust/base64/src/read/decoder_tests.rs b/third_party/rust/base64/src/read/decoder_tests.rs new file mode 100644 index 0000000000..65d58d8e3f --- /dev/null +++ b/third_party/rust/base64/src/read/decoder_tests.rs @@ -0,0 +1,346 @@ +use std::{ + cmp, + io::{self, Read as _}, + iter, +}; + +use rand::{Rng as _, RngCore as _}; + +use super::decoder::{DecoderReader, BUF_SIZE}; +use crate::{ + engine::{general_purpose::STANDARD, Engine, GeneralPurpose}, + tests::{random_alphabet, random_config, random_engine}, + DecodeError, +}; + +#[test] +fn simple() { + let tests: &[(&[u8], &[u8])] = &[ + (&b"0"[..], &b"MA=="[..]), + (b"01", b"MDE="), + (b"012", b"MDEy"), + (b"0123", b"MDEyMw=="), + (b"01234", b"MDEyMzQ="), + (b"012345", b"MDEyMzQ1"), + (b"0123456", b"MDEyMzQ1Ng=="), + (b"01234567", b"MDEyMzQ1Njc="), + (b"012345678", b"MDEyMzQ1Njc4"), + (b"0123456789", b"MDEyMzQ1Njc4OQ=="), + ][..]; + + for (text_expected, base64data) in tests.iter() { + // Read n bytes at a time. + for n in 1..base64data.len() + 1 { + let mut wrapped_reader = io::Cursor::new(base64data); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD); + + // handle errors as you normally would + let mut text_got = Vec::new(); + let mut buffer = vec![0u8; n]; + while let Ok(read) = decoder.read(&mut buffer[..]) { + if read == 0 { + break; + } + text_got.extend_from_slice(&buffer[..read]); + } + + assert_eq!( + text_got, + *text_expected, + "\nGot: {}\nExpected: {}", + String::from_utf8_lossy(&text_got[..]), + String::from_utf8_lossy(text_expected) + ); + } + } +} + +// Make sure we error out on trailing junk. +#[test] +fn trailing_junk() { + let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..]; + + for base64data in tests.iter() { + // Read n bytes at a time. + for n in 1..base64data.len() + 1 { + let mut wrapped_reader = io::Cursor::new(base64data); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD); + + // handle errors as you normally would + let mut buffer = vec![0u8; n]; + let mut saw_error = false; + loop { + match decoder.read(&mut buffer[..]) { + Err(_) => { + saw_error = true; + break; + } + Ok(read) if read == 0 => break, + Ok(_) => (), + } + } + + assert!(saw_error); + } + } +} + +#[test] +fn handles_short_read_from_delegate() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(0..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + bytes.truncate(size); + rng.fill_bytes(&mut bytes[..size]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + engine.encode_string(&bytes[..], &mut b64); + + let mut wrapped_reader = io::Cursor::new(b64.as_bytes()); + let mut short_reader = RandomShortRead { + delegate: &mut wrapped_reader, + rng: &mut rng, + }; + + let mut decoder = DecoderReader::new(&mut short_reader, &engine); + + let decoded_len = decoder.read_to_end(&mut decoded).unwrap(); + assert_eq!(size, decoded_len); + assert_eq!(&bytes[..], &decoded[..]); + } +} + +#[test] +fn read_in_short_increments() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(0..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + // leave room to play around with larger buffers + decoded.extend(iter::repeat(0).take(size * 3)); + + rng.fill_bytes(&mut bytes[..]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + + engine.encode_string(&bytes[..], &mut b64); + + let mut wrapped_reader = io::Cursor::new(&b64[..]); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); + + consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder); + } +} + +#[test] +fn read_in_short_increments_with_short_delegate_reads() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(0..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + // leave room to play around with larger buffers + decoded.extend(iter::repeat(0).take(size * 3)); + + rng.fill_bytes(&mut bytes[..]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + + engine.encode_string(&bytes[..], &mut b64); + + let mut base_reader = io::Cursor::new(&b64[..]); + let mut decoder = DecoderReader::new(&mut base_reader, &engine); + let mut short_reader = RandomShortRead { + delegate: &mut decoder, + rng: &mut rand::thread_rng(), + }; + + consume_with_short_reads_and_validate( + &mut rng, + &bytes[..], + &mut decoded, + &mut short_reader, + ); + } +} + +#[test] +fn reports_invalid_last_symbol_correctly() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut b64_bytes = Vec::new(); + let mut decoded = Vec::new(); + let mut bulk_decoded = Vec::new(); + + for _ in 0..1_000 { + bytes.clear(); + b64.clear(); + b64_bytes.clear(); + + let size = rng.gen_range(1..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + decoded.extend(iter::repeat(0).take(size)); + rng.fill_bytes(&mut bytes[..]); + assert_eq!(size, bytes.len()); + + let config = random_config(&mut rng); + let alphabet = random_alphabet(&mut rng); + // changing padding will cause invalid padding errors when we twiddle the last byte + let engine = GeneralPurpose::new(alphabet, config.with_encode_padding(false)); + engine.encode_string(&bytes[..], &mut b64); + b64_bytes.extend(b64.bytes()); + assert_eq!(b64_bytes.len(), b64.len()); + + // change the last character to every possible symbol. Should behave the same as bulk + // decoding whether invalid or valid. + for &s1 in alphabet.symbols.iter() { + decoded.clear(); + bulk_decoded.clear(); + + // replace the last + *b64_bytes.last_mut().unwrap() = s1; + let bulk_res = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded); + + let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); + + let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| { + e.into_inner() + .and_then(|e| e.downcast::<DecodeError>().ok()) + }); + + assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res); + } + } +} + +#[test] +fn reports_invalid_byte_correctly() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(1..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + rng.fill_bytes(&mut bytes[..size]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + + engine.encode_string(&bytes[..], &mut b64); + // replace one byte, somewhere, with '*', which is invalid + let bad_byte_pos = rng.gen_range(0..b64.len()); + let mut b64_bytes = b64.bytes().collect::<Vec<u8>>(); + b64_bytes[bad_byte_pos] = b'*'; + + let mut wrapped_reader = io::Cursor::new(b64_bytes.clone()); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); + + // some gymnastics to avoid double-moving the io::Error, which is not Copy + let read_decode_err = decoder + .read_to_end(&mut decoded) + .map_err(|e| { + let kind = e.kind(); + let inner = e + .into_inner() + .and_then(|e| e.downcast::<DecodeError>().ok()); + inner.map(|i| (*i, kind)) + }) + .err() + .and_then(|o| o); + + let mut bulk_buf = Vec::new(); + let bulk_decode_err = engine.decode_vec(&b64_bytes[..], &mut bulk_buf).err(); + + // it's tricky to predict where the invalid data's offset will be since if it's in the last + // chunk it will be reported at the first padding location because it's treated as invalid + // padding. So, we just check that it's the same as it is for decoding all at once. + assert_eq!( + bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)), + read_decode_err + ); + } +} + +fn consume_with_short_reads_and_validate<R: io::Read>( + rng: &mut rand::rngs::ThreadRng, + expected_bytes: &[u8], + decoded: &mut [u8], + short_reader: &mut R, +) { + let mut total_read = 0_usize; + loop { + assert!( + total_read <= expected_bytes.len(), + "tr {} size {}", + total_read, + expected_bytes.len() + ); + if total_read == expected_bytes.len() { + assert_eq!(expected_bytes, &decoded[..total_read]); + // should be done + assert_eq!(0, short_reader.read(&mut *decoded).unwrap()); + // didn't write anything + assert_eq!(expected_bytes, &decoded[..total_read]); + + break; + } + let decode_len = rng.gen_range(1..cmp::max(2, expected_bytes.len() * 2)); + + let read = short_reader + .read(&mut decoded[total_read..total_read + decode_len]) + .unwrap(); + total_read += read; + } +} + +/// Limits how many bytes a reader will provide in each read call. +/// Useful for shaking out code that may work fine only with typical input sources that always fill +/// the buffer. +struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> { + delegate: &'b mut R, + rng: &'a mut N, +} + +impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> { + fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> { + // avoid 0 since it means EOF for non-empty buffers + let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len()); + + self.delegate.read(&mut buf[..effective_len]) + } +} diff --git a/third_party/rust/base64/src/read/mod.rs b/third_party/rust/base64/src/read/mod.rs new file mode 100644 index 0000000000..856064481c --- /dev/null +++ b/third_party/rust/base64/src/read/mod.rs @@ -0,0 +1,6 @@ +//! Implementations of `io::Read` to transparently decode base64. +mod decoder; +pub use self::decoder::DecoderReader; + +#[cfg(test)] +mod decoder_tests; diff --git a/third_party/rust/base64/src/tests.rs b/third_party/rust/base64/src/tests.rs new file mode 100644 index 0000000000..7083b5433f --- /dev/null +++ b/third_party/rust/base64/src/tests.rs @@ -0,0 +1,117 @@ +use std::str; + +use rand::{ + distributions, + distributions::{Distribution as _, Uniform}, + seq::SliceRandom, + Rng, SeedableRng, +}; + +use crate::{ + alphabet, + encode::encoded_len, + engine::{ + general_purpose::{GeneralPurpose, GeneralPurposeConfig}, + Config, DecodePaddingMode, Engine, + }, +}; + +#[test] +fn roundtrip_random_config_short() { + // exercise the slower encode/decode routines that operate on shorter buffers more vigorously + roundtrip_random_config(Uniform::new(0, 50), 10_000); +} + +#[test] +fn roundtrip_random_config_long() { + roundtrip_random_config(Uniform::new(0, 1000), 10_000); +} + +pub fn assert_encode_sanity(encoded: &str, padded: bool, input_len: usize) { + let input_rem = input_len % 3; + let expected_padding_len = if input_rem > 0 { + if padded { + 3 - input_rem + } else { + 0 + } + } else { + 0 + }; + + let expected_encoded_len = encoded_len(input_len, padded).unwrap(); + + assert_eq!(expected_encoded_len, encoded.len()); + + let padding_len = encoded.chars().filter(|&c| c == '=').count(); + + assert_eq!(expected_padding_len, padding_len); + + let _ = str::from_utf8(encoded.as_bytes()).expect("Base64 should be valid utf8"); +} + +fn roundtrip_random_config(input_len_range: Uniform<usize>, iterations: u32) { + let mut input_buf: Vec<u8> = Vec::new(); + let mut encoded_buf = String::new(); + let mut rng = rand::rngs::SmallRng::from_entropy(); + + for _ in 0..iterations { + input_buf.clear(); + encoded_buf.clear(); + + let input_len = input_len_range.sample(&mut rng); + + let engine = random_engine(&mut rng); + + for _ in 0..input_len { + input_buf.push(rng.gen()); + } + + engine.encode_string(&input_buf, &mut encoded_buf); + + assert_encode_sanity(&encoded_buf, engine.config().encode_padding(), input_len); + + assert_eq!(input_buf, engine.decode(&encoded_buf).unwrap()); + } +} + +pub fn random_config<R: Rng>(rng: &mut R) -> GeneralPurposeConfig { + let mode = rng.gen(); + GeneralPurposeConfig::new() + .with_encode_padding(match mode { + DecodePaddingMode::Indifferent => rng.gen(), + DecodePaddingMode::RequireCanonical => true, + DecodePaddingMode::RequireNone => false, + }) + .with_decode_padding_mode(mode) + .with_decode_allow_trailing_bits(rng.gen()) +} + +impl distributions::Distribution<DecodePaddingMode> for distributions::Standard { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> DecodePaddingMode { + match rng.gen_range(0..=2) { + 0 => DecodePaddingMode::Indifferent, + 1 => DecodePaddingMode::RequireCanonical, + _ => DecodePaddingMode::RequireNone, + } + } +} + +pub fn random_alphabet<R: Rng>(rng: &mut R) -> &'static alphabet::Alphabet { + ALPHABETS.choose(rng).unwrap() +} + +pub fn random_engine<R: Rng>(rng: &mut R) -> GeneralPurpose { + let alphabet = random_alphabet(rng); + let config = random_config(rng); + GeneralPurpose::new(alphabet, config) +} + +const ALPHABETS: &[alphabet::Alphabet] = &[ + alphabet::URL_SAFE, + alphabet::STANDARD, + alphabet::CRYPT, + alphabet::BCRYPT, + alphabet::IMAP_MUTF7, + alphabet::BIN_HEX, +]; diff --git a/third_party/rust/base64/src/write/encoder.rs b/third_party/rust/base64/src/write/encoder.rs new file mode 100644 index 0000000000..1c19bb42ab --- /dev/null +++ b/third_party/rust/base64/src/write/encoder.rs @@ -0,0 +1,407 @@ +use crate::engine::Engine; +use std::{ + cmp, fmt, io, + io::{ErrorKind, Result}, +}; + +pub(crate) const BUF_SIZE: usize = 1024; +/// The most bytes whose encoding will fit in `BUF_SIZE` +const MAX_INPUT_LEN: usize = BUF_SIZE / 4 * 3; +// 3 bytes of input = 4 bytes of base64, always (because we don't allow line wrapping) +const MIN_ENCODE_CHUNK_SIZE: usize = 3; + +/// A `Write` implementation that base64 encodes data before delegating to the wrapped writer. +/// +/// Because base64 has special handling for the end of the input data (padding, etc), there's a +/// `finish()` method on this type that encodes any leftover input bytes and adds padding if +/// appropriate. It's called automatically when deallocated (see the `Drop` implementation), but +/// any error that occurs when invoking the underlying writer will be suppressed. If you want to +/// handle such errors, call `finish()` yourself. +/// +/// # Examples +/// +/// ``` +/// use std::io::Write; +/// use base64::engine::general_purpose; +/// +/// // use a vec as the simplest possible `Write` -- in real code this is probably a file, etc. +/// let mut enc = base64::write::EncoderWriter::new(Vec::new(), &general_purpose::STANDARD); +/// +/// // handle errors as you normally would +/// enc.write_all(b"asdf").unwrap(); +/// +/// // could leave this out to be called by Drop, if you don't care +/// // about handling errors or getting the delegate writer back +/// let delegate = enc.finish().unwrap(); +/// +/// // base64 was written to the writer +/// assert_eq!(b"YXNkZg==", &delegate[..]); +/// +/// ``` +/// +/// # Panics +/// +/// Calling `write()` (or related methods) or `finish()` after `finish()` has completed without +/// error is invalid and will panic. +/// +/// # Errors +/// +/// Base64 encoding itself does not generate errors, but errors from the wrapped writer will be +/// returned as per the contract of `Write`. +/// +/// # Performance +/// +/// It has some minor performance loss compared to encoding slices (a couple percent). +/// It does not do any heap allocation. +/// +/// # Limitations +/// +/// Owing to the specification of the `write` and `flush` methods on the `Write` trait and their +/// implications for a buffering implementation, these methods may not behave as expected. In +/// particular, calling `write_all` on this interface may fail with `io::ErrorKind::WriteZero`. +/// See the documentation of the `Write` trait implementation for further details. +pub struct EncoderWriter<'e, E: Engine, W: io::Write> { + engine: &'e E, + /// Where encoded data is written to. It's an Option as it's None immediately before Drop is + /// called so that finish() can return the underlying writer. None implies that finish() has + /// been called successfully. + delegate: Option<W>, + /// Holds a partial chunk, if any, after the last `write()`, so that we may then fill the chunk + /// with the next `write()`, encode it, then proceed with the rest of the input normally. + extra_input: [u8; MIN_ENCODE_CHUNK_SIZE], + /// How much of `extra` is occupied, in `[0, MIN_ENCODE_CHUNK_SIZE]`. + extra_input_occupied_len: usize, + /// Buffer to encode into. May hold leftover encoded bytes from a previous write call that the underlying writer + /// did not write last time. + output: [u8; BUF_SIZE], + /// How much of `output` is occupied with encoded data that couldn't be written last time + output_occupied_len: usize, + /// panic safety: don't write again in destructor if writer panicked while we were writing to it + panicked: bool, +} + +impl<'e, E: Engine, W: io::Write> fmt::Debug for EncoderWriter<'e, E, W> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "extra_input: {:?} extra_input_occupied_len:{:?} output[..5]: {:?} output_occupied_len: {:?}", + self.extra_input, + self.extra_input_occupied_len, + &self.output[0..5], + self.output_occupied_len + ) + } +} + +impl<'e, E: Engine, W: io::Write> EncoderWriter<'e, E, W> { + /// Create a new encoder that will write to the provided delegate writer. + pub fn new(delegate: W, engine: &'e E) -> EncoderWriter<'e, E, W> { + EncoderWriter { + engine, + delegate: Some(delegate), + extra_input: [0u8; MIN_ENCODE_CHUNK_SIZE], + extra_input_occupied_len: 0, + output: [0u8; BUF_SIZE], + output_occupied_len: 0, + panicked: false, + } + } + + /// Encode all remaining buffered data and write it, including any trailing incomplete input + /// triples and associated padding. + /// + /// Once this succeeds, no further writes or calls to this method are allowed. + /// + /// This may write to the delegate writer multiple times if the delegate writer does not accept + /// all input provided to its `write` each invocation. + /// + /// If you don't care about error handling, it is not necessary to call this function, as the + /// equivalent finalization is done by the Drop impl. + /// + /// Returns the writer that this was constructed around. + /// + /// # Errors + /// + /// The first error that is not of `ErrorKind::Interrupted` will be returned. + pub fn finish(&mut self) -> Result<W> { + // If we could consume self in finish(), we wouldn't have to worry about this case, but + // finish() is retryable in the face of I/O errors, so we can't consume here. + if self.delegate.is_none() { + panic!("Encoder has already had finish() called"); + }; + + self.write_final_leftovers()?; + + let writer = self.delegate.take().expect("Writer must be present"); + + Ok(writer) + } + + /// Write any remaining buffered data to the delegate writer. + fn write_final_leftovers(&mut self) -> Result<()> { + if self.delegate.is_none() { + // finish() has already successfully called this, and we are now in drop() with a None + // writer, so just no-op + return Ok(()); + } + + self.write_all_encoded_output()?; + + if self.extra_input_occupied_len > 0 { + let encoded_len = self + .engine + .encode_slice( + &self.extra_input[..self.extra_input_occupied_len], + &mut self.output[..], + ) + .expect("buffer is large enough"); + + self.output_occupied_len = encoded_len; + + self.write_all_encoded_output()?; + + // write succeeded, do not write the encoding of extra again if finish() is retried + self.extra_input_occupied_len = 0; + } + + Ok(()) + } + + /// Write as much of the encoded output to the delegate writer as it will accept, and store the + /// leftovers to be attempted at the next write() call. Updates `self.output_occupied_len`. + /// + /// # Errors + /// + /// Errors from the delegate writer are returned. In the case of an error, + /// `self.output_occupied_len` will not be updated, as errors from `write` are specified to mean + /// that no write took place. + fn write_to_delegate(&mut self, current_output_len: usize) -> Result<()> { + self.panicked = true; + let res = self + .delegate + .as_mut() + .expect("Writer must be present") + .write(&self.output[..current_output_len]); + self.panicked = false; + + res.map(|consumed| { + debug_assert!(consumed <= current_output_len); + + if consumed < current_output_len { + self.output_occupied_len = current_output_len.checked_sub(consumed).unwrap(); + // If we're blocking on I/O, the minor inefficiency of copying bytes to the + // start of the buffer is the least of our concerns... + // TODO Rotate moves more than we need to; copy_within now stable. + self.output.rotate_left(consumed); + } else { + self.output_occupied_len = 0; + } + }) + } + + /// Write all buffered encoded output. If this returns `Ok`, `self.output_occupied_len` is `0`. + /// + /// This is basically write_all for the remaining buffered data but without the undesirable + /// abort-on-`Ok(0)` behavior. + /// + /// # Errors + /// + /// Any error emitted by the delegate writer abort the write loop and is returned, unless it's + /// `Interrupted`, in which case the error is ignored and writes will continue. + fn write_all_encoded_output(&mut self) -> Result<()> { + while self.output_occupied_len > 0 { + let remaining_len = self.output_occupied_len; + match self.write_to_delegate(remaining_len) { + // try again on interrupts ala write_all + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} + // other errors return + Err(e) => return Err(e), + // success no-ops because remaining length is already updated + Ok(_) => {} + }; + } + + debug_assert_eq!(0, self.output_occupied_len); + Ok(()) + } + + /// Unwraps this `EncoderWriter`, returning the base writer it writes base64 encoded output + /// to. + /// + /// Normally this method should not be needed, since `finish()` returns the inner writer if + /// it completes successfully. That will also ensure all data has been flushed, which the + /// `into_inner()` function does *not* do. + /// + /// Calling this method after `finish()` has completed successfully will panic, since the + /// writer has already been returned. + /// + /// This method may be useful if the writer implements additional APIs beyond the `Write` + /// trait. Note that the inner writer might be in an error state or have an incomplete + /// base64 string written to it. + pub fn into_inner(mut self) -> W { + self.delegate + .take() + .expect("Encoder has already had finish() called") + } +} + +impl<'e, E: Engine, W: io::Write> io::Write for EncoderWriter<'e, E, W> { + /// Encode input and then write to the delegate writer. + /// + /// Under non-error circumstances, this returns `Ok` with the value being the number of bytes + /// of `input` consumed. The value may be `0`, which interacts poorly with `write_all`, which + /// interprets `Ok(0)` as an error, despite it being allowed by the contract of `write`. See + /// <https://github.com/rust-lang/rust/issues/56889> for more on that. + /// + /// If the previous call to `write` provided more (encoded) data than the delegate writer could + /// accept in a single call to its `write`, the remaining data is buffered. As long as buffered + /// data is present, subsequent calls to `write` will try to write the remaining buffered data + /// to the delegate and return either `Ok(0)` -- and therefore not consume any of `input` -- or + /// an error. + /// + /// # Errors + /// + /// Any errors emitted by the delegate writer are returned. + fn write(&mut self, input: &[u8]) -> Result<usize> { + if self.delegate.is_none() { + panic!("Cannot write more after calling finish()"); + } + + if input.is_empty() { + return Ok(0); + } + + // The contract of `Write::write` places some constraints on this implementation: + // - a call to `write()` represents at most one call to a wrapped `Write`, so we can't + // iterate over the input and encode multiple chunks. + // - Errors mean that "no bytes were written to this writer", so we need to reset the + // internal state to what it was before the error occurred + + // before reading any input, write any leftover encoded output from last time + if self.output_occupied_len > 0 { + let current_len = self.output_occupied_len; + return self + .write_to_delegate(current_len) + // did not read any input + .map(|_| 0); + } + + debug_assert_eq!(0, self.output_occupied_len); + + // how many bytes, if any, were read into `extra` to create a triple to encode + let mut extra_input_read_len = 0; + let mut input = input; + + let orig_extra_len = self.extra_input_occupied_len; + + let mut encoded_size = 0; + // always a multiple of MIN_ENCODE_CHUNK_SIZE + let mut max_input_len = MAX_INPUT_LEN; + + // process leftover un-encoded input from last write + if self.extra_input_occupied_len > 0 { + debug_assert!(self.extra_input_occupied_len < 3); + if input.len() + self.extra_input_occupied_len >= MIN_ENCODE_CHUNK_SIZE { + // Fill up `extra`, encode that into `output`, and consume as much of the rest of + // `input` as possible. + // We could write just the encoding of `extra` by itself but then we'd have to + // return after writing only 4 bytes, which is inefficient if the underlying writer + // would make a syscall. + extra_input_read_len = MIN_ENCODE_CHUNK_SIZE - self.extra_input_occupied_len; + debug_assert!(extra_input_read_len > 0); + // overwrite only bytes that weren't already used. If we need to rollback extra_len + // (when the subsequent write errors), the old leading bytes will still be there. + self.extra_input[self.extra_input_occupied_len..MIN_ENCODE_CHUNK_SIZE] + .copy_from_slice(&input[0..extra_input_read_len]); + + let len = self.engine.internal_encode( + &self.extra_input[0..MIN_ENCODE_CHUNK_SIZE], + &mut self.output[..], + ); + debug_assert_eq!(4, len); + + input = &input[extra_input_read_len..]; + + // consider extra to be used up, since we encoded it + self.extra_input_occupied_len = 0; + // don't clobber where we just encoded to + encoded_size = 4; + // and don't read more than can be encoded + max_input_len = MAX_INPUT_LEN - MIN_ENCODE_CHUNK_SIZE; + + // fall through to normal encoding + } else { + // `extra` and `input` are non empty, but `|extra| + |input| < 3`, so there must be + // 1 byte in each. + debug_assert_eq!(1, input.len()); + debug_assert_eq!(1, self.extra_input_occupied_len); + + self.extra_input[self.extra_input_occupied_len] = input[0]; + self.extra_input_occupied_len += 1; + return Ok(1); + }; + } else if input.len() < MIN_ENCODE_CHUNK_SIZE { + // `extra` is empty, and `input` fits inside it + self.extra_input[0..input.len()].copy_from_slice(input); + self.extra_input_occupied_len = input.len(); + return Ok(input.len()); + }; + + // either 0 or 1 complete chunks encoded from extra + debug_assert!(encoded_size == 0 || encoded_size == 4); + debug_assert!( + // didn't encode extra input + MAX_INPUT_LEN == max_input_len + // encoded one triple + || MAX_INPUT_LEN == max_input_len + MIN_ENCODE_CHUNK_SIZE + ); + + // encode complete triples only + let input_complete_chunks_len = input.len() - (input.len() % MIN_ENCODE_CHUNK_SIZE); + let input_chunks_to_encode_len = cmp::min(input_complete_chunks_len, max_input_len); + debug_assert_eq!(0, max_input_len % MIN_ENCODE_CHUNK_SIZE); + debug_assert_eq!(0, input_chunks_to_encode_len % MIN_ENCODE_CHUNK_SIZE); + + encoded_size += self.engine.internal_encode( + &input[..(input_chunks_to_encode_len)], + &mut self.output[encoded_size..], + ); + + // not updating `self.output_occupied_len` here because if the below write fails, it should + // "never take place" -- the buffer contents we encoded are ignored and perhaps retried + // later, if the consumer chooses. + + self.write_to_delegate(encoded_size) + // no matter whether we wrote the full encoded buffer or not, we consumed the same + // input + .map(|_| extra_input_read_len + input_chunks_to_encode_len) + .map_err(|e| { + // in case we filled and encoded `extra`, reset extra_len + self.extra_input_occupied_len = orig_extra_len; + + e + }) + } + + /// Because this is usually treated as OK to call multiple times, it will *not* flush any + /// incomplete chunks of input or write padding. + /// # Errors + /// + /// The first error that is not of [`ErrorKind::Interrupted`] will be returned. + fn flush(&mut self) -> Result<()> { + self.write_all_encoded_output()?; + self.delegate + .as_mut() + .expect("Writer must be present") + .flush() + } +} + +impl<'e, E: Engine, W: io::Write> Drop for EncoderWriter<'e, E, W> { + fn drop(&mut self) { + if !self.panicked { + // like `BufWriter`, ignore errors during drop + let _ = self.write_final_leftovers(); + } + } +} diff --git a/third_party/rust/base64/src/write/encoder_string_writer.rs b/third_party/rust/base64/src/write/encoder_string_writer.rs new file mode 100644 index 0000000000..9394dc9bf7 --- /dev/null +++ b/third_party/rust/base64/src/write/encoder_string_writer.rs @@ -0,0 +1,178 @@ +use super::encoder::EncoderWriter; +use crate::engine::Engine; +use std::io; + +/// A `Write` implementation that base64-encodes data using the provided config and accumulates the +/// resulting base64 utf8 `&str` in a [StrConsumer] implementation (typically `String`), which is +/// then exposed via `into_inner()`. +/// +/// # Examples +/// +/// Buffer base64 in a new String: +/// +/// ``` +/// use std::io::Write; +/// use base64::engine::general_purpose; +/// +/// let mut enc = base64::write::EncoderStringWriter::new(&general_purpose::STANDARD); +/// +/// enc.write_all(b"asdf").unwrap(); +/// +/// // get the resulting String +/// let b64_string = enc.into_inner(); +/// +/// assert_eq!("YXNkZg==", &b64_string); +/// ``` +/// +/// Or, append to an existing `String`, which implements `StrConsumer`: +/// +/// ``` +/// use std::io::Write; +/// use base64::engine::general_purpose; +/// +/// let mut buf = String::from("base64: "); +/// +/// let mut enc = base64::write::EncoderStringWriter::from_consumer( +/// &mut buf, +/// &general_purpose::STANDARD); +/// +/// enc.write_all(b"asdf").unwrap(); +/// +/// // release the &mut reference on buf +/// let _ = enc.into_inner(); +/// +/// assert_eq!("base64: YXNkZg==", &buf); +/// ``` +/// +/// # Panics +/// +/// Calling `write()` (or related methods) or `finish()` after `finish()` has completed without +/// error is invalid and will panic. +/// +/// # Performance +/// +/// Because it has to validate that the base64 is UTF-8, it is about 80% as fast as writing plain +/// bytes to a `io::Write`. +pub struct EncoderStringWriter<'e, E: Engine, S: StrConsumer> { + encoder: EncoderWriter<'e, E, Utf8SingleCodeUnitWriter<S>>, +} + +impl<'e, E: Engine, S: StrConsumer> EncoderStringWriter<'e, E, S> { + /// Create a EncoderStringWriter that will append to the provided `StrConsumer`. + pub fn from_consumer(str_consumer: S, engine: &'e E) -> Self { + EncoderStringWriter { + encoder: EncoderWriter::new(Utf8SingleCodeUnitWriter { str_consumer }, engine), + } + } + + /// Encode all remaining buffered data, including any trailing incomplete input triples and + /// associated padding. + /// + /// Returns the base64-encoded form of the accumulated written data. + pub fn into_inner(mut self) -> S { + self.encoder + .finish() + .expect("Writing to a consumer should never fail") + .str_consumer + } +} + +impl<'e, E: Engine> EncoderStringWriter<'e, E, String> { + /// Create a EncoderStringWriter that will encode into a new `String` with the provided config. + pub fn new(engine: &'e E) -> Self { + EncoderStringWriter::from_consumer(String::new(), engine) + } +} + +impl<'e, E: Engine, S: StrConsumer> io::Write for EncoderStringWriter<'e, E, S> { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + self.encoder.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.encoder.flush() + } +} + +/// An abstraction around consuming `str`s produced by base64 encoding. +pub trait StrConsumer { + /// Consume the base64 encoded data in `buf` + fn consume(&mut self, buf: &str); +} + +/// As for io::Write, `StrConsumer` is implemented automatically for `&mut S`. +impl<S: StrConsumer + ?Sized> StrConsumer for &mut S { + fn consume(&mut self, buf: &str) { + (**self).consume(buf); + } +} + +/// Pushes the str onto the end of the String +impl StrConsumer for String { + fn consume(&mut self, buf: &str) { + self.push_str(buf); + } +} + +/// A `Write` that only can handle bytes that are valid single-byte UTF-8 code units. +/// +/// This is safe because we only use it when writing base64, which is always valid UTF-8. +struct Utf8SingleCodeUnitWriter<S: StrConsumer> { + str_consumer: S, +} + +impl<S: StrConsumer> io::Write for Utf8SingleCodeUnitWriter<S> { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + // Because we expect all input to be valid utf-8 individual bytes, we can encode any buffer + // length + let s = std::str::from_utf8(buf).expect("Input must be valid UTF-8"); + + self.str_consumer.consume(s); + + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + // no op + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + engine::Engine, tests::random_engine, write::encoder_string_writer::EncoderStringWriter, + }; + use rand::Rng; + use std::io::Write; + + #[test] + fn every_possible_split_of_input() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::<u8>::new(); + let mut normal_encoded = String::new(); + + let size = 5_000; + + for i in 0..size { + orig_data.clear(); + normal_encoded.clear(); + + for _ in 0..size { + orig_data.push(rng.gen()); + } + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + let mut stream_encoder = EncoderStringWriter::new(&engine); + // Write the first i bytes, then the rest + stream_encoder.write_all(&orig_data[0..i]).unwrap(); + stream_encoder.write_all(&orig_data[i..]).unwrap(); + + let stream_encoded = stream_encoder.into_inner(); + + assert_eq!(normal_encoded, stream_encoded); + } + } +} diff --git a/third_party/rust/base64/src/write/encoder_tests.rs b/third_party/rust/base64/src/write/encoder_tests.rs new file mode 100644 index 0000000000..ce76d631e5 --- /dev/null +++ b/third_party/rust/base64/src/write/encoder_tests.rs @@ -0,0 +1,554 @@ +use std::io::{Cursor, Write}; +use std::{cmp, io, str}; + +use rand::Rng; + +use crate::{ + alphabet::{STANDARD, URL_SAFE}, + engine::{ + general_purpose::{GeneralPurpose, NO_PAD, PAD}, + Engine, + }, + tests::random_engine, +}; + +use super::EncoderWriter; + +const URL_SAFE_ENGINE: GeneralPurpose = GeneralPurpose::new(&URL_SAFE, PAD); +const NO_PAD_ENGINE: GeneralPurpose = GeneralPurpose::new(&STANDARD, NO_PAD); + +#[test] +fn encode_three_bytes() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + let sz = enc.write(b"abc").unwrap(); + assert_eq!(sz, 3); + } + assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abc").as_bytes()); +} + +#[test] +fn encode_nine_bytes_two_writes() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + let sz = enc.write(b"abcdef").unwrap(); + assert_eq!(sz, 6); + let sz = enc.write(b"ghi").unwrap(); + assert_eq!(sz, 3); + } + assert_eq!( + &c.get_ref()[..], + URL_SAFE_ENGINE.encode("abcdefghi").as_bytes() + ); +} + +#[test] +fn encode_one_then_two_bytes() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + let sz = enc.write(b"a").unwrap(); + assert_eq!(sz, 1); + let sz = enc.write(b"bc").unwrap(); + assert_eq!(sz, 2); + } + assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abc").as_bytes()); +} + +#[test] +fn encode_one_then_five_bytes() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + let sz = enc.write(b"a").unwrap(); + assert_eq!(sz, 1); + let sz = enc.write(b"bcdef").unwrap(); + assert_eq!(sz, 5); + } + assert_eq!( + &c.get_ref()[..], + URL_SAFE_ENGINE.encode("abcdef").as_bytes() + ); +} + +#[test] +fn encode_1_2_3_bytes() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + let sz = enc.write(b"a").unwrap(); + assert_eq!(sz, 1); + let sz = enc.write(b"bc").unwrap(); + assert_eq!(sz, 2); + let sz = enc.write(b"def").unwrap(); + assert_eq!(sz, 3); + } + assert_eq!( + &c.get_ref()[..], + URL_SAFE_ENGINE.encode("abcdef").as_bytes() + ); +} + +#[test] +fn encode_with_padding() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + enc.write_all(b"abcd").unwrap(); + + enc.flush().unwrap(); + } + assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abcd").as_bytes()); +} + +#[test] +fn encode_with_padding_multiple_writes() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + assert_eq!(1, enc.write(b"a").unwrap()); + assert_eq!(2, enc.write(b"bc").unwrap()); + assert_eq!(3, enc.write(b"def").unwrap()); + assert_eq!(1, enc.write(b"g").unwrap()); + + enc.flush().unwrap(); + } + assert_eq!( + &c.get_ref()[..], + URL_SAFE_ENGINE.encode("abcdefg").as_bytes() + ); +} + +#[test] +fn finish_writes_extra_byte() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE); + + assert_eq!(6, enc.write(b"abcdef").unwrap()); + + // will be in extra + assert_eq!(1, enc.write(b"g").unwrap()); + + // 1 trailing byte = 2 encoded chars + let _ = enc.finish().unwrap(); + } + assert_eq!( + &c.get_ref()[..], + URL_SAFE_ENGINE.encode("abcdefg").as_bytes() + ); +} + +#[test] +fn write_partial_chunk_encodes_partial_chunk() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + // nothing encoded yet + assert_eq!(2, enc.write(b"ab").unwrap()); + // encoded here + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("ab").as_bytes()); + assert_eq!(3, c.get_ref().len()); +} + +#[test] +fn write_1_chunk_encodes_complete_chunk() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + assert_eq!(3, enc.write(b"abc").unwrap()); + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes()); + assert_eq!(4, c.get_ref().len()); +} + +#[test] +fn write_1_chunk_and_partial_encodes_only_complete_chunk() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + // "d" not consumed since it's not a full chunk + assert_eq!(3, enc.write(b"abcd").unwrap()); + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes()); + assert_eq!(4, c.get_ref().len()); +} + +#[test] +fn write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + assert_eq!(1, enc.write(b"a").unwrap()); + assert_eq!(2, enc.write(b"bc").unwrap()); + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes()); + assert_eq!(4, c.get_ref().len()); +} + +#[test] +fn write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining( +) { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + assert_eq!(1, enc.write(b"a").unwrap()); + // doesn't consume "d" + assert_eq!(2, enc.write(b"bcd").unwrap()); + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes()); + assert_eq!(4, c.get_ref().len()); +} + +#[test] +fn write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + assert_eq!(1, enc.write(b"a").unwrap()); + // completes partial chunk, and another chunk + assert_eq!(5, enc.write(b"bcdef").unwrap()); + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abcdef").as_bytes()); + assert_eq!(8, c.get_ref().len()); +} + +#[test] +fn write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks( +) { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + + assert_eq!(1, enc.write(b"a").unwrap()); + // completes partial chunk, and another chunk, with one more partial chunk that's not + // consumed + assert_eq!(5, enc.write(b"bcdefe").unwrap()); + let _ = enc.finish().unwrap(); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abcdef").as_bytes()); + assert_eq!(8, c.get_ref().len()); +} + +#[test] +fn drop_calls_finish_for_you() { + let mut c = Cursor::new(Vec::new()); + { + let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE); + assert_eq!(1, enc.write(b"a").unwrap()); + } + assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("a").as_bytes()); + assert_eq!(2, c.get_ref().len()); +} + +#[test] +fn every_possible_split_of_input() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::<u8>::new(); + let mut stream_encoded = Vec::<u8>::new(); + let mut normal_encoded = String::new(); + + let size = 5_000; + + for i in 0..size { + orig_data.clear(); + stream_encoded.clear(); + normal_encoded.clear(); + + for _ in 0..size { + orig_data.push(rng.gen()); + } + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + { + let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, &engine); + // Write the first i bytes, then the rest + stream_encoder.write_all(&orig_data[0..i]).unwrap(); + stream_encoder.write_all(&orig_data[i..]).unwrap(); + } + + assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap()); + } +} + +#[test] +fn encode_random_config_matches_normal_encode_reasonable_input_len() { + // choose up to 2 * buf size, so ~half the time it'll use a full buffer + do_encode_random_config_matches_normal_encode(super::encoder::BUF_SIZE * 2); +} + +#[test] +fn encode_random_config_matches_normal_encode_tiny_input_len() { + do_encode_random_config_matches_normal_encode(10); +} + +#[test] +fn retrying_writes_that_error_with_interrupted_works() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::<u8>::new(); + let mut stream_encoded = Vec::<u8>::new(); + let mut normal_encoded = String::new(); + + for _ in 0..1_000 { + orig_data.clear(); + stream_encoded.clear(); + normal_encoded.clear(); + + let orig_len: usize = rng.gen_range(100..20_000); + for _ in 0..orig_len { + orig_data.push(rng.gen()); + } + + // encode the normal way + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + // encode via the stream encoder + { + let mut interrupt_rng = rand::thread_rng(); + let mut interrupting_writer = InterruptingWriter { + w: &mut stream_encoded, + rng: &mut interrupt_rng, + fraction: 0.8, + }; + + let mut stream_encoder = EncoderWriter::new(&mut interrupting_writer, &engine); + let mut bytes_consumed = 0; + while bytes_consumed < orig_len { + // use short inputs since we want to use `extra` a lot as that's what needs rollback + // when errors occur + let input_len: usize = cmp::min(rng.gen_range(0..10), orig_len - bytes_consumed); + + retry_interrupted_write_all( + &mut stream_encoder, + &orig_data[bytes_consumed..bytes_consumed + input_len], + ) + .unwrap(); + + bytes_consumed += input_len; + } + + loop { + let res = stream_encoder.finish(); + match res { + Ok(_) => break, + Err(e) => match e.kind() { + io::ErrorKind::Interrupted => continue, + _ => Err(e).unwrap(), // bail + }, + } + } + + assert_eq!(orig_len, bytes_consumed); + } + + assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap()); + } +} + +#[test] +fn writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::<u8>::new(); + let mut stream_encoded = Vec::<u8>::new(); + let mut normal_encoded = String::new(); + + for _ in 0..1_000 { + orig_data.clear(); + stream_encoded.clear(); + normal_encoded.clear(); + + let orig_len: usize = rng.gen_range(100..20_000); + for _ in 0..orig_len { + orig_data.push(rng.gen()); + } + + // encode the normal way + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + // encode via the stream encoder + { + let mut partial_rng = rand::thread_rng(); + let mut partial_writer = PartialInterruptingWriter { + w: &mut stream_encoded, + rng: &mut partial_rng, + full_input_fraction: 0.1, + no_interrupt_fraction: 0.1, + }; + + let mut stream_encoder = EncoderWriter::new(&mut partial_writer, &engine); + let mut bytes_consumed = 0; + while bytes_consumed < orig_len { + // use at most medium-length inputs to exercise retry logic more aggressively + let input_len: usize = cmp::min(rng.gen_range(0..100), orig_len - bytes_consumed); + + let res = + stream_encoder.write(&orig_data[bytes_consumed..bytes_consumed + input_len]); + + // retry on interrupt + match res { + Ok(len) => bytes_consumed += len, + Err(e) => match e.kind() { + io::ErrorKind::Interrupted => continue, + _ => { + panic!("should not see other errors"); + } + }, + } + } + + let _ = stream_encoder.finish().unwrap(); + + assert_eq!(orig_len, bytes_consumed); + } + + assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap()); + } +} + +/// Retry writes until all the data is written or an error that isn't Interrupted is returned. +fn retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()> { + let mut bytes_consumed = 0; + + while bytes_consumed < buf.len() { + let res = w.write(&buf[bytes_consumed..]); + + match res { + Ok(len) => bytes_consumed += len, + Err(e) => match e.kind() { + io::ErrorKind::Interrupted => continue, + _ => return Err(e), + }, + } + } + + Ok(()) +} + +fn do_encode_random_config_matches_normal_encode(max_input_len: usize) { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::<u8>::new(); + let mut stream_encoded = Vec::<u8>::new(); + let mut normal_encoded = String::new(); + + for _ in 0..1_000 { + orig_data.clear(); + stream_encoded.clear(); + normal_encoded.clear(); + + let orig_len: usize = rng.gen_range(100..20_000); + for _ in 0..orig_len { + orig_data.push(rng.gen()); + } + + // encode the normal way + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + // encode via the stream encoder + { + let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, &engine); + let mut bytes_consumed = 0; + while bytes_consumed < orig_len { + let input_len: usize = + cmp::min(rng.gen_range(0..max_input_len), orig_len - bytes_consumed); + + // write a little bit of the data + stream_encoder + .write_all(&orig_data[bytes_consumed..bytes_consumed + input_len]) + .unwrap(); + + bytes_consumed += input_len; + } + + let _ = stream_encoder.finish().unwrap(); + + assert_eq!(orig_len, bytes_consumed); + } + + assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap()); + } +} + +/// A `Write` implementation that returns Interrupted some fraction of the time, randomly. +struct InterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> { + w: &'a mut W, + rng: &'a mut R, + /// In [0, 1]. If a random number in [0, 1] is `<= threshold`, `Write` methods will return + /// an `Interrupted` error + fraction: f64, +} + +impl<'a, W: Write, R: Rng> Write for InterruptingWriter<'a, W, R> { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + if self.rng.gen_range(0.0..1.0) <= self.fraction { + return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted")); + } + + self.w.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + if self.rng.gen_range(0.0..1.0) <= self.fraction { + return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted")); + } + + self.w.flush() + } +} + +/// A `Write` implementation that sometimes will only write part of its input. +struct PartialInterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> { + w: &'a mut W, + rng: &'a mut R, + /// In [0, 1]. If a random number in [0, 1] is `<= threshold`, `write()` will write all its + /// input. Otherwise, it will write a random substring + full_input_fraction: f64, + no_interrupt_fraction: f64, +} + +impl<'a, W: Write, R: Rng> Write for PartialInterruptingWriter<'a, W, R> { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + if self.rng.gen_range(0.0..1.0) > self.no_interrupt_fraction { + return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted")); + } + + if self.rng.gen_range(0.0..1.0) <= self.full_input_fraction || buf.is_empty() { + // pass through the buf untouched + self.w.write(buf) + } else { + // only use a prefix of it + self.w + .write(&buf[0..(self.rng.gen_range(0..(buf.len() - 1)))]) + } + } + + fn flush(&mut self) -> io::Result<()> { + self.w.flush() + } +} diff --git a/third_party/rust/base64/src/write/mod.rs b/third_party/rust/base64/src/write/mod.rs new file mode 100644 index 0000000000..2a617db9de --- /dev/null +++ b/third_party/rust/base64/src/write/mod.rs @@ -0,0 +1,11 @@ +//! Implementations of `io::Write` to transparently handle base64. +mod encoder; +mod encoder_string_writer; + +pub use self::{ + encoder::EncoderWriter, + encoder_string_writer::{EncoderStringWriter, StrConsumer}, +}; + +#[cfg(test)] +mod encoder_tests; |