diff options
Diffstat (limited to 'third_party/rust/prio/src/codec.rs')
-rw-r--r-- | third_party/rust/prio/src/codec.rs | 210 |
1 files changed, 146 insertions, 64 deletions
diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs index 71f4f8ce5f..98e6299abd 100644 --- a/third_party/rust/prio/src/codec.rs +++ b/third_party/rust/prio/src/codec.rs @@ -20,6 +20,7 @@ use std::{ /// An error that occurred during decoding. #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum CodecError { /// An I/O error. #[error("I/O error")] @@ -33,6 +34,10 @@ pub enum CodecError { #[error("length prefix of encoded vector overflows buffer: {0}")] LengthPrefixTooBig(usize), + /// The byte length of a vector exceeded the range of its length prefix. + #[error("vector length exceeded range of length prefix")] + LengthPrefixOverflow, + /// Custom errors from [`Decode`] implementations. #[error("other error: {0}")] Other(#[source] Box<dyn Error + 'static + Send + Sync>), @@ -97,10 +102,10 @@ impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D { /// Describes how to encode objects into a byte sequence. pub trait Encode { /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. - fn encode(&self, bytes: &mut Vec<u8>); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError>; /// Convenience method to encode a value into a new `Vec<u8>`. - fn get_encoded(&self) -> Vec<u8> { + fn get_encoded(&self) -> Result<Vec<u8>, CodecError> { self.get_encoded_with_param(&()) } @@ -116,17 +121,21 @@ pub trait ParameterizedEncode<P> { /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. /// `encoding_parameter` provides details of the wire encoding, used to control how the value /// is encoded. - fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>); + fn encode_with_param( + &self, + encoding_parameter: &P, + bytes: &mut Vec<u8>, + ) -> Result<(), CodecError>; /// Convenience method to encode a value into a new `Vec<u8>`. - fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> { + fn get_encoded_with_param(&self, encoding_parameter: &P) -> Result<Vec<u8>, CodecError> { let mut ret = if let Some(length) = self.encoded_len_with_param(encoding_parameter) { Vec::with_capacity(length) } else { Vec::new() }; - self.encode_with_param(encoding_parameter, &mut ret); - ret + self.encode_with_param(encoding_parameter, &mut ret)?; + Ok(ret) } /// Returns an optional hint indicating how many bytes will be required to encode this value, or @@ -139,7 +148,11 @@ pub trait ParameterizedEncode<P> { /// Provide a blanket implementation so that any [`Encode`] can be used as a /// `ParameterizedEncode<T>` for any `T`. impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E { - fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) { + fn encode_with_param( + &self, + _encoding_parameter: &T, + bytes: &mut Vec<u8>, + ) -> Result<(), CodecError> { self.encode(bytes) } @@ -155,7 +168,9 @@ impl Decode for () { } impl Encode for () { - fn encode(&self, _bytes: &mut Vec<u8>) {} + fn encode(&self, _bytes: &mut Vec<u8>) -> Result<(), CodecError> { + Ok(()) + } fn encoded_len(&self) -> Option<usize> { Some(0) @@ -171,8 +186,9 @@ impl Decode for u8 { } impl Encode for u8 { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { bytes.push(*self); + Ok(()) } fn encoded_len(&self) -> Option<usize> { @@ -187,8 +203,9 @@ impl Decode for u16 { } impl Encode for u16 { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { bytes.extend_from_slice(&u16::to_be_bytes(*self)); + Ok(()) } fn encoded_len(&self) -> Option<usize> { @@ -208,9 +225,10 @@ impl Decode for U24 { } impl Encode for U24 { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { // Encode lower three bytes of the u32 as u24 bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]); + Ok(()) } fn encoded_len(&self) -> Option<usize> { @@ -225,8 +243,9 @@ impl Decode for u32 { } impl Encode for u32 { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { bytes.extend_from_slice(&u32::to_be_bytes(*self)); + Ok(()) } fn encoded_len(&self) -> Option<usize> { @@ -241,8 +260,9 @@ impl Decode for u64 { } impl Encode for u64 { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { bytes.extend_from_slice(&u64::to_be_bytes(*self)); + Ok(()) } fn encoded_len(&self) -> Option<usize> { @@ -257,18 +277,19 @@ pub fn encode_u8_items<P, E: ParameterizedEncode<P>>( bytes: &mut Vec<u8>, encoding_parameter: &P, items: &[E], -) { +) -> Result<(), CodecError> { // Reserve space to later write length let len_offset = bytes.len(); bytes.push(0); for item in items { - item.encode_with_param(encoding_parameter, bytes); + item.encode_with_param(encoding_parameter, bytes)?; } - let len = bytes.len() - len_offset - 1; - assert!(len <= usize::from(u8::MAX)); - bytes[len_offset] = len as u8; + let len = + u8::try_from(bytes.len() - len_offset - 1).map_err(|_| CodecError::LengthPrefixOverflow)?; + bytes[len_offset] = len; + Ok(()) } /// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of @@ -292,20 +313,19 @@ pub fn encode_u16_items<P, E: ParameterizedEncode<P>>( bytes: &mut Vec<u8>, encoding_parameter: &P, items: &[E], -) { +) -> Result<(), CodecError> { // Reserve space to later write length let len_offset = bytes.len(); - 0u16.encode(bytes); + 0u16.encode(bytes)?; for item in items { - item.encode_with_param(encoding_parameter, bytes); + item.encode_with_param(encoding_parameter, bytes)?; } - let len = bytes.len() - len_offset - 2; - assert!(len <= usize::from(u16::MAX)); - for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() { - bytes[len_offset + offset] = *byte; - } + let len = u16::try_from(bytes.len() - len_offset - 2) + .map_err(|_| CodecError::LengthPrefixOverflow)?; + bytes[len_offset..len_offset + 2].copy_from_slice(&len.to_be_bytes()); + Ok(()) } /// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of @@ -330,20 +350,22 @@ pub fn encode_u24_items<P, E: ParameterizedEncode<P>>( bytes: &mut Vec<u8>, encoding_parameter: &P, items: &[E], -) { +) -> Result<(), CodecError> { // Reserve space to later write length let len_offset = bytes.len(); - U24(0).encode(bytes); + U24(0).encode(bytes)?; for item in items { - item.encode_with_param(encoding_parameter, bytes); + item.encode_with_param(encoding_parameter, bytes)?; } - let len = bytes.len() - len_offset - 3; - assert!(len <= 0xffffff); - for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() { - bytes[len_offset + offset] = *byte; + let len = u32::try_from(bytes.len() - len_offset - 3) + .map_err(|_| CodecError::LengthPrefixOverflow)?; + if len > 0xffffff { + return Err(CodecError::LengthPrefixOverflow); } + bytes[len_offset..len_offset + 3].copy_from_slice(&len.to_be_bytes()[1..]); + Ok(()) } /// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of @@ -368,20 +390,19 @@ pub fn encode_u32_items<P, E: ParameterizedEncode<P>>( bytes: &mut Vec<u8>, encoding_parameter: &P, items: &[E], -) { +) -> Result<(), CodecError> { // Reserve space to later write length let len_offset = bytes.len(); - 0u32.encode(bytes); + 0u32.encode(bytes)?; for item in items { - item.encode_with_param(encoding_parameter, bytes); + item.encode_with_param(encoding_parameter, bytes)?; } - let len = bytes.len() - len_offset - 4; - let len: u32 = len.try_into().expect("Length too large"); - for (offset, byte) in len.to_be_bytes().iter().enumerate() { - bytes[len_offset + offset] = *byte; - } + let len = u32::try_from(bytes.len() - len_offset - 4) + .map_err(|_| CodecError::LengthPrefixOverflow)?; + bytes[len_offset..len_offset + 4].copy_from_slice(&len.to_be_bytes()); + Ok(()) } /// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of @@ -432,6 +453,7 @@ fn decode_items<P, D: ParameterizedDecode<P>>( #[cfg(test)] mod tests { + use std::io::ErrorKind; use super::*; use assert_matches::assert_matches; @@ -439,7 +461,7 @@ mod tests { #[test] fn encode_nothing() { let mut bytes = vec![]; - ().encode(&mut bytes); + ().encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), 0); } @@ -448,7 +470,7 @@ mod tests { let value = 100u8; let mut bytes = vec![]; - value.encode(&mut bytes); + value.encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), 1); let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap(); @@ -460,7 +482,7 @@ mod tests { let value = 1000u16; let mut bytes = vec![]; - value.encode(&mut bytes); + value.encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), 2); // Check endianness of encoding assert_eq!(bytes, vec![3, 232]); @@ -474,7 +496,7 @@ mod tests { let value = U24(1_000_000u32); let mut bytes = vec![]; - value.encode(&mut bytes); + value.encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), 3); // Check endianness of encoding assert_eq!(bytes, vec![15, 66, 64]); @@ -488,7 +510,7 @@ mod tests { let value = 134_217_728u32; let mut bytes = vec![]; - value.encode(&mut bytes); + value.encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), 4); // Check endianness of encoding assert_eq!(bytes, vec![8, 0, 0, 0]); @@ -502,7 +524,7 @@ mod tests { let value = 137_438_953_472u64; let mut bytes = vec![]; - value.encode(&mut bytes); + value.encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), 8); // Check endianness of encoding assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]); @@ -521,12 +543,12 @@ mod tests { } impl Encode for TestMessage { - fn encode(&self, bytes: &mut Vec<u8>) { - self.field_u8.encode(bytes); - self.field_u16.encode(bytes); - self.field_u24.encode(bytes); - self.field_u32.encode(bytes); - self.field_u64.encode(bytes); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + self.field_u8.encode(bytes)?; + self.field_u16.encode(bytes)?; + self.field_u24.encode(bytes)?; + self.field_u32.encode(bytes)?; + self.field_u64.encode(bytes) } fn encoded_len(&self) -> Option<usize> { @@ -584,7 +606,7 @@ mod tests { }; let mut bytes = vec![]; - value.encode(&mut bytes); + value.encode(&mut bytes).unwrap(); assert_eq!(bytes.len(), TestMessage::encoded_length()); assert_eq!(value.encoded_len().unwrap(), TestMessage::encoded_length()); @@ -622,7 +644,7 @@ mod tests { fn roundtrip_variable_length_u8() { let values = messages_vec(); let mut bytes = vec![]; - encode_u8_items(&mut bytes, &(), &values); + encode_u8_items(&mut bytes, &(), &values).unwrap(); assert_eq!( bytes.len(), @@ -640,7 +662,7 @@ mod tests { fn roundtrip_variable_length_u16() { let values = messages_vec(); let mut bytes = vec![]; - encode_u16_items(&mut bytes, &(), &values); + encode_u16_items(&mut bytes, &(), &values).unwrap(); assert_eq!( bytes.len(), @@ -661,7 +683,7 @@ mod tests { fn roundtrip_variable_length_u24() { let values = messages_vec(); let mut bytes = vec![]; - encode_u24_items(&mut bytes, &(), &values); + encode_u24_items(&mut bytes, &(), &values).unwrap(); assert_eq!( bytes.len(), @@ -682,7 +704,7 @@ mod tests { fn roundtrip_variable_length_u32() { let values = messages_vec(); let mut bytes = Vec::new(); - encode_u32_items(&mut bytes, &(), &values); + encode_u32_items(&mut bytes, &(), &values).unwrap(); assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length()); @@ -697,6 +719,21 @@ mod tests { } #[test] + fn decode_too_short() { + let values = messages_vec(); + let mut bytes = Vec::new(); + encode_u32_items(&mut bytes, &(), &values).unwrap(); + + let error = + decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..3])).unwrap_err(); + assert_matches!(error, CodecError::Io(e) => assert_eq!(e.kind(), ErrorKind::UnexpectedEof)); + + let error = + decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..4])).unwrap_err(); + assert_matches!(error, CodecError::LengthPrefixTooBig(_)); + } + + #[test] fn decode_items_overflow() { let encoded = vec![1u8]; @@ -724,11 +761,56 @@ mod tests { #[test] fn length_hint_correctness() { - assert_eq!(().encoded_len().unwrap(), ().get_encoded().len()); - assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().len()); - assert_eq!(0u16.encoded_len().unwrap(), 0u16.get_encoded().len()); - assert_eq!(U24(0).encoded_len().unwrap(), U24(0).get_encoded().len()); - assert_eq!(0u32.encoded_len().unwrap(), 0u32.get_encoded().len()); - assert_eq!(0u64.encoded_len().unwrap(), 0u64.get_encoded().len()); + assert_eq!(().encoded_len().unwrap(), ().get_encoded().unwrap().len()); + assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().unwrap().len()); + assert_eq!( + 0u16.encoded_len().unwrap(), + 0u16.get_encoded().unwrap().len() + ); + assert_eq!( + U24(0).encoded_len().unwrap(), + U24(0).get_encoded().unwrap().len() + ); + assert_eq!( + 0u32.encoded_len().unwrap(), + 0u32.get_encoded().unwrap().len() + ); + assert_eq!( + 0u64.encoded_len().unwrap(), + 0u64.get_encoded().unwrap().len() + ); + } + + #[test] + fn get_decoded_leftover() { + let encoded_good = [1, 2, 3, 4]; + assert_matches!(u32::get_decoded(&encoded_good).unwrap(), 0x01020304u32); + + let encoded_bad = [1, 2, 3, 4, 5]; + let error = u32::get_decoded(&encoded_bad).unwrap_err(); + assert_matches!(error, CodecError::BytesLeftOver(1)); + } + + #[test] + fn encoded_len_backwards_compatibility() { + struct MyMessage; + + impl Encode for MyMessage { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + bytes.extend_from_slice(b"Hello, world"); + Ok(()) + } + } + + assert_eq!(MyMessage.encoded_len(), None); + + assert_eq!(MyMessage.get_encoded().unwrap(), b"Hello, world"); + } + + #[test] + fn encode_length_prefix_overflow() { + let mut bytes = Vec::new(); + let error = encode_u8_items(&mut bytes, &(), &[1u8; u8::MAX as usize + 1]).unwrap_err(); + assert_matches!(error, CodecError::LengthPrefixOverflow); } } |