summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio2
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio2')
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/client.rs16
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/server.rs50
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/test_vector.rs32
3 files changed, 65 insertions, 33 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio2/client.rs b/third_party/rust/prio/src/vdaf/prio2/client.rs
index dbce39ee3f..9515601d8a 100644
--- a/third_party/rust/prio/src/vdaf/prio2/client.rs
+++ b/third_party/rust/prio/src/vdaf/prio2/client.rs
@@ -4,10 +4,14 @@
//! Primitives for the Prio2 client.
use crate::{
- field::{FftFriendlyFieldElement, FieldError},
+ codec::CodecError,
+ field::FftFriendlyFieldElement,
polynomial::{poly_fft, PolyAuxMemory},
prng::{Prng, PrngError},
- vdaf::{xof::SeedStreamAes128, VdafError},
+ vdaf::{
+ xof::{Seed, SeedStreamAes128},
+ VdafError,
+ },
};
use std::convert::TryFrom;
@@ -32,9 +36,9 @@ pub enum SerializeError {
/// Emitted by `unpack_proof[_mut]` if the serialized share+proof has the wrong length
#[error("serialized input has wrong length")]
UnpackInputSizeMismatch,
- /// Finite field operation error.
- #[error("finite field operation error")]
- Field(#[from] FieldError),
+ /// Codec error.
+ #[error(transparent)]
+ Codec(#[from] CodecError),
}
#[derive(Debug)]
@@ -63,7 +67,7 @@ impl<F: FftFriendlyFieldElement> ClientMemory<F> {
}
Ok(Self {
- prng: Prng::new()?,
+ prng: Prng::from_prio2_seed(Seed::<32>::generate()?.as_ref()),
points_f: vec![F::zero(); n],
points_g: vec![F::zero(); n],
evals_f: vec![F::zero(); 2 * n],
diff --git a/third_party/rust/prio/src/vdaf/prio2/server.rs b/third_party/rust/prio/src/vdaf/prio2/server.rs
index 11c161babf..9d2871c867 100644
--- a/third_party/rust/prio/src/vdaf/prio2/server.rs
+++ b/third_party/rust/prio/src/vdaf/prio2/server.rs
@@ -101,9 +101,13 @@ pub(crate) fn is_valid_share<F: FftFriendlyFieldElement>(
#[cfg(test)]
mod test_util {
use crate::{
+ codec::ParameterizedDecode,
field::{merge_vector, FftFriendlyFieldElement},
prng::Prng,
- vdaf::prio2::client::proof_length,
+ vdaf::{
+ prio2::client::{proof_length, SerializeError},
+ Share, ShareDecodingParameter,
+ },
};
use super::{generate_verification_message, is_valid_share, ServerError, VerificationMessage};
@@ -133,17 +137,17 @@ mod test_util {
/// Deserialize
fn deserialize_share(&self, share: &[u8]) -> Result<Vec<F>, ServerError> {
let len = proof_length(self.dimension);
- Ok(if self.is_first_server {
- F::byte_slice_into_vec(share)?
+ let decoding_parameter = if self.is_first_server {
+ ShareDecodingParameter::Leader(len)
} else {
- if share.len() != 32 {
- return Err(ServerError::ShareLength);
- }
-
- Prng::from_prio2_seed(&share.try_into().unwrap())
- .take(len)
- .collect()
- })
+ ShareDecodingParameter::Helper
+ };
+ let decoded_share = Share::get_decoded_with_param(&decoding_parameter, share)
+ .map_err(SerializeError::from)?;
+ match decoded_share {
+ Share::Leader(vec) => Ok(vec),
+ Share::Helper(seed) => Ok(Prng::from_prio2_seed(&seed.0).take(len).collect()),
+ }
}
/// Generate verification message from an encrypted share
@@ -194,14 +198,19 @@ mod test_util {
mod tests {
use super::*;
use crate::{
- codec::Encode,
+ codec::{Encode, ParameterizedDecode},
field::{FieldElement, FieldPrio2},
prng::Prng,
vdaf::{
- prio2::{client::unpack_proof_mut, server::test_util::Server, Prio2},
- Client,
+ prio2::{
+ client::{proof_length, unpack_proof_mut},
+ server::test_util::Server,
+ Prio2,
+ },
+ Client, Share, ShareDecodingParameter,
},
};
+ use assert_matches::assert_matches;
use rand::{random, Rng};
fn secret_share(share: &mut [FieldPrio2]) -> Vec<FieldPrio2> {
@@ -286,10 +295,13 @@ mod tests {
let vdaf = Prio2::new(dim).unwrap();
let (_, shares) = vdaf.shard(&data, &[0; 16]).unwrap();
- let share1_original = shares[0].get_encoded();
- let share2 = shares[1].get_encoded();
+ let share1_original = shares[0].get_encoded().unwrap();
+ let share2 = shares[1].get_encoded().unwrap();
- let mut share1_field = FieldPrio2::byte_slice_into_vec(&share1_original).unwrap();
+ let mut share1_field: Vec<FieldPrio2> = assert_matches!(
+ Share::get_decoded_with_param(&ShareDecodingParameter::<32>::Leader(proof_length(dim)), &share1_original),
+ Ok(Share::Leader(vec)) => vec
+ );
let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap();
let one = FieldPrio2::from(1);
@@ -304,7 +316,9 @@ mod tests {
};
// reserialize altered share1
- let share1_modified = FieldPrio2::slice_into_byte_vec(&share1_field);
+ let share1_modified = Share::<FieldPrio2, 32>::Leader(share1_field)
+ .get_encoded()
+ .unwrap();
let mut prng = Prng::from_prio2_seed(&random());
let eval_at = vdaf.choose_eval_at(&mut prng);
diff --git a/third_party/rust/prio/src/vdaf/prio2/test_vector.rs b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
index ae2b8b0f9d..114b437b55 100644
--- a/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
+++ b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
@@ -48,9 +48,16 @@ mod base64 {
//! instead of an array of an array of integers when serializing to JSON.
//
// Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2
- use crate::field::{FieldElement, FieldPrio2};
+ use crate::{
+ codec::ParameterizedDecode,
+ field::{encode_fieldvec, FieldElement, FieldPrio2},
+ vdaf::{Share, ShareDecodingParameter},
+ };
+ use assert_matches::assert_matches;
use base64::{engine::Engine, prelude::BASE64_STANDARD};
- use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
+ use serde::{
+ de::Error as _, ser::Error as _, Deserialize, Deserializer, Serialize, Serializer,
+ };
pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> {
let base64_vec = v
@@ -63,21 +70,28 @@ mod base64 {
pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> {
<Vec<String>>::deserialize(d)?
.iter()
- .map(|s| BASE64_STANDARD.decode(s.as_bytes()).map_err(Error::custom))
+ .map(|s| {
+ BASE64_STANDARD
+ .decode(s.as_bytes())
+ .map_err(D::Error::custom)
+ })
.collect()
}
pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> {
- String::serialize(
- &BASE64_STANDARD.encode(FieldPrio2::slice_into_byte_vec(v)),
- s,
- )
+ let mut bytes = Vec::new();
+ encode_fieldvec(v, &mut bytes).map_err(S::Error::custom)?;
+ String::serialize(&BASE64_STANDARD.encode(&bytes), s)
}
pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> {
let bytes = BASE64_STANDARD
.decode(String::deserialize(d)?.as_bytes())
- .map_err(Error::custom)?;
- FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom)
+ .map_err(D::Error::custom)?;
+ let decoding_parameter =
+ ShareDecodingParameter::<32>::Leader(bytes.len() / FieldPrio2::ENCODED_SIZE);
+ let share = Share::<FieldPrio2, 32>::get_decoded_with_param(&decoding_parameter, &bytes)
+ .map_err(D::Error::custom)?;
+ assert_matches!(share, Share::Leader(vec) => Ok(vec))
}
}