diff options
Diffstat (limited to 'third_party/rust/base64/src/read/decoder_tests.rs')
-rw-r--r-- | third_party/rust/base64/src/read/decoder_tests.rs | 346 |
1 files changed, 346 insertions, 0 deletions
diff --git a/third_party/rust/base64/src/read/decoder_tests.rs b/third_party/rust/base64/src/read/decoder_tests.rs new file mode 100644 index 0000000000..65d58d8e3f --- /dev/null +++ b/third_party/rust/base64/src/read/decoder_tests.rs @@ -0,0 +1,346 @@ +use std::{ + cmp, + io::{self, Read as _}, + iter, +}; + +use rand::{Rng as _, RngCore as _}; + +use super::decoder::{DecoderReader, BUF_SIZE}; +use crate::{ + engine::{general_purpose::STANDARD, Engine, GeneralPurpose}, + tests::{random_alphabet, random_config, random_engine}, + DecodeError, +}; + +#[test] +fn simple() { + let tests: &[(&[u8], &[u8])] = &[ + (&b"0"[..], &b"MA=="[..]), + (b"01", b"MDE="), + (b"012", b"MDEy"), + (b"0123", b"MDEyMw=="), + (b"01234", b"MDEyMzQ="), + (b"012345", b"MDEyMzQ1"), + (b"0123456", b"MDEyMzQ1Ng=="), + (b"01234567", b"MDEyMzQ1Njc="), + (b"012345678", b"MDEyMzQ1Njc4"), + (b"0123456789", b"MDEyMzQ1Njc4OQ=="), + ][..]; + + for (text_expected, base64data) in tests.iter() { + // Read n bytes at a time. + for n in 1..base64data.len() + 1 { + let mut wrapped_reader = io::Cursor::new(base64data); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD); + + // handle errors as you normally would + let mut text_got = Vec::new(); + let mut buffer = vec![0u8; n]; + while let Ok(read) = decoder.read(&mut buffer[..]) { + if read == 0 { + break; + } + text_got.extend_from_slice(&buffer[..read]); + } + + assert_eq!( + text_got, + *text_expected, + "\nGot: {}\nExpected: {}", + String::from_utf8_lossy(&text_got[..]), + String::from_utf8_lossy(text_expected) + ); + } + } +} + +// Make sure we error out on trailing junk. +#[test] +fn trailing_junk() { + let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..]; + + for base64data in tests.iter() { + // Read n bytes at a time. + for n in 1..base64data.len() + 1 { + let mut wrapped_reader = io::Cursor::new(base64data); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD); + + // handle errors as you normally would + let mut buffer = vec![0u8; n]; + let mut saw_error = false; + loop { + match decoder.read(&mut buffer[..]) { + Err(_) => { + saw_error = true; + break; + } + Ok(read) if read == 0 => break, + Ok(_) => (), + } + } + + assert!(saw_error); + } + } +} + +#[test] +fn handles_short_read_from_delegate() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(0..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + bytes.truncate(size); + rng.fill_bytes(&mut bytes[..size]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + engine.encode_string(&bytes[..], &mut b64); + + let mut wrapped_reader = io::Cursor::new(b64.as_bytes()); + let mut short_reader = RandomShortRead { + delegate: &mut wrapped_reader, + rng: &mut rng, + }; + + let mut decoder = DecoderReader::new(&mut short_reader, &engine); + + let decoded_len = decoder.read_to_end(&mut decoded).unwrap(); + assert_eq!(size, decoded_len); + assert_eq!(&bytes[..], &decoded[..]); + } +} + +#[test] +fn read_in_short_increments() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(0..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + // leave room to play around with larger buffers + decoded.extend(iter::repeat(0).take(size * 3)); + + rng.fill_bytes(&mut bytes[..]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + + engine.encode_string(&bytes[..], &mut b64); + + let mut wrapped_reader = io::Cursor::new(&b64[..]); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); + + consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder); + } +} + +#[test] +fn read_in_short_increments_with_short_delegate_reads() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(0..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + // leave room to play around with larger buffers + decoded.extend(iter::repeat(0).take(size * 3)); + + rng.fill_bytes(&mut bytes[..]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + + engine.encode_string(&bytes[..], &mut b64); + + let mut base_reader = io::Cursor::new(&b64[..]); + let mut decoder = DecoderReader::new(&mut base_reader, &engine); + let mut short_reader = RandomShortRead { + delegate: &mut decoder, + rng: &mut rand::thread_rng(), + }; + + consume_with_short_reads_and_validate( + &mut rng, + &bytes[..], + &mut decoded, + &mut short_reader, + ); + } +} + +#[test] +fn reports_invalid_last_symbol_correctly() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut b64_bytes = Vec::new(); + let mut decoded = Vec::new(); + let mut bulk_decoded = Vec::new(); + + for _ in 0..1_000 { + bytes.clear(); + b64.clear(); + b64_bytes.clear(); + + let size = rng.gen_range(1..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + decoded.extend(iter::repeat(0).take(size)); + rng.fill_bytes(&mut bytes[..]); + assert_eq!(size, bytes.len()); + + let config = random_config(&mut rng); + let alphabet = random_alphabet(&mut rng); + // changing padding will cause invalid padding errors when we twiddle the last byte + let engine = GeneralPurpose::new(alphabet, config.with_encode_padding(false)); + engine.encode_string(&bytes[..], &mut b64); + b64_bytes.extend(b64.bytes()); + assert_eq!(b64_bytes.len(), b64.len()); + + // change the last character to every possible symbol. Should behave the same as bulk + // decoding whether invalid or valid. + for &s1 in alphabet.symbols.iter() { + decoded.clear(); + bulk_decoded.clear(); + + // replace the last + *b64_bytes.last_mut().unwrap() = s1; + let bulk_res = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded); + + let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); + + let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| { + e.into_inner() + .and_then(|e| e.downcast::<DecodeError>().ok()) + }); + + assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res); + } + } +} + +#[test] +fn reports_invalid_byte_correctly() { + let mut rng = rand::thread_rng(); + let mut bytes = Vec::new(); + let mut b64 = String::new(); + let mut decoded = Vec::new(); + + for _ in 0..10_000 { + bytes.clear(); + b64.clear(); + decoded.clear(); + + let size = rng.gen_range(1..(10 * BUF_SIZE)); + bytes.extend(iter::repeat(0).take(size)); + rng.fill_bytes(&mut bytes[..size]); + assert_eq!(size, bytes.len()); + + let engine = random_engine(&mut rng); + + engine.encode_string(&bytes[..], &mut b64); + // replace one byte, somewhere, with '*', which is invalid + let bad_byte_pos = rng.gen_range(0..b64.len()); + let mut b64_bytes = b64.bytes().collect::<Vec<u8>>(); + b64_bytes[bad_byte_pos] = b'*'; + + let mut wrapped_reader = io::Cursor::new(b64_bytes.clone()); + let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine); + + // some gymnastics to avoid double-moving the io::Error, which is not Copy + let read_decode_err = decoder + .read_to_end(&mut decoded) + .map_err(|e| { + let kind = e.kind(); + let inner = e + .into_inner() + .and_then(|e| e.downcast::<DecodeError>().ok()); + inner.map(|i| (*i, kind)) + }) + .err() + .and_then(|o| o); + + let mut bulk_buf = Vec::new(); + let bulk_decode_err = engine.decode_vec(&b64_bytes[..], &mut bulk_buf).err(); + + // it's tricky to predict where the invalid data's offset will be since if it's in the last + // chunk it will be reported at the first padding location because it's treated as invalid + // padding. So, we just check that it's the same as it is for decoding all at once. + assert_eq!( + bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)), + read_decode_err + ); + } +} + +fn consume_with_short_reads_and_validate<R: io::Read>( + rng: &mut rand::rngs::ThreadRng, + expected_bytes: &[u8], + decoded: &mut [u8], + short_reader: &mut R, +) { + let mut total_read = 0_usize; + loop { + assert!( + total_read <= expected_bytes.len(), + "tr {} size {}", + total_read, + expected_bytes.len() + ); + if total_read == expected_bytes.len() { + assert_eq!(expected_bytes, &decoded[..total_read]); + // should be done + assert_eq!(0, short_reader.read(&mut *decoded).unwrap()); + // didn't write anything + assert_eq!(expected_bytes, &decoded[..total_read]); + + break; + } + let decode_len = rng.gen_range(1..cmp::max(2, expected_bytes.len() * 2)); + + let read = short_reader + .read(&mut decoded[total_read..total_read + decode_len]) + .unwrap(); + total_read += read; + } +} + +/// Limits how many bytes a reader will provide in each read call. +/// Useful for shaking out code that may work fine only with typical input sources that always fill +/// the buffer. +struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> { + delegate: &'b mut R, + rng: &'a mut N, +} + +impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> { + fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> { + // avoid 0 since it means EOF for non-empty buffers + let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len()); + + self.delegate.read(&mut buf[..effective_len]) + } +} |