summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio2/test_vector.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/test_vector.rs32
1 files changed, 23 insertions, 9 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio2/test_vector.rs b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
index ae2b8b0f9d..114b437b55 100644
--- a/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
+++ b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
@@ -48,9 +48,16 @@ mod base64 {
//! instead of an array of an array of integers when serializing to JSON.
//
// Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2
- use crate::field::{FieldElement, FieldPrio2};
+ use crate::{
+ codec::ParameterizedDecode,
+ field::{encode_fieldvec, FieldElement, FieldPrio2},
+ vdaf::{Share, ShareDecodingParameter},
+ };
+ use assert_matches::assert_matches;
use base64::{engine::Engine, prelude::BASE64_STANDARD};
- use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
+ use serde::{
+ de::Error as _, ser::Error as _, Deserialize, Deserializer, Serialize, Serializer,
+ };
pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> {
let base64_vec = v
@@ -63,21 +70,28 @@ mod base64 {
pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> {
<Vec<String>>::deserialize(d)?
.iter()
- .map(|s| BASE64_STANDARD.decode(s.as_bytes()).map_err(Error::custom))
+ .map(|s| {
+ BASE64_STANDARD
+ .decode(s.as_bytes())
+ .map_err(D::Error::custom)
+ })
.collect()
}
pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> {
- String::serialize(
- &BASE64_STANDARD.encode(FieldPrio2::slice_into_byte_vec(v)),
- s,
- )
+ let mut bytes = Vec::new();
+ encode_fieldvec(v, &mut bytes).map_err(S::Error::custom)?;
+ String::serialize(&BASE64_STANDARD.encode(&bytes), s)
}
pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> {
let bytes = BASE64_STANDARD
.decode(String::deserialize(d)?.as_bytes())
- .map_err(Error::custom)?;
- FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom)
+ .map_err(D::Error::custom)?;
+ let decoding_parameter =
+ ShareDecodingParameter::<32>::Leader(bytes.len() / FieldPrio2::ENCODED_SIZE);
+ let share = Share::<FieldPrio2, 32>::get_decoded_with_param(&decoding_parameter, &bytes)
+ .map_err(D::Error::custom)?;
+ assert_matches!(share, Share::Leader(vec) => Ok(vec))
}
}