summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/codec.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/codec.rs')
-rw-r--r--third_party/rust/prio/src/codec.rs210
1 files changed, 146 insertions, 64 deletions
diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs
index 71f4f8ce5f..98e6299abd 100644
--- a/third_party/rust/prio/src/codec.rs
+++ b/third_party/rust/prio/src/codec.rs
@@ -20,6 +20,7 @@ use std::{
/// An error that occurred during decoding.
#[derive(Debug, thiserror::Error)]
+#[non_exhaustive]
pub enum CodecError {
/// An I/O error.
#[error("I/O error")]
@@ -33,6 +34,10 @@ pub enum CodecError {
#[error("length prefix of encoded vector overflows buffer: {0}")]
LengthPrefixTooBig(usize),
+ /// The byte length of a vector exceeded the range of its length prefix.
+ #[error("vector length exceeded range of length prefix")]
+ LengthPrefixOverflow,
+
/// Custom errors from [`Decode`] implementations.
#[error("other error: {0}")]
Other(#[source] Box<dyn Error + 'static + Send + Sync>),
@@ -97,10 +102,10 @@ impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D {
/// Describes how to encode objects into a byte sequence.
pub trait Encode {
/// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
- fn encode(&self, bytes: &mut Vec<u8>);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError>;
/// Convenience method to encode a value into a new `Vec<u8>`.
- fn get_encoded(&self) -> Vec<u8> {
+ fn get_encoded(&self) -> Result<Vec<u8>, CodecError> {
self.get_encoded_with_param(&())
}
@@ -116,17 +121,21 @@ pub trait ParameterizedEncode<P> {
/// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
/// `encoding_parameter` provides details of the wire encoding, used to control how the value
/// is encoded.
- fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>);
+ fn encode_with_param(
+ &self,
+ encoding_parameter: &P,
+ bytes: &mut Vec<u8>,
+ ) -> Result<(), CodecError>;
/// Convenience method to encode a value into a new `Vec<u8>`.
- fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> {
+ fn get_encoded_with_param(&self, encoding_parameter: &P) -> Result<Vec<u8>, CodecError> {
let mut ret = if let Some(length) = self.encoded_len_with_param(encoding_parameter) {
Vec::with_capacity(length)
} else {
Vec::new()
};
- self.encode_with_param(encoding_parameter, &mut ret);
- ret
+ self.encode_with_param(encoding_parameter, &mut ret)?;
+ Ok(ret)
}
/// Returns an optional hint indicating how many bytes will be required to encode this value, or
@@ -139,7 +148,11 @@ pub trait ParameterizedEncode<P> {
/// Provide a blanket implementation so that any [`Encode`] can be used as a
/// `ParameterizedEncode<T>` for any `T`.
impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
- fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) {
+ fn encode_with_param(
+ &self,
+ _encoding_parameter: &T,
+ bytes: &mut Vec<u8>,
+ ) -> Result<(), CodecError> {
self.encode(bytes)
}
@@ -155,7 +168,9 @@ impl Decode for () {
}
impl Encode for () {
- fn encode(&self, _bytes: &mut Vec<u8>) {}
+ fn encode(&self, _bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ Ok(())
+ }
fn encoded_len(&self) -> Option<usize> {
Some(0)
@@ -171,8 +186,9 @@ impl Decode for u8 {
}
impl Encode for u8 {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.push(*self);
+ Ok(())
}
fn encoded_len(&self) -> Option<usize> {
@@ -187,8 +203,9 @@ impl Decode for u16 {
}
impl Encode for u16 {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&u16::to_be_bytes(*self));
+ Ok(())
}
fn encoded_len(&self) -> Option<usize> {
@@ -208,9 +225,10 @@ impl Decode for U24 {
}
impl Encode for U24 {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
// Encode lower three bytes of the u32 as u24
bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
+ Ok(())
}
fn encoded_len(&self) -> Option<usize> {
@@ -225,8 +243,9 @@ impl Decode for u32 {
}
impl Encode for u32 {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&u32::to_be_bytes(*self));
+ Ok(())
}
fn encoded_len(&self) -> Option<usize> {
@@ -241,8 +260,9 @@ impl Decode for u64 {
}
impl Encode for u64 {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&u64::to_be_bytes(*self));
+ Ok(())
}
fn encoded_len(&self) -> Option<usize> {
@@ -257,18 +277,19 @@ pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
-) {
+) -> Result<(), CodecError> {
// Reserve space to later write length
let len_offset = bytes.len();
bytes.push(0);
for item in items {
- item.encode_with_param(encoding_parameter, bytes);
+ item.encode_with_param(encoding_parameter, bytes)?;
}
- let len = bytes.len() - len_offset - 1;
- assert!(len <= usize::from(u8::MAX));
- bytes[len_offset] = len as u8;
+ let len =
+ u8::try_from(bytes.len() - len_offset - 1).map_err(|_| CodecError::LengthPrefixOverflow)?;
+ bytes[len_offset] = len;
+ Ok(())
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
@@ -292,20 +313,19 @@ pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
-) {
+) -> Result<(), CodecError> {
// Reserve space to later write length
let len_offset = bytes.len();
- 0u16.encode(bytes);
+ 0u16.encode(bytes)?;
for item in items {
- item.encode_with_param(encoding_parameter, bytes);
+ item.encode_with_param(encoding_parameter, bytes)?;
}
- let len = bytes.len() - len_offset - 2;
- assert!(len <= usize::from(u16::MAX));
- for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() {
- bytes[len_offset + offset] = *byte;
- }
+ let len = u16::try_from(bytes.len() - len_offset - 2)
+ .map_err(|_| CodecError::LengthPrefixOverflow)?;
+ bytes[len_offset..len_offset + 2].copy_from_slice(&len.to_be_bytes());
+ Ok(())
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
@@ -330,20 +350,22 @@ pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
-) {
+) -> Result<(), CodecError> {
// Reserve space to later write length
let len_offset = bytes.len();
- U24(0).encode(bytes);
+ U24(0).encode(bytes)?;
for item in items {
- item.encode_with_param(encoding_parameter, bytes);
+ item.encode_with_param(encoding_parameter, bytes)?;
}
- let len = bytes.len() - len_offset - 3;
- assert!(len <= 0xffffff);
- for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() {
- bytes[len_offset + offset] = *byte;
+ let len = u32::try_from(bytes.len() - len_offset - 3)
+ .map_err(|_| CodecError::LengthPrefixOverflow)?;
+ if len > 0xffffff {
+ return Err(CodecError::LengthPrefixOverflow);
}
+ bytes[len_offset..len_offset + 3].copy_from_slice(&len.to_be_bytes()[1..]);
+ Ok(())
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
@@ -368,20 +390,19 @@ pub fn encode_u32_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
-) {
+) -> Result<(), CodecError> {
// Reserve space to later write length
let len_offset = bytes.len();
- 0u32.encode(bytes);
+ 0u32.encode(bytes)?;
for item in items {
- item.encode_with_param(encoding_parameter, bytes);
+ item.encode_with_param(encoding_parameter, bytes)?;
}
- let len = bytes.len() - len_offset - 4;
- let len: u32 = len.try_into().expect("Length too large");
- for (offset, byte) in len.to_be_bytes().iter().enumerate() {
- bytes[len_offset + offset] = *byte;
- }
+ let len = u32::try_from(bytes.len() - len_offset - 4)
+ .map_err(|_| CodecError::LengthPrefixOverflow)?;
+ bytes[len_offset..len_offset + 4].copy_from_slice(&len.to_be_bytes());
+ Ok(())
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
@@ -432,6 +453,7 @@ fn decode_items<P, D: ParameterizedDecode<P>>(
#[cfg(test)]
mod tests {
+ use std::io::ErrorKind;
use super::*;
use assert_matches::assert_matches;
@@ -439,7 +461,7 @@ mod tests {
#[test]
fn encode_nothing() {
let mut bytes = vec![];
- ().encode(&mut bytes);
+ ().encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), 0);
}
@@ -448,7 +470,7 @@ mod tests {
let value = 100u8;
let mut bytes = vec![];
- value.encode(&mut bytes);
+ value.encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), 1);
let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
@@ -460,7 +482,7 @@ mod tests {
let value = 1000u16;
let mut bytes = vec![];
- value.encode(&mut bytes);
+ value.encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), 2);
// Check endianness of encoding
assert_eq!(bytes, vec![3, 232]);
@@ -474,7 +496,7 @@ mod tests {
let value = U24(1_000_000u32);
let mut bytes = vec![];
- value.encode(&mut bytes);
+ value.encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), 3);
// Check endianness of encoding
assert_eq!(bytes, vec![15, 66, 64]);
@@ -488,7 +510,7 @@ mod tests {
let value = 134_217_728u32;
let mut bytes = vec![];
- value.encode(&mut bytes);
+ value.encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), 4);
// Check endianness of encoding
assert_eq!(bytes, vec![8, 0, 0, 0]);
@@ -502,7 +524,7 @@ mod tests {
let value = 137_438_953_472u64;
let mut bytes = vec![];
- value.encode(&mut bytes);
+ value.encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), 8);
// Check endianness of encoding
assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
@@ -521,12 +543,12 @@ mod tests {
}
impl Encode for TestMessage {
- fn encode(&self, bytes: &mut Vec<u8>) {
- self.field_u8.encode(bytes);
- self.field_u16.encode(bytes);
- self.field_u24.encode(bytes);
- self.field_u32.encode(bytes);
- self.field_u64.encode(bytes);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ self.field_u8.encode(bytes)?;
+ self.field_u16.encode(bytes)?;
+ self.field_u24.encode(bytes)?;
+ self.field_u32.encode(bytes)?;
+ self.field_u64.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
@@ -584,7 +606,7 @@ mod tests {
};
let mut bytes = vec![];
- value.encode(&mut bytes);
+ value.encode(&mut bytes).unwrap();
assert_eq!(bytes.len(), TestMessage::encoded_length());
assert_eq!(value.encoded_len().unwrap(), TestMessage::encoded_length());
@@ -622,7 +644,7 @@ mod tests {
fn roundtrip_variable_length_u8() {
let values = messages_vec();
let mut bytes = vec![];
- encode_u8_items(&mut bytes, &(), &values);
+ encode_u8_items(&mut bytes, &(), &values).unwrap();
assert_eq!(
bytes.len(),
@@ -640,7 +662,7 @@ mod tests {
fn roundtrip_variable_length_u16() {
let values = messages_vec();
let mut bytes = vec![];
- encode_u16_items(&mut bytes, &(), &values);
+ encode_u16_items(&mut bytes, &(), &values).unwrap();
assert_eq!(
bytes.len(),
@@ -661,7 +683,7 @@ mod tests {
fn roundtrip_variable_length_u24() {
let values = messages_vec();
let mut bytes = vec![];
- encode_u24_items(&mut bytes, &(), &values);
+ encode_u24_items(&mut bytes, &(), &values).unwrap();
assert_eq!(
bytes.len(),
@@ -682,7 +704,7 @@ mod tests {
fn roundtrip_variable_length_u32() {
let values = messages_vec();
let mut bytes = Vec::new();
- encode_u32_items(&mut bytes, &(), &values);
+ encode_u32_items(&mut bytes, &(), &values).unwrap();
assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length());
@@ -697,6 +719,21 @@ mod tests {
}
#[test]
+ fn decode_too_short() {
+ let values = messages_vec();
+ let mut bytes = Vec::new();
+ encode_u32_items(&mut bytes, &(), &values).unwrap();
+
+ let error =
+ decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..3])).unwrap_err();
+ assert_matches!(error, CodecError::Io(e) => assert_eq!(e.kind(), ErrorKind::UnexpectedEof));
+
+ let error =
+ decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..4])).unwrap_err();
+ assert_matches!(error, CodecError::LengthPrefixTooBig(_));
+ }
+
+ #[test]
fn decode_items_overflow() {
let encoded = vec![1u8];
@@ -724,11 +761,56 @@ mod tests {
#[test]
fn length_hint_correctness() {
- assert_eq!(().encoded_len().unwrap(), ().get_encoded().len());
- assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().len());
- assert_eq!(0u16.encoded_len().unwrap(), 0u16.get_encoded().len());
- assert_eq!(U24(0).encoded_len().unwrap(), U24(0).get_encoded().len());
- assert_eq!(0u32.encoded_len().unwrap(), 0u32.get_encoded().len());
- assert_eq!(0u64.encoded_len().unwrap(), 0u64.get_encoded().len());
+ assert_eq!(().encoded_len().unwrap(), ().get_encoded().unwrap().len());
+ assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().unwrap().len());
+ assert_eq!(
+ 0u16.encoded_len().unwrap(),
+ 0u16.get_encoded().unwrap().len()
+ );
+ assert_eq!(
+ U24(0).encoded_len().unwrap(),
+ U24(0).get_encoded().unwrap().len()
+ );
+ assert_eq!(
+ 0u32.encoded_len().unwrap(),
+ 0u32.get_encoded().unwrap().len()
+ );
+ assert_eq!(
+ 0u64.encoded_len().unwrap(),
+ 0u64.get_encoded().unwrap().len()
+ );
+ }
+
+ #[test]
+ fn get_decoded_leftover() {
+ let encoded_good = [1, 2, 3, 4];
+ assert_matches!(u32::get_decoded(&encoded_good).unwrap(), 0x01020304u32);
+
+ let encoded_bad = [1, 2, 3, 4, 5];
+ let error = u32::get_decoded(&encoded_bad).unwrap_err();
+ assert_matches!(error, CodecError::BytesLeftOver(1));
+ }
+
+ #[test]
+ fn encoded_len_backwards_compatibility() {
+ struct MyMessage;
+
+ impl Encode for MyMessage {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ bytes.extend_from_slice(b"Hello, world");
+ Ok(())
+ }
+ }
+
+ assert_eq!(MyMessage.encoded_len(), None);
+
+ assert_eq!(MyMessage.get_encoded().unwrap(), b"Hello, world");
+ }
+
+ #[test]
+ fn encode_length_prefix_overflow() {
+ let mut bytes = Vec::new();
+ let error = encode_u8_items(&mut bytes, &(), &[1u8; u8::MAX as usize + 1]).unwrap_err();
+ assert_matches!(error, CodecError::LengthPrefixOverflow);
}
}