summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/codec.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/codec.rs')
-rw-r--r--third_party/rust/prio/src/codec.rs658
1 files changed, 658 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs
new file mode 100644
index 0000000000..c3ee8e9db7
--- /dev/null
+++ b/third_party/rust/prio/src/codec.rs
@@ -0,0 +1,658 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Module `codec` provides support for encoding and decoding messages to or from the TLS wire
+//! encoding, as specified in [RFC 8446, Section 3][1]. It provides traits that can be implemented
+//! on values that need to be encoded or decoded, as well as utility functions for encoding
+//! sequences of values.
+//!
+//! [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3
+
+use byteorder::{BigEndian, ReadBytesExt};
+use std::{
+ convert::TryInto,
+ error::Error,
+ io::{Cursor, Read},
+ mem::size_of,
+ num::TryFromIntError,
+};
+
+#[allow(missing_docs)]
+#[derive(Debug, thiserror::Error)]
+pub enum CodecError {
+ #[error("I/O error")]
+ Io(#[from] std::io::Error),
+ #[error("{0} bytes left in buffer after decoding value")]
+ BytesLeftOver(usize),
+ #[error("length prefix of encoded vector overflows buffer: {0}")]
+ LengthPrefixTooBig(usize),
+ #[error("other error: {0}")]
+ Other(#[source] Box<dyn Error + 'static + Send + Sync>),
+ #[error("unexpected value")]
+ UnexpectedValue,
+}
+
+/// Describes how to decode an object from a byte sequence.
+pub trait Decode: Sized {
+ /// Read and decode an encoded object from `bytes`. On success, the decoded value is returned
+ /// and `bytes` is advanced by the encoded size of the value. On failure, an error is returned
+ /// and no further attempt to read from `bytes` should be made.
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>;
+
+ /// Convenience method to get decoded value. Returns an error if [`Self::decode`] fails, or if
+ /// there are any bytes left in `bytes` after decoding a value.
+ fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> {
+ Self::get_decoded_with_param(&(), bytes)
+ }
+}
+
+/// Describes how to decode an object from a byte sequence, with a decoding parameter provided to
+/// provide additional data used in decoding.
+pub trait ParameterizedDecode<P>: Sized {
+ /// Read and decode an encoded object from `bytes`. `decoding_parameter` provides details of the
+ /// wire encoding such as lengths of different portions of the message. On success, the decoded
+ /// value is returned and `bytes` is advanced by the encoded size of the value. On failure, an
+ /// error is returned and no further attempt to read from `bytes` should be made.
+ fn decode_with_param(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError>;
+
+ /// Convenience method to get decoded value. Returns an error if [`Self::decode_with_param`]
+ /// fails, or if there are any bytes left in `bytes` after decoding a value.
+ fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> {
+ let mut cursor = Cursor::new(bytes);
+ let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?;
+ if cursor.position() as usize != bytes.len() {
+ return Err(CodecError::BytesLeftOver(
+ bytes.len() - cursor.position() as usize,
+ ));
+ }
+
+ Ok(decoded)
+ }
+}
+
+// Provide a blanket implementation so that any Decode can be used as a ParameterizedDecode<T> for
+// any T.
+impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D {
+ fn decode_with_param(
+ _decoding_parameter: &T,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Self::decode(bytes)
+ }
+}
+
+/// Describes how to encode objects into a byte sequence.
+pub trait Encode {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ fn encode(&self, bytes: &mut Vec<u8>);
+
+ /// Convenience method to get encoded value.
+ fn get_encoded(&self) -> Vec<u8> {
+ self.get_encoded_with_param(&())
+ }
+}
+
+/// Describes how to encode objects into a byte sequence.
+pub trait ParameterizedEncode<P> {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ /// `encoding_parameter` provides details of the wire encoding, used to control how the value
+ /// is encoded.
+ fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>);
+
+ /// Convenience method to get encoded value.
+ fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> {
+ let mut ret = Vec::new();
+ self.encode_with_param(encoding_parameter, &mut ret);
+ ret
+ }
+}
+
+// Provide a blanket implementation so that any Encode can be used as a ParameterizedEncode<T> for
+// any T.
+impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
+ fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) {
+ self.encode(bytes)
+ }
+}
+
+impl Decode for () {
+ fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(())
+ }
+}
+
+impl Encode for () {
+ fn encode(&self, _bytes: &mut Vec<u8>) {}
+}
+
+impl Decode for u8 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut value = [0u8; size_of::<u8>()];
+ bytes.read_exact(&mut value)?;
+ Ok(value[0])
+ }
+}
+
+impl Encode for u8 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.push(*self);
+ }
+}
+
+impl Decode for u16 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u16::<BigEndian>()?)
+ }
+}
+
+impl Encode for u16 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u16::to_be_bytes(*self));
+ }
+}
+
+/// 24 bit integer, per
+/// [RFC 8443, section 3.3](https://datatracker.ietf.org/doc/html/rfc8446#section-3.3)
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+struct U24(pub u32);
+
+impl Decode for U24 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(U24(bytes.read_u24::<BigEndian>()?))
+ }
+}
+
+impl Encode for U24 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // Encode lower three bytes of the u32 as u24
+ bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
+ }
+}
+
+impl Decode for u32 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u32::<BigEndian>()?)
+ }
+}
+
+impl Encode for u32 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u32::to_be_bytes(*self));
+ }
+}
+
+impl Decode for u64 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u64::<BigEndian>()?)
+ }
+}
+
+impl Encode for u64 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u64::to_be_bytes(*self));
+ }
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ bytes.push(0);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 1;
+ assert!(len <= u8::MAX.into());
+ bytes[len_offset] = len as u8;
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u8_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read one byte to get length of opaque byte vector
+ let length = usize::from(u8::decode(bytes)?);
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ 0u16.encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 2;
+ assert!(len <= u16::MAX.into());
+ for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u16_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read two bytes to get length of opaque byte vector
+ let length = usize::from(u16::decode(bytes)?);
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
+/// `0xffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ U24(0).encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 3;
+ assert!(len <= 0xffffff);
+ for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u24_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read three bytes to get length of opaque byte vector
+ let length = U24::decode(bytes)?.0 as usize;
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
+/// `0xffffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u32_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ 0u32.encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 4;
+ let len: u32 = len.try_into().expect("Length too large");
+ for (offset, byte) in len.to_be_bytes().iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u32_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read four bytes to get length of opaque byte vector.
+ let len: usize = u32::decode(bytes)?
+ .try_into()
+ .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?;
+
+ decode_items(len, decoding_parameter, bytes)
+}
+
+/// Decode the next `length` bytes from `bytes` into as many instances of `D` as possible.
+fn decode_items<P, D: ParameterizedDecode<P>>(
+ length: usize,
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ let mut decoded = Vec::new();
+ let initial_position = bytes.position() as usize;
+
+ // Create cursor over specified portion of provided cursor to ensure we can't read past length.
+ let inner = bytes.get_ref();
+
+ // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer.
+ let (items_end, overflowed) = initial_position.overflowing_add(length);
+ if overflowed || items_end > inner.len() {
+ return Err(CodecError::LengthPrefixTooBig(length));
+ }
+
+ let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]);
+
+ while sub.position() < length as u64 {
+ decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?);
+ }
+
+ // Advance outer cursor by the amount read in the inner cursor
+ bytes.set_position(initial_position as u64 + sub.position());
+
+ Ok(decoded)
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn encode_nothing() {
+ let mut bytes = vec![];
+ ().encode(&mut bytes);
+ assert_eq!(bytes.len(), 0);
+ }
+
+ #[test]
+ fn roundtrip_u8() {
+ let value = 100u8;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 1);
+
+ let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u16() {
+ let value = 1000u16;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 2);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![3, 232]);
+
+ let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u24() {
+ let value = U24(1_000_000u32);
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 3);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![15, 66, 64]);
+
+ let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u32() {
+ let value = 134_217_728u32;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 4);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![8, 0, 0, 0]);
+
+ let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u64() {
+ let value = 137_438_953_472u64;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 8);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
+
+ let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct TestMessage {
+ field_u8: u8,
+ field_u16: u16,
+ field_u24: U24,
+ field_u32: u32,
+ field_u64: u64,
+ }
+
+ impl Encode for TestMessage {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.field_u8.encode(bytes);
+ self.field_u16.encode(bytes);
+ self.field_u24.encode(bytes);
+ self.field_u32.encode(bytes);
+ self.field_u64.encode(bytes);
+ }
+ }
+
+ impl Decode for TestMessage {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let field_u8 = u8::decode(bytes)?;
+ let field_u16 = u16::decode(bytes)?;
+ let field_u24 = U24::decode(bytes)?;
+ let field_u32 = u32::decode(bytes)?;
+ let field_u64 = u64::decode(bytes)?;
+
+ Ok(TestMessage {
+ field_u8,
+ field_u16,
+ field_u24,
+ field_u32,
+ field_u64,
+ })
+ }
+ }
+
+ impl TestMessage {
+ fn encoded_length() -> usize {
+ // u8 field
+ 1 +
+ // u16 field
+ 2 +
+ // u24 field
+ 3 +
+ // u32 field
+ 4 +
+ // u64 field
+ 8
+ }
+ }
+
+ #[test]
+ fn roundtrip_message() {
+ let value = TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ };
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), TestMessage::encoded_length());
+
+ let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ fn messages_vec() -> Vec<TestMessage> {
+ vec![
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ ]
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u8() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u8_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 1 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u16() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u16_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 2 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ // Check endianness of encoded length
+ assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]);
+
+ let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u24() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u24_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 3 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ // Check endianness of encoded length
+ assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]);
+
+ let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u32() {
+ let values = messages_vec();
+ let mut bytes = Vec::new();
+ encode_u32_items(&mut bytes, &(), &values);
+
+ assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length());
+
+ // Check endianness of encoded length.
+ assert_eq!(
+ bytes[0..4],
+ [0, 0, 0, 3 * TestMessage::encoded_length() as u8]
+ );
+
+ let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn decode_items_overflow() {
+ let encoded = vec![1u8];
+
+ let mut cursor = Cursor::new(encoded.as_slice());
+ cursor.set_position(1);
+
+ assert_matches!(
+ decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(),
+ CodecError::LengthPrefixTooBig(usize::MAX)
+ );
+ }
+
+ #[test]
+ fn decode_items_too_big() {
+ let encoded = vec![1u8];
+
+ let mut cursor = Cursor::new(encoded.as_slice());
+ cursor.set_position(1);
+
+ assert_matches!(
+ decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(),
+ CodecError::LengthPrefixTooBig(2)
+ );
+ }
+}