From 36d22d82aa202bb199967e9512281e9a53db42c9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 21:33:14 +0200 Subject: Adding upstream version 115.7.0esr. Signed-off-by: Daniel Baumann --- third_party/rust/prost/src/encoding.rs | 1770 ++++++++++++++++++++++++++++++++ 1 file changed, 1770 insertions(+) create mode 100644 third_party/rust/prost/src/encoding.rs (limited to 'third_party/rust/prost/src/encoding.rs') diff --git a/third_party/rust/prost/src/encoding.rs b/third_party/rust/prost/src/encoding.rs new file mode 100644 index 0000000000..252358685c --- /dev/null +++ b/third_party/rust/prost/src/encoding.rs @@ -0,0 +1,1770 @@ +//! Utility functions and types for encoding and decoding Protobuf types. +//! +//! Meant to be used only from `Message` implementations. + +#![allow(clippy::implicit_hasher, clippy::ptr_arg)] + +use alloc::collections::BTreeMap; +use alloc::format; +use alloc::string::String; +use alloc::vec::Vec; +use core::cmp::min; +use core::convert::TryFrom; +use core::mem; +use core::str; +use core::u32; +use core::usize; + +use ::bytes::{Buf, BufMut, Bytes}; + +use crate::DecodeError; +use crate::Message; + +/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. +/// The buffer must have enough remaining space (maximum 10 bytes). +#[inline] +pub fn encode_varint(mut value: u64, buf: &mut B) +where + B: BufMut, +{ + // Safety notes: + // + // - ptr::write is an unsafe raw pointer write. The use here is safe since the length of the + // uninit slice is checked. + // - advance_mut is unsafe because it could cause uninitialized memory to be advanced over. The + // use here is safe since each byte which is advanced over has been written to in the + // previous loop iteration. + unsafe { + let mut i; + 'outer: loop { + i = 0; + + let uninit_slice = buf.chunk_mut(); + for offset in 0..uninit_slice.len() { + i += 1; + let ptr = uninit_slice.as_mut_ptr().add(offset); + if value < 0x80 { + ptr.write(value as u8); + break 'outer; + } else { + ptr.write(((value & 0x7F) | 0x80) as u8); + value >>= 7; + } + } + + buf.advance_mut(i); + debug_assert!(buf.has_remaining_mut()); + } + + buf.advance_mut(i); + } +} + +/// Decodes a LEB128-encoded variable length integer from the buffer. +pub fn decode_varint(buf: &mut B) -> Result +where + B: Buf, +{ + let bytes = buf.chunk(); + let len = bytes.len(); + if len == 0 { + return Err(DecodeError::new("invalid varint")); + } + + let byte = unsafe { *bytes.get_unchecked(0) }; + if byte < 0x80 { + buf.advance(1); + Ok(u64::from(byte)) + } else if len > 10 || bytes[len - 1] < 0x80 { + let (value, advance) = unsafe { decode_varint_slice(bytes) }?; + buf.advance(advance); + Ok(value) + } else { + decode_varint_slow(buf) + } +} + +/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the +/// number of bytes read. +/// +/// Based loosely on [`ReadVarint64FromArray`][1]. +/// +/// ## Safety +/// +/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last +/// element in bytes is < `0x80`. +/// +/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406 +#[inline] +unsafe fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> { + // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance. + + let mut b: u8; + let mut part0: u32; + b = *bytes.get_unchecked(0); + part0 = u32::from(b); + if b < 0x80 { + return Ok((u64::from(part0), 1)); + }; + part0 -= 0x80; + b = *bytes.get_unchecked(1); + part0 += u32::from(b) << 7; + if b < 0x80 { + return Ok((u64::from(part0), 2)); + }; + part0 -= 0x80 << 7; + b = *bytes.get_unchecked(2); + part0 += u32::from(b) << 14; + if b < 0x80 { + return Ok((u64::from(part0), 3)); + }; + part0 -= 0x80 << 14; + b = *bytes.get_unchecked(3); + part0 += u32::from(b) << 21; + if b < 0x80 { + return Ok((u64::from(part0), 4)); + }; + part0 -= 0x80 << 21; + let value = u64::from(part0); + + let mut part1: u32; + b = *bytes.get_unchecked(4); + part1 = u32::from(b); + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 5)); + }; + part1 -= 0x80; + b = *bytes.get_unchecked(5); + part1 += u32::from(b) << 7; + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 6)); + }; + part1 -= 0x80 << 7; + b = *bytes.get_unchecked(6); + part1 += u32::from(b) << 14; + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 7)); + }; + part1 -= 0x80 << 14; + b = *bytes.get_unchecked(7); + part1 += u32::from(b) << 21; + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 8)); + }; + part1 -= 0x80 << 21; + let value = value + ((u64::from(part1)) << 28); + + let mut part2: u32; + b = *bytes.get_unchecked(8); + part2 = u32::from(b); + if b < 0x80 { + return Ok((value + (u64::from(part2) << 56), 9)); + }; + part2 -= 0x80; + b = *bytes.get_unchecked(9); + part2 += u32::from(b) << 7; + if b < 0x80 { + return Ok((value + (u64::from(part2) << 56), 10)); + }; + + // We have overrun the maximum size of a varint (10 bytes). Assume the data is corrupt. + Err(DecodeError::new("invalid varint")) +} + +/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as +/// necessary. +#[inline(never)] +fn decode_varint_slow(buf: &mut B) -> Result +where + B: Buf, +{ + let mut value = 0; + for count in 0..min(10, buf.remaining()) { + let byte = buf.get_u8(); + value |= u64::from(byte & 0x7F) << (count * 7); + if byte <= 0x7F { + return Ok(value); + } + } + + Err(DecodeError::new("invalid varint")) +} + +/// Additional information passed to every decode/merge function. +/// +/// The context should be passed by value and can be freely cloned. When passing +/// to a function which is decoding a nested object, then use `enter_recursion`. +#[derive(Clone, Debug)] +pub struct DecodeContext { + /// How many times we can recurse in the current decode stack before we hit + /// the recursion limit. + /// + /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be + /// customized. The recursion limit can be ignored by building the Prost + /// crate with the `no-recursion-limit` feature. + #[cfg(not(feature = "no-recursion-limit"))] + recurse_count: u32, +} + +impl Default for DecodeContext { + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + fn default() -> DecodeContext { + DecodeContext { + recurse_count: crate::RECURSION_LIMIT, + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + fn default() -> DecodeContext { + DecodeContext {} + } +} + +impl DecodeContext { + /// Call this function before recursively decoding. + /// + /// There is no `exit` function since this function creates a new `DecodeContext` + /// to be used at the next level of recursion. Continue to use the old context + // at the previous level of recursion. + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + pub(crate) fn enter_recursion(&self) -> DecodeContext { + DecodeContext { + recurse_count: self.recurse_count - 1, + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + pub(crate) fn enter_recursion(&self) -> DecodeContext { + DecodeContext {} + } + + /// Checks whether the recursion limit has been reached in the stack of + /// decodes described by the `DecodeContext` at `self.ctx`. + /// + /// Returns `Ok<()>` if it is ok to continue recursing. + /// Returns `Err` if the recursion limit has been reached. + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> { + if self.recurse_count == 0 { + Err(DecodeError::new("recursion limit reached")) + } else { + Ok(()) + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + #[allow(clippy::unnecessary_wraps)] // needed in other features + pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> { + Ok(()) + } +} + +/// Returns the encoded length of the value in LEB128 variable length format. +/// The returned value will be between 1 and 10, inclusive. +#[inline] +pub fn encoded_len_varint(value: u64) -> usize { + // Based on [VarintSize64][1]. + // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309 + ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum WireType { + Varint = 0, + SixtyFourBit = 1, + LengthDelimited = 2, + StartGroup = 3, + EndGroup = 4, + ThirtyTwoBit = 5, +} + +pub const MIN_TAG: u32 = 1; +pub const MAX_TAG: u32 = (1 << 29) - 1; + +impl TryFrom for WireType { + type Error = DecodeError; + + #[inline] + fn try_from(value: u64) -> Result { + match value { + 0 => Ok(WireType::Varint), + 1 => Ok(WireType::SixtyFourBit), + 2 => Ok(WireType::LengthDelimited), + 3 => Ok(WireType::StartGroup), + 4 => Ok(WireType::EndGroup), + 5 => Ok(WireType::ThirtyTwoBit), + _ => Err(DecodeError::new(format!( + "invalid wire type value: {}", + value + ))), + } + } +} + +/// Encodes a Protobuf field key, which consists of a wire type designator and +/// the field tag. +#[inline] +pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut B) +where + B: BufMut, +{ + debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag)); + let key = (tag << 3) | wire_type as u32; + encode_varint(u64::from(key), buf); +} + +/// Decodes a Protobuf field key, which consists of a wire type designator and +/// the field tag. +#[inline(always)] +pub fn decode_key(buf: &mut B) -> Result<(u32, WireType), DecodeError> +where + B: Buf, +{ + let key = decode_varint(buf)?; + if key > u64::from(u32::MAX) { + return Err(DecodeError::new(format!("invalid key value: {}", key))); + } + let wire_type = WireType::try_from(key & 0x07)?; + let tag = key as u32 >> 3; + + if tag < MIN_TAG { + return Err(DecodeError::new("invalid tag value: 0")); + } + + Ok((tag, wire_type)) +} + +/// Returns the width of an encoded Protobuf field key with the given tag. +/// The returned width will be between 1 and 5 bytes (inclusive). +#[inline] +pub fn key_len(tag: u32) -> usize { + encoded_len_varint(u64::from(tag << 3)) +} + +/// Checks that the expected wire type matches the actual wire type, +/// or returns an error result. +#[inline] +pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> { + if expected != actual { + return Err(DecodeError::new(format!( + "invalid wire type: {:?} (expected {:?})", + actual, expected + ))); + } + Ok(()) +} + +/// Helper function which abstracts reading a length delimiter prefix followed +/// by decoding values until the length of bytes is exhausted. +pub fn merge_loop( + value: &mut T, + buf: &mut B, + ctx: DecodeContext, + mut merge: M, +) -> Result<(), DecodeError> +where + M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>, + B: Buf, +{ + let len = decode_varint(buf)?; + let remaining = buf.remaining(); + if len > remaining as u64 { + return Err(DecodeError::new("buffer underflow")); + } + + let limit = remaining - len as usize; + while buf.remaining() > limit { + merge(value, buf, ctx.clone())?; + } + + if buf.remaining() != limit { + return Err(DecodeError::new("delimited length exceeded")); + } + Ok(()) +} + +pub fn skip_field( + wire_type: WireType, + tag: u32, + buf: &mut B, + ctx: DecodeContext, +) -> Result<(), DecodeError> +where + B: Buf, +{ + ctx.limit_reached()?; + let len = match wire_type { + WireType::Varint => decode_varint(buf).map(|_| 0)?, + WireType::ThirtyTwoBit => 4, + WireType::SixtyFourBit => 8, + WireType::LengthDelimited => decode_varint(buf)?, + WireType::StartGroup => loop { + let (inner_tag, inner_wire_type) = decode_key(buf)?; + match inner_wire_type { + WireType::EndGroup => { + if inner_tag != tag { + return Err(DecodeError::new("unexpected end group tag")); + } + break 0; + } + _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?, + } + }, + WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")), + }; + + if len > buf.remaining() as u64 { + return Err(DecodeError::new("buffer underflow")); + } + + buf.advance(len as usize); + Ok(()) +} + +/// Helper macro which emits an `encode_repeated` function for the type. +macro_rules! encode_repeated { + ($ty:ty) => { + pub fn encode_repeated(tag: u32, values: &[$ty], buf: &mut B) + where + B: BufMut, + { + for value in values { + encode(tag, value, buf); + } + } + }; +} + +/// Helper macro which emits a `merge_repeated` function for the numeric type. +macro_rules! merge_repeated_numeric { + ($ty:ty, + $wire_type:expr, + $merge:ident, + $merge_repeated:ident) => { + pub fn $merge_repeated( + wire_type: WireType, + values: &mut Vec<$ty>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if wire_type == WireType::LengthDelimited { + // Packed. + merge_loop(values, buf, ctx, |values, buf, ctx| { + let mut value = Default::default(); + $merge($wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + }) + } else { + // Unpacked. + check_wire_type($wire_type, wire_type)?; + let mut value = Default::default(); + $merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + } + }; +} + +/// Macro which emits a module containing a set of encoding functions for a +/// variable width numeric type. +macro_rules! varint { + ($ty:ty, + $proto_ty:ident) => ( + varint!($ty, + $proto_ty, + to_uint64(value) { *value as u64 }, + from_uint64(value) { value as $ty }); + ); + + ($ty:ty, + $proto_ty:ident, + to_uint64($to_uint64_value:ident) $to_uint64:expr, + from_uint64($from_uint64_value:ident) $from_uint64:expr) => ( + + pub mod $proto_ty { + use crate::encoding::*; + + pub fn encode(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut { + encode_key(tag, WireType::Varint, buf); + encode_varint($to_uint64, buf); + } + + pub fn merge(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf { + check_wire_type(WireType::Varint, wire_type)?; + let $from_uint64_value = decode_varint(buf)?; + *value = $from_uint64; + Ok(()) + } + + encode_repeated!($ty); + + pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut { + if values.is_empty() { return; } + + encode_key(tag, WireType::LengthDelimited, buf); + let len: usize = values.iter().map(|$to_uint64_value| { + encoded_len_varint($to_uint64) + }).sum(); + encode_varint(len as u64, buf); + + for $to_uint64_value in values { + encode_varint($to_uint64, buf); + } + } + + merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated); + + #[inline] + pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize { + key_len(tag) + encoded_len_varint($to_uint64) + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize { + key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| { + encoded_len_varint($to_uint64) + }).sum::() + } + + #[inline] + pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize { + if values.is_empty() { + 0 + } else { + let len = values.iter() + .map(|$to_uint64_value| encoded_len_varint($to_uint64)) + .sum::(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + } + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use crate::encoding::$proto_ty::*; + use crate::encoding::test::{ + check_collection_type, + check_type, + }; + + proptest! { + #[test] + fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::Varint, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(value, tag, WireType::Varint, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + #[test] + fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::LengthDelimited, + encode_packed, merge_repeated, + encoded_len_packed)?; + } + } + } + } + + ); +} +varint!(bool, bool, + to_uint64(value) if *value { 1u64 } else { 0u64 }, + from_uint64(value) value != 0); +varint!(i32, int32); +varint!(i64, int64); +varint!(u32, uint32); +varint!(u64, uint64); +varint!(i32, sint32, +to_uint64(value) { + ((value << 1) ^ (value >> 31)) as u32 as u64 +}, +from_uint64(value) { + let value = value as u32; + ((value >> 1) as i32) ^ (-((value & 1) as i32)) +}); +varint!(i64, sint64, +to_uint64(value) { + ((value << 1) ^ (value >> 63)) as u64 +}, +from_uint64(value) { + ((value >> 1) as i64) ^ (-((value & 1) as i64)) +}); + +/// Macro which emits a module containing a set of encoding functions for a +/// fixed width numeric type. +macro_rules! fixed_width { + ($ty:ty, + $width:expr, + $wire_type:expr, + $proto_ty:ident, + $put:ident, + $get:ident) => { + pub mod $proto_ty { + use crate::encoding::*; + + pub fn encode(tag: u32, value: &$ty, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, $wire_type, buf); + buf.$put(*value); + } + + pub fn merge( + wire_type: WireType, + value: &mut $ty, + buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + check_wire_type($wire_type, wire_type)?; + if buf.remaining() < $width { + return Err(DecodeError::new("buffer underflow")); + } + *value = buf.$get(); + Ok(()) + } + + encode_repeated!($ty); + + pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut B) + where + B: BufMut, + { + if values.is_empty() { + return; + } + + encode_key(tag, WireType::LengthDelimited, buf); + let len = values.len() as u64 * $width; + encode_varint(len as u64, buf); + + for value in values { + buf.$put(*value); + } + } + + merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated); + + #[inline] + pub fn encoded_len(tag: u32, _: &$ty) -> usize { + key_len(tag) + $width + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize { + (key_len(tag) + $width) * values.len() + } + + #[inline] + pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize { + if values.is_empty() { + 0 + } else { + let len = $width * values.len(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + } + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, $wire_type, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(value, tag, $wire_type, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + #[test] + fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::LengthDelimited, + encode_packed, merge_repeated, + encoded_len_packed)?; + } + } + } + } + }; +} +fixed_width!( + f32, + 4, + WireType::ThirtyTwoBit, + float, + put_f32_le, + get_f32_le +); +fixed_width!( + f64, + 8, + WireType::SixtyFourBit, + double, + put_f64_le, + get_f64_le +); +fixed_width!( + u32, + 4, + WireType::ThirtyTwoBit, + fixed32, + put_u32_le, + get_u32_le +); +fixed_width!( + u64, + 8, + WireType::SixtyFourBit, + fixed64, + put_u64_le, + get_u64_le +); +fixed_width!( + i32, + 4, + WireType::ThirtyTwoBit, + sfixed32, + put_i32_le, + get_i32_le +); +fixed_width!( + i64, + 8, + WireType::SixtyFourBit, + sfixed64, + put_i64_le, + get_i64_le +); + +/// Macro which emits encoding functions for a length-delimited type. +macro_rules! length_delimited { + ($ty:ty) => { + encode_repeated!($ty); + + pub fn merge_repeated( + wire_type: WireType, + values: &mut Vec<$ty>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut value = Default::default(); + merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + + #[inline] + pub fn encoded_len(tag: u32, value: &$ty) -> usize { + key_len(tag) + encoded_len_varint(value.len() as u64) + value.len() + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize { + key_len(tag) * values.len() + + values + .iter() + .map(|value| encoded_len_varint(value.len() as u64) + value.len()) + .sum::() + } + }; +} + +pub mod string { + use super::*; + + pub fn encode(tag: u32, value: &String, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + buf.put_slice(value.as_bytes()); + } + pub fn merge( + wire_type: WireType, + value: &mut String, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + // ## Unsafety + // + // `string::merge` reuses `bytes::merge`, with an additional check of utf-8 + // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the + // string is cleared, so as to avoid leaking a string field with invalid data. + // + // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe + // alternative of temporarily swapping an empty `String` into the field, because it results + // in up to 10% better performance on the protobuf message decoding benchmarks. + // + // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into + // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or + // in the buf implementation, a drop guard is used. + unsafe { + struct DropGuard<'a>(&'a mut Vec); + impl<'a> Drop for DropGuard<'a> { + #[inline] + fn drop(&mut self) { + self.0.clear(); + } + } + + let drop_guard = DropGuard(value.as_mut_vec()); + bytes::merge(wire_type, drop_guard.0, buf, ctx)?; + match str::from_utf8(drop_guard.0) { + Ok(_) => { + // Success; do not clear the bytes. + mem::forget(drop_guard); + Ok(()) + } + Err(_) => Err(DecodeError::new( + "invalid string value: data is not UTF-8 encoded", + )), + } + } + } + + length_delimited!(String); + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check(value: String, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + +pub trait BytesAdapter: sealed::BytesAdapter {} + +mod sealed { + use super::{Buf, BufMut}; + + pub trait BytesAdapter: Default + Sized + 'static { + fn len(&self) -> usize; + + /// Replace contents of this buffer with the contents of another buffer. + fn replace_with(&mut self, buf: B) + where + B: Buf; + + /// Appends this buffer to the (contents of) other buffer. + fn append_to(&self, buf: &mut B) + where + B: BufMut; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + } +} + +impl BytesAdapter for Bytes {} + +impl sealed::BytesAdapter for Bytes { + fn len(&self) -> usize { + Buf::remaining(self) + } + + fn replace_with(&mut self, mut buf: B) + where + B: Buf, + { + *self = buf.copy_to_bytes(buf.remaining()); + } + + fn append_to(&self, buf: &mut B) + where + B: BufMut, + { + buf.put(self.clone()) + } +} + +impl BytesAdapter for Vec {} + +impl sealed::BytesAdapter for Vec { + fn len(&self) -> usize { + Vec::len(self) + } + + fn replace_with(&mut self, buf: B) + where + B: Buf, + { + self.clear(); + self.reserve(buf.remaining()); + self.put(buf); + } + + fn append_to(&self, buf: &mut B) + where + B: BufMut, + { + buf.put(self.as_slice()) + } +} + +pub mod bytes { + use super::*; + + pub fn encode(tag: u32, value: &A, buf: &mut B) + where + A: BytesAdapter, + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + value.append_to(buf); + } + + pub fn merge( + wire_type: WireType, + value: &mut A, + buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + A: BytesAdapter, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let len = decode_varint(buf)?; + if len > buf.remaining() as u64 { + return Err(DecodeError::new("buffer underflow")); + } + let len = len as usize; + + // Clear the existing value. This follows from the following rule in the encoding guide[1]: + // + // > Normally, an encoded message would never have more than one instance of a non-repeated + // > field. However, parsers are expected to handle the case in which they do. For numeric + // > types and strings, if the same field appears multiple times, the parser accepts the + // > last value it sees. + // + // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional + value.replace_with(buf.copy_to_bytes(len)); + Ok(()) + } + + length_delimited!(impl BytesAdapter); + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check_vec(value: Vec, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type::, Vec>(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + + #[test] + fn check_bytes(value: Vec, tag in MIN_TAG..=MAX_TAG) { + let value = Bytes::from(value); + super::test::check_type::(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + + #[test] + fn check_repeated_vec(value: Vec>, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + + #[test] + fn check_repeated_bytes(value: Vec>, tag in MIN_TAG..=MAX_TAG) { + let value = value.into_iter().map(Bytes::from).collect(); + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + +pub mod message { + use super::*; + + pub fn encode(tag: u32, msg: &M, buf: &mut B) + where + M: Message, + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(msg.encoded_len() as u64, buf); + msg.encode_raw(buf); + } + + pub fn merge( + wire_type: WireType, + msg: &mut M, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + ctx.limit_reached()?; + merge_loop( + msg, + buf, + ctx.enter_recursion(), + |msg: &mut M, buf: &mut B, ctx| { + let (tag, wire_type) = decode_key(buf)?; + msg.merge_field(tag, wire_type, buf, ctx) + }, + ) + } + + pub fn encode_repeated(tag: u32, messages: &[M], buf: &mut B) + where + M: Message, + B: BufMut, + { + for msg in messages { + encode(tag, msg, buf); + } + } + + pub fn merge_repeated( + wire_type: WireType, + messages: &mut Vec, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message + Default, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut msg = M::default(); + merge(WireType::LengthDelimited, &mut msg, buf, ctx)?; + messages.push(msg); + Ok(()) + } + + #[inline] + pub fn encoded_len(tag: u32, msg: &M) -> usize + where + M: Message, + { + let len = msg.encoded_len(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, messages: &[M]) -> usize + where + M: Message, + { + key_len(tag) * messages.len() + + messages + .iter() + .map(Message::encoded_len) + .map(|len| len + encoded_len_varint(len as u64)) + .sum::() + } +} + +pub mod group { + use super::*; + + pub fn encode(tag: u32, msg: &M, buf: &mut B) + where + M: Message, + B: BufMut, + { + encode_key(tag, WireType::StartGroup, buf); + msg.encode_raw(buf); + encode_key(tag, WireType::EndGroup, buf); + } + + pub fn merge( + tag: u32, + wire_type: WireType, + msg: &mut M, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message, + B: Buf, + { + check_wire_type(WireType::StartGroup, wire_type)?; + + ctx.limit_reached()?; + loop { + let (field_tag, field_wire_type) = decode_key(buf)?; + if field_wire_type == WireType::EndGroup { + if field_tag != tag { + return Err(DecodeError::new("unexpected end group tag")); + } + return Ok(()); + } + + M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?; + } + } + + pub fn encode_repeated(tag: u32, messages: &[M], buf: &mut B) + where + M: Message, + B: BufMut, + { + for msg in messages { + encode(tag, msg, buf); + } + } + + pub fn merge_repeated( + tag: u32, + wire_type: WireType, + messages: &mut Vec, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message + Default, + B: Buf, + { + check_wire_type(WireType::StartGroup, wire_type)?; + let mut msg = M::default(); + merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?; + messages.push(msg); + Ok(()) + } + + #[inline] + pub fn encoded_len(tag: u32, msg: &M) -> usize + where + M: Message, + { + 2 * key_len(tag) + msg.encoded_len() + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, messages: &[M]) -> usize + where + M: Message, + { + 2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::() + } +} + +/// Rust doesn't have a `Map` trait, so macros are currently the best way to be +/// generic over `HashMap` and `BTreeMap`. +macro_rules! map { + ($map_ty:ident) => { + use crate::encoding::*; + use core::hash::Hash; + + /// Generic protobuf map encode function. + pub fn encode( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + tag: u32, + values: &$map_ty, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + encode_with_default( + key_encode, + key_encoded_len, + val_encode, + val_encoded_len, + &V::default(), + tag, + values, + buf, + ) + } + + /// Generic protobuf map merge function. + pub fn merge( + key_merge: KM, + val_merge: VM, + values: &mut $map_ty, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + V: Default, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx) + } + + /// Generic protobuf map encode function. + pub fn encoded_len( + key_encoded_len: KL, + val_encoded_len: VL, + tag: u32, + values: &$map_ty, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values) + } + + /// Generic protobuf map encode function with an overriden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encode_with_default( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &$map_ty, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + for (key, val) in values.iter() { + let skip_key = key == &K::default(); + let skip_val = val == val_default; + + let len = (if skip_key { 0 } else { key_encoded_len(1, key) }) + + (if skip_val { 0 } else { val_encoded_len(2, val) }); + + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(len as u64, buf); + if !skip_key { + key_encode(1, key, buf); + } + if !skip_val { + val_encode(2, val, buf); + } + } + } + + /// Generic protobuf map merge function with an overriden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn merge_with_default( + key_merge: KM, + val_merge: VM, + val_default: V, + values: &mut $map_ty, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + let mut key = Default::default(); + let mut val = val_default; + ctx.limit_reached()?; + merge_loop( + &mut (&mut key, &mut val), + buf, + ctx.enter_recursion(), + |&mut (ref mut key, ref mut val), buf, ctx| { + let (tag, wire_type) = decode_key(buf)?; + match tag { + 1 => key_merge(wire_type, key, buf, ctx), + 2 => val_merge(wire_type, val, buf, ctx), + _ => skip_field(wire_type, tag, buf, ctx), + } + }, + )?; + values.insert(key, val); + + Ok(()) + } + + /// Generic protobuf map encode function with an overriden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encoded_len_with_default( + key_encoded_len: KL, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &$map_ty, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + key_len(tag) * values.len() + + values + .iter() + .map(|(key, val)| { + let len = (if key == &K::default() { + 0 + } else { + key_encoded_len(1, key) + }) + (if val == val_default { + 0 + } else { + val_encoded_len(2, val) + }); + encoded_len_varint(len as u64) + len + }) + .sum::() + } + }; +} + +#[cfg(feature = "std")] +pub mod hash_map { + use std::collections::HashMap; + map!(HashMap); +} + +pub mod btree_map { + map!(BTreeMap); +} + +#[cfg(test)] +mod test { + use alloc::string::ToString; + use core::borrow::Borrow; + use core::fmt::Debug; + use core::u64; + + use ::bytes::{Bytes, BytesMut}; + use proptest::{prelude::*, test_runner::TestCaseResult}; + + use crate::encoding::*; + + pub fn check_type( + value: T, + tag: u32, + wire_type: WireType, + encode: fn(u32, &B, &mut BytesMut), + merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + encoded_len: fn(u32, &B) -> usize, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq + Borrow, + B: ?Sized, + { + prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG); + + let expected_len = encoded_len(tag, value.borrow()); + + let mut buf = BytesMut::with_capacity(expected_len); + encode(tag, value.borrow(), &mut buf); + + let mut buf = buf.freeze(); + + prop_assert_eq!( + buf.remaining(), + expected_len, + "encoded_len wrong; expected: {}, actual: {}", + expected_len, + buf.remaining() + ); + + if !buf.has_remaining() { + // Short circuit for empty packed values. + return Ok(()); + } + + let (decoded_tag, decoded_wire_type) = + decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?; + prop_assert_eq!( + tag, + decoded_tag, + "decoded tag does not match; expected: {}, actual: {}", + tag, + decoded_tag + ); + + prop_assert_eq!( + wire_type, + decoded_wire_type, + "decoded wire type does not match; expected: {:?}, actual: {:?}", + wire_type, + decoded_wire_type, + ); + + match wire_type { + WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!( + "64bit wire type illegal remaining: {}, tag: {}", + buf.remaining(), + tag + ))), + WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!( + "32bit wire type illegal remaining: {}, tag: {}", + buf.remaining(), + tag + ))), + _ => Ok(()), + }?; + + let mut roundtrip_value = T::default(); + merge( + wire_type, + &mut roundtrip_value, + &mut buf, + DecodeContext::default(), + ) + .map_err(|error| TestCaseError::fail(error.to_string()))?; + + prop_assert!( + !buf.has_remaining(), + "expected buffer to be empty, remaining: {}", + buf.remaining() + ); + + prop_assert_eq!(value, roundtrip_value); + + Ok(()) + } + + pub fn check_collection_type( + value: T, + tag: u32, + wire_type: WireType, + encode: E, + mut merge: M, + encoded_len: L, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq + Borrow, + B: ?Sized, + E: FnOnce(u32, &B, &mut BytesMut), + M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + L: FnOnce(u32, &B) -> usize, + { + prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG); + + let expected_len = encoded_len(tag, value.borrow()); + + let mut buf = BytesMut::with_capacity(expected_len); + encode(tag, value.borrow(), &mut buf); + + let mut buf = buf.freeze(); + + prop_assert_eq!( + buf.remaining(), + expected_len, + "encoded_len wrong; expected: {}, actual: {}", + expected_len, + buf.remaining() + ); + + let mut roundtrip_value = Default::default(); + while buf.has_remaining() { + let (decoded_tag, decoded_wire_type) = + decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?; + + prop_assert_eq!( + tag, + decoded_tag, + "decoded tag does not match; expected: {}, actual: {}", + tag, + decoded_tag + ); + + prop_assert_eq!( + wire_type, + decoded_wire_type, + "decoded wire type does not match; expected: {:?}, actual: {:?}", + wire_type, + decoded_wire_type + ); + + merge( + wire_type, + &mut roundtrip_value, + &mut buf, + DecodeContext::default(), + ) + .map_err(|error| TestCaseError::fail(error.to_string()))?; + } + + prop_assert_eq!(value, roundtrip_value); + + Ok(()) + } + + #[test] + fn string_merge_invalid_utf8() { + let mut s = String::new(); + let buf = b"\x02\x80\x80"; + + let r = string::merge( + WireType::LengthDelimited, + &mut s, + &mut &buf[..], + DecodeContext::default(), + ); + r.expect_err("must be an error"); + assert!(s.is_empty()); + } + + #[test] + fn varint() { + fn check(value: u64, mut encoded: &[u8]) { + // TODO(rust-lang/rust-clippy#5494) + #![allow(clippy::clone_double_ref)] + + // Small buffer. + let mut buf = Vec::with_capacity(1); + encode_varint(value, &mut buf); + assert_eq!(buf, encoded); + + // Large buffer. + let mut buf = Vec::with_capacity(100); + encode_varint(value, &mut buf); + assert_eq!(buf, encoded); + + assert_eq!(encoded_len_varint(value), encoded.len()); + + let roundtrip_value = decode_varint(&mut encoded.clone()).expect("decoding failed"); + assert_eq!(value, roundtrip_value); + + let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed"); + assert_eq!(value, roundtrip_value); + } + + check(2u64.pow(0) - 1, &[0x00]); + check(2u64.pow(0), &[0x01]); + + check(2u64.pow(7) - 1, &[0x7F]); + check(2u64.pow(7), &[0x80, 0x01]); + check(300, &[0xAC, 0x02]); + + check(2u64.pow(14) - 1, &[0xFF, 0x7F]); + check(2u64.pow(14), &[0x80, 0x80, 0x01]); + + check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]); + check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]); + + check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]); + check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]); + + check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); + check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); + + check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); + check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); + + check( + 2u64.pow(49) - 1, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + check( + 2u64.pow(49), + &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], + ); + + check( + 2u64.pow(56) - 1, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + check( + 2u64.pow(56), + &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], + ); + + check( + 2u64.pow(63) - 1, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + check( + 2u64.pow(63), + &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], + ); + + check( + u64::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], + ); + } + + /// This big bowl o' macro soup generates an encoding property test for each combination of map + /// type, scalar map key, and value type. + /// TODO: these tests take a long time to compile, can this be improved? + #[cfg(feature = "std")] + macro_rules! map_tests { + (keys: $keys:tt, + vals: $vals:tt) => { + mod hash_map { + map_tests!(@private HashMap, hash_map, $keys, $vals); + } + mod btree_map { + map_tests!(@private BTreeMap, btree_map, $keys, $vals); + } + }; + + (@private $map_type:ident, + $mod_name:ident, + [$(($key_ty:ty, $key_proto:ident)),*], + $vals:tt) => { + $( + mod $key_proto { + use std::collections::$map_type; + + use proptest::prelude::*; + + use crate::encoding::*; + use crate::encoding::test::check_collection_type; + + map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals); + } + )* + }; + + (@private $map_type:ident, + $mod_name:ident, + ($key_ty:ty, $key_proto:ident), + [$(($val_ty:ty, $val_proto:ident)),*]) => { + $( + proptest! { + #[test] + fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(values, tag, WireType::LengthDelimited, + |tag, values, buf| { + $mod_name::encode($key_proto::encode, + $key_proto::encoded_len, + $val_proto::encode, + $val_proto::encoded_len, + tag, + values, + buf) + }, + |wire_type, values, buf, ctx| { + check_wire_type(WireType::LengthDelimited, wire_type)?; + $mod_name::merge($key_proto::merge, + $val_proto::merge, + values, + buf, + ctx) + }, + |tag, values| { + $mod_name::encoded_len($key_proto::encoded_len, + $val_proto::encoded_len, + tag, + values) + })?; + } + } + )* + }; + } + + #[cfg(feature = "std")] + map_tests!(keys: [ + (i32, int32), + (i64, int64), + (u32, uint32), + (u64, uint64), + (i32, sint32), + (i64, sint64), + (u32, fixed32), + (u64, fixed64), + (i32, sfixed32), + (i64, sfixed64), + (bool, bool), + (String, string) + ], + vals: [ + (f32, float), + (f64, double), + (i32, int32), + (i64, int64), + (u32, uint32), + (u64, uint64), + (i32, sint32), + (i64, sint64), + (u32, fixed32), + (u64, fixed64), + (i32, sfixed32), + (i64, sfixed64), + (bool, bool), + (String, string), + (Vec, bytes) + ]); +} -- cgit v1.2.3