// SPDX-License-Identifier: MPL-2.0 //! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-03]]. //! //! **WARNING:** Neither this code nor the cryptographic construction it implements has undergone //! significant security analysis. Use at your own risk. //! //! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented //! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO //! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication //! cost. //! //! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-03]] into //! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of //! which are instantiated here: //! //! - [`Prio3Aes128Count`] for aggregating a counter (*) //! - [`Prio3Aes128CountVec`] for aggregating a vector of counters //! - [`Prio3Aes128Sum`] for copmputing the sum of integers (*) //! - [`Prio3Aes128Histogram`] for estimating a distribution via a histogram (*) //! //! Additional types can be constructed from [`Prio3`] as needed. //! //! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-03]]. //! //! [BBCG+19]: https://ia.cr/2019/188 //! [CGB17]: https://crypto.stanford.edu/prio/ //! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ #[cfg(feature = "crypto-dependencies")] use super::prg::PrgAes128; use super::{DST_LEN, VERSION}; use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; use crate::field::FieldElement; #[cfg(feature = "crypto-dependencies")] use crate::field::{Field128, Field64}; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; #[cfg(feature = "crypto-dependencies")] use crate::flp::gadgets::{BlindPolyEval, ParallelSum}; #[cfg(feature = "crypto-dependencies")] use crate::flp::types::{Average, Count, CountVec, Histogram, Sum}; use crate::flp::Type; use crate::prng::Prng; use crate::vdaf::prg::{Prg, RandSource, Seed}; use crate::vdaf::{ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError, }; use std::convert::TryFrom; use std::fmt::Debug; use std::io::Cursor; use std::iter::IntoIterator; use std::marker::PhantomData; /// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum. #[cfg(feature = "crypto-dependencies")] pub type Prio3Aes128Count = Prio3, PrgAes128, 16>; #[cfg(feature = "crypto-dependencies")] impl Prio3Aes128Count { /// Construct an instance of Prio3Aes128Count with the given number of aggregators. pub fn new_aes128_count(num_aggregators: u8) -> Result { Prio3::new(num_aggregators, Count::new()) } } /// The count-vector type. Each measurement is a vector of integers in `[0,2)` and the aggregate is /// the element-wise sum. #[cfg(feature = "crypto-dependencies")] pub type Prio3Aes128CountVec = Prio3>>, PrgAes128, 16>; #[cfg(feature = "crypto-dependencies")] impl Prio3Aes128CountVec { /// Construct an instance of Prio3Aes1238CountVec with the given number of aggregators. `len` /// defines the length of each measurement. pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result { Prio3::new(num_aggregators, CountVec::new(len)) } } /// Like [`Prio3Aes128CountVec`] except this type uses multithreading to improve sharding and /// preparation time. Note that the improvement is only noticeable for very large input lengths, /// e.g., 201 and up. (Your system's mileage may vary.) #[cfg(feature = "multithreaded")] #[cfg(feature = "crypto-dependencies")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] pub type Prio3Aes128CountVecMultithreaded = Prio3< CountVec>>, PrgAes128, 16, >; #[cfg(feature = "multithreaded")] #[cfg(feature = "crypto-dependencies")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] impl Prio3Aes128CountVecMultithreaded { /// Construct an instance of Prio3Aes1238CountVecMultithreaded with the given number of /// aggregators. `len` defines the length of each measurement. pub fn new_aes128_count_vec_multithreaded( num_aggregators: u8, len: usize, ) -> Result { Prio3::new(num_aggregators, CountVec::new(len)) } } /// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the /// aggregate is the sum. #[cfg(feature = "crypto-dependencies")] pub type Prio3Aes128Sum = Prio3, PrgAes128, 16>; #[cfg(feature = "crypto-dependencies")] impl Prio3Aes128Sum { /// Construct an instance of Prio3Aes128Sum with the given number of aggregators and required /// bit length. The bit length must not exceed 64. pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result { if bits > 64 { return Err(VdafError::Uncategorized(format!( "bit length ({}) exceeds limit for aggregate type (64)", bits ))); } Prio3::new(num_aggregators, Sum::new(bits as usize)?) } } /// The histogram type. Each measurement is an unsigned integer and the result is a histogram /// representation of the distribution. The bucket boundaries are fixed in advance. #[cfg(feature = "crypto-dependencies")] pub type Prio3Aes128Histogram = Prio3, PrgAes128, 16>; #[cfg(feature = "crypto-dependencies")] impl Prio3Aes128Histogram { /// Constructs an instance of Prio3Aes128Histogram with the given number of aggregators and /// desired histogram bucket boundaries. pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result { let buckets = buckets.iter().map(|bucket| *bucket as u128).collect(); Prio3::new(num_aggregators, Histogram::new(buckets)?) } } /// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and /// the aggregate is the arithmetic average. #[cfg(feature = "crypto-dependencies")] pub type Prio3Aes128Average = Prio3, PrgAes128, 16>; #[cfg(feature = "crypto-dependencies")] impl Prio3Aes128Average { /// Construct an instance of Prio3Aes128Average with the given number of aggregators and /// required bit length. The bit length must not exceed 64. pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result { check_num_aggregators(num_aggregators)?; if bits > 64 { return Err(VdafError::Uncategorized(format!( "bit length ({}) exceeds limit for aggregate type (64)", bits ))); } Ok(Prio3 { num_aggregators, typ: Average::new(bits as usize)?, phantom: PhantomData, }) } } /// The base type for Prio3. /// /// An instance of Prio3 is determined by: /// /// - a [`Type`](crate::flp::Type) that defines the set of valid input measurements; and /// - a [`Prg`](crate::vdaf::prg::Prg) for deriving vectors of field elements from seeds. /// /// New instances can be defined by aliasing the base type. For example, [`Prio3Aes128Count`] is an /// alias for `Prio3, PrgAes128, 16>`. /// /// ``` /// use prio::vdaf::{ /// Aggregator, Client, Collector, PrepareTransition, /// prio3::Prio3, /// }; /// use rand::prelude::*; /// /// let num_shares = 2; /// let vdaf = Prio3::new_aes128_count(num_shares).unwrap(); /// /// let mut out_shares = vec![vec![]; num_shares.into()]; /// let mut rng = thread_rng(); /// let verify_key = rng.gen(); /// let measurements = [0, 1, 1, 1, 0]; /// for measurement in measurements { /// // Shard /// let (public_share, input_shares) = vdaf.shard(&measurement).unwrap(); /// let mut nonce = [0; 16]; /// rng.fill(&mut nonce); /// /// // Prepare /// let mut prep_states = vec![]; /// let mut prep_shares = vec![]; /// for (agg_id, input_share) in input_shares.iter().enumerate() { /// let (state, share) = vdaf.prepare_init( /// &verify_key, /// agg_id, /// &(), /// &nonce, /// &public_share, /// input_share /// ).unwrap(); /// prep_states.push(state); /// prep_shares.push(share); /// } /// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); /// /// for (agg_id, state) in prep_states.into_iter().enumerate() { /// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() { /// PrepareTransition::Finish(out_share) => out_share, /// _ => panic!("unexpected transition"), /// }; /// out_shares[agg_id].push(out_share); /// } /// } /// /// // Aggregate /// let agg_shares = out_shares.into_iter() /// .map(|o| vdaf.aggregate(&(), o).unwrap()); /// /// // Unshard /// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap(); /// assert_eq!(agg_res, 3); /// ``` /// /// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ #[derive(Clone, Debug)] pub struct Prio3 where T: Type, P: Prg, { num_aggregators: u8, typ: T, phantom: PhantomData

, } impl Prio3 where T: Type, P: Prg, { /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the /// underlying type. pub fn new(num_aggregators: u8, typ: T) -> Result { check_num_aggregators(num_aggregators)?; Ok(Self { num_aggregators, typ, phantom: PhantomData, }) } /// The output length of the underlying FLP. pub fn output_len(&self) -> usize { self.typ.output_len() } /// The verifier length of the underlying FLP. pub fn verifier_len(&self) -> usize { self.typ.verifier_len() } fn derive_joint_randomness<'a>(parts: impl Iterator>) -> Seed { let mut info = [0; VERSION.len() + 5]; info[..VERSION.len()].copy_from_slice(VERSION); info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes()); info[VERSION.len() + 4] = 255; let mut deriver = P::init(&[0; L]); deriver.update(&info); for part in parts { deriver.update(part.as_ref()); } deriver.into_seed() } fn shard_with_rand_source( &self, measurement: &T::Measurement, rand_source: RandSource, ) -> Result>, VdafError> { let mut info = [0; DST_LEN + 1]; info[..VERSION.len()].copy_from_slice(VERSION); info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); let num_aggregators = self.num_aggregators; let input = self.typ.encode_measurement(measurement)?; // Generate the input shares and compute the joint randomness. let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1); let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 { Some(Vec::with_capacity(num_aggregators as usize - 1)) } else { None }; let mut leader_input_share = input.clone(); for agg_id in 1..num_aggregators { let helper = HelperShare::from_rand_source(rand_source)?; let mut deriver = P::init(helper.joint_rand_param.blind.as_ref()); info[DST_LEN] = agg_id; deriver.update(&info); let prng: Prng = Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info)); for (x, y) in leader_input_share .iter_mut() .zip(prng) .take(self.typ.input_len()) { *x -= y; deriver.update(&y.into()); } if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() { helper_joint_rand_parts.push(deriver.into_seed()); } helper_shares.push(helper); } let leader_blind = Seed::from_rand_source(rand_source)?; info[DST_LEN] = 0; // ID of the leader let mut deriver = P::init(leader_blind.as_ref()); deriver.update(&info); for x in leader_input_share.iter() { deriver.update(&(*x).into()); } let leader_joint_rand_seed_part = deriver.into_seed(); // Compute the joint randomness seed. let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| { Self::derive_joint_randomness( std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()), ) }); // Run the proof-generation algorithm. let domain_separation_tag = &info[..DST_LEN]; let joint_rand: Vec = joint_rand_seed .map(|joint_rand_seed| { let prng: Prng = Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag)); prng.take(self.typ.joint_rand_len()).collect() }) .unwrap_or_default(); let prng: Prng = Prng::from_seed_stream(P::seed_stream( &Seed::from_rand_source(rand_source)?, domain_separation_tag, )); let prove_rand: Vec = prng.take(self.typ.prove_rand_len()).collect(); let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?; // Generate the proof shares and distribute the joint randomness seed hints. for (j, helper) in helper_shares.iter_mut().enumerate() { info[DST_LEN] = j as u8 + 1; let prng: Prng = Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info)); for (x, y) in leader_proof_share .iter_mut() .zip(prng) .take(self.typ.proof_len()) { *x -= y; } if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() { let mut hint = Vec::with_capacity(num_aggregators as usize - 1); hint.push(leader_joint_rand_seed_part.clone()); hint.extend(helper_joint_rand_parts[..j].iter().cloned()); hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned()); helper.joint_rand_param.seed_hint = hint; } } let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 { Some(JointRandParam { seed_hint: helper_joint_rand_parts.unwrap_or_default(), blind: leader_blind, }) } else { None }; // Prep the output messages. let mut out = Vec::with_capacity(num_aggregators as usize); out.push(Prio3InputShare { input_share: Share::Leader(leader_input_share), proof_share: Share::Leader(leader_proof_share), joint_rand_param: leader_joint_rand_param, }); for helper in helper_shares.into_iter() { let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 { Some(helper.joint_rand_param) } else { None }; out.push(Prio3InputShare { input_share: Share::Helper(helper.input_share), proof_share: Share::Helper(helper.proof_share), joint_rand_param: helper_joint_rand_param, }); } Ok(out) } /// Shard measurement with constant randomness of repeated bytes. /// This method is not secure. It is used for running test vectors for Prio3. #[cfg(feature = "test-util")] pub fn test_vec_shard( &self, measurement: &T::Measurement, ) -> Result>, VdafError> { self.shard_with_rand_source(measurement, |buf| { buf.fill(1); Ok(()) }) } fn role_try_from(&self, agg_id: usize) -> Result { if agg_id >= self.num_aggregators as usize { return Err(VdafError::Uncategorized("unexpected aggregator id".into())); } Ok(u8::try_from(agg_id).unwrap()) } } impl Vdaf for Prio3 where T: Type, P: Prg, { const ID: u32 = T::ID; type Measurement = T::Measurement; type AggregateResult = T::AggregateResult; type AggregationParam = (); type PublicShare = (); type InputShare = Prio3InputShare; type OutputShare = OutputShare; type AggregateShare = AggregateShare; fn num_aggregators(&self) -> usize { self.num_aggregators as usize } } /// Message sent by the [`Client`](crate::vdaf::Client) to each /// [`Aggregator`](crate::vdaf::Aggregator) during the Sharding phase. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Prio3InputShare { /// The input share. input_share: Share, /// The proof share. proof_share: Share, /// Parameters used by the Aggregator to compute the joint randomness. This field is optional /// because not every [`Type`](`crate::flp::Type`) requires joint randomness. joint_rand_param: Option>, } impl Encode for Prio3InputShare { fn encode(&self, bytes: &mut Vec) { if matches!( (&self.input_share, &self.proof_share), (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_)) ) { panic!("tried to encode input share with ambiguous encoding") } self.input_share.encode(bytes); self.proof_share.encode(bytes); if let Some(ref param) = self.joint_rand_param { param.blind.encode(bytes); for part in param.seed_hint.iter() { part.encode(bytes); } } } } impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3, usize)> for Prio3InputShare where T: Type, P: Prg, { fn decode_with_param( (prio3, agg_id): &(&'a Prio3, usize), bytes: &mut Cursor<&[u8]>, ) -> Result { let agg_id = prio3 .role_try_from(*agg_id) .map_err(|e| CodecError::Other(Box::new(e)))?; let (input_decoder, proof_decoder) = if agg_id == 0 { ( ShareDecodingParameter::Leader(prio3.typ.input_len()), ShareDecodingParameter::Leader(prio3.typ.proof_len()), ) } else { ( ShareDecodingParameter::Helper, ShareDecodingParameter::Helper, ) }; let input_share = Share::decode_with_param(&input_decoder, bytes)?; let proof_share = Share::decode_with_param(&proof_decoder, bytes)?; let joint_rand_param = if prio3.typ.joint_rand_len() > 0 { let num_aggregators = prio3.num_aggregators(); let blind = Seed::decode(bytes)?; let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes)) .take(num_aggregators - 1) .collect::, _>>()?; Some(JointRandParam { blind, seed_hint }) } else { None }; Ok(Prio3InputShare { input_share, proof_share, joint_rand_param, }) } } #[derive(Clone, Debug, Eq, PartialEq)] /// Message broadcast by each [`Aggregator`](crate::vdaf::Aggregator) in each round of the /// Preparation phase. pub struct Prio3PrepareShare { /// A share of the FLP verifier message. (See [`Type`](crate::flp::Type).) verifier: Vec, /// A part of the joint randomness seed. joint_rand_part: Option>, } impl Encode for Prio3PrepareShare { fn encode(&self, bytes: &mut Vec) { for x in &self.verifier { x.encode(bytes); } if let Some(ref seed) = self.joint_rand_part { seed.encode(bytes); } } } impl ParameterizedDecode> for Prio3PrepareShare { fn decode_with_param( decoding_parameter: &Prio3PrepareState, bytes: &mut Cursor<&[u8]>, ) -> Result { let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len); for _ in 0..decoding_parameter.verifier_len { verifier.push(F::decode(bytes)?); } let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() { Some(Seed::decode(bytes)?) } else { None }; Ok(Prio3PrepareShare { verifier, joint_rand_part, }) } } #[derive(Clone, Debug, Eq, PartialEq)] /// Result of combining a round of [`Prio3PrepareShare`] messages. pub struct Prio3PrepareMessage { /// The joint randomness seed computed by the Aggregators. joint_rand_seed: Option>, } impl Encode for Prio3PrepareMessage { fn encode(&self, bytes: &mut Vec) { if let Some(ref seed) = self.joint_rand_seed { seed.encode(bytes); } } } impl ParameterizedDecode> for Prio3PrepareMessage { fn decode_with_param( decoding_parameter: &Prio3PrepareState, bytes: &mut Cursor<&[u8]>, ) -> Result { let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() { Some(Seed::decode(bytes)?) } else { None }; Ok(Prio3PrepareMessage { joint_rand_seed }) } } impl Client for Prio3 where T: Type, P: Prg, { #[allow(clippy::type_complexity)] fn shard( &self, measurement: &T::Measurement, ) -> Result<((), Vec>), VdafError> { self.shard_with_rand_source(measurement, getrandom::getrandom) .map(|input_shares| ((), input_shares)) } } /// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Prio3PrepareState { input_share: Share, joint_rand_seed: Option>, agg_id: u8, verifier_len: usize, } impl Encode for Prio3PrepareState { /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. fn encode(&self, bytes: &mut Vec) { self.input_share.encode(bytes); if let Some(ref seed) = self.joint_rand_seed { seed.encode(bytes); } } } impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3, usize)> for Prio3PrepareState where T: Type, P: Prg, { fn decode_with_param( (prio3, agg_id): &(&'a Prio3, usize), bytes: &mut Cursor<&[u8]>, ) -> Result { let agg_id = prio3 .role_try_from(*agg_id) .map_err(|e| CodecError::Other(Box::new(e)))?; let share_decoder = if agg_id == 0 { ShareDecodingParameter::Leader(prio3.typ.input_len()) } else { ShareDecodingParameter::Helper }; let input_share = Share::decode_with_param(&share_decoder, bytes)?; let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 { Some(Seed::decode(bytes)?) } else { None }; Ok(Self { input_share, joint_rand_seed, agg_id, verifier_len: prio3.typ.verifier_len(), }) } } impl Aggregator for Prio3 where T: Type, P: Prg, { type PrepareState = Prio3PrepareState; type PrepareShare = Prio3PrepareShare; type PrepareMessage = Prio3PrepareMessage; /// Begins the Prep process with the other aggregators. The result of this process is /// the aggregator's output share. #[allow(clippy::type_complexity)] fn prepare_init( &self, verify_key: &[u8; L], agg_id: usize, _agg_param: &(), nonce: &[u8], _public_share: &(), msg: &Prio3InputShare, ) -> Result< ( Prio3PrepareState, Prio3PrepareShare, ), VdafError, > { let agg_id = self.role_try_from(agg_id)?; let mut info = [0; DST_LEN + 1]; info[..VERSION.len()].copy_from_slice(VERSION); info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); info[DST_LEN] = agg_id; let domain_separation_tag = &info[..DST_LEN]; let mut deriver = P::init(verify_key); deriver.update(domain_separation_tag); deriver.update(&[255]); deriver.update(nonce); let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream()); // Create a reference to the (expanded) input share. let expanded_input_share: Option> = match msg.input_share { Share::Leader(_) => None, Share::Helper(ref seed) => { let prng = Prng::from_seed_stream(P::seed_stream(seed, &info)); Some(prng.take(self.typ.input_len()).collect()) } }; let input_share = match msg.input_share { Share::Leader(ref data) => data, Share::Helper(_) => expanded_input_share.as_ref().unwrap(), }; // Create a reference to the (expanded) proof share. let expanded_proof_share: Option> = match msg.proof_share { Share::Leader(_) => None, Share::Helper(ref seed) => { let prng = Prng::from_seed_stream(P::seed_stream(seed, &info)); Some(prng.take(self.typ.proof_len()).collect()) } }; let proof_share = match msg.proof_share { Share::Leader(ref data) => data, Share::Helper(_) => expanded_proof_share.as_ref().unwrap(), }; // Compute the joint randomness. let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 { let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref()); deriver.update(&info); for x in input_share { deriver.update(&(*x).into()); } let joint_rand_seed_part = deriver.into_seed(); let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint; let joint_rand_seed = Self::derive_joint_randomness( hints[..agg_id as usize] .iter() .chain(std::iter::once(&joint_rand_seed_part)) .chain(hints[agg_id as usize..].iter()), ); let prng: Prng = Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag)); ( Some(joint_rand_seed), Some(joint_rand_seed_part), prng.take(self.typ.joint_rand_len()).collect(), ) } else { (None, None, Vec::new()) }; // Compute the query randomness. let query_rand: Vec = query_rand_prng.take(self.typ.query_rand_len()).collect(); // Run the query-generation algorithm. let verifier_share = self.typ.query( input_share, proof_share, &query_rand, &joint_rand, self.num_aggregators as usize, )?; Ok(( Prio3PrepareState { input_share: msg.input_share.clone(), joint_rand_seed, agg_id, verifier_len: verifier_share.len(), }, Prio3PrepareShare { verifier: verifier_share, joint_rand_part: joint_rand_seed_part, }, )) } fn prepare_preprocess>>( &self, inputs: M, ) -> Result, VdafError> { let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()]; let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators()); let mut count = 0; for share in inputs.into_iter() { count += 1; if share.verifier.len() != verifier.len() { return Err(VdafError::Uncategorized(format!( "unexpected verifier share length: got {}; want {}", share.verifier.len(), verifier.len(), ))); } if self.typ.joint_rand_len() > 0 { let joint_rand_seed_part = share.joint_rand_part.unwrap(); joint_rand_parts.push(joint_rand_seed_part); } for (x, y) in verifier.iter_mut().zip(share.verifier) { *x += y; } } if count != self.num_aggregators { return Err(VdafError::Uncategorized(format!( "unexpected message count: got {}; want {}", count, self.num_aggregators, ))); } // Check the proof verifier. match self.typ.decide(&verifier) { Ok(true) => (), Ok(false) => { return Err(VdafError::Uncategorized( "proof verifier check failed".into(), )) } Err(err) => return Err(VdafError::from(err)), }; let joint_rand_seed = if self.typ.joint_rand_len() > 0 { Some(Self::derive_joint_randomness(joint_rand_parts.iter())) } else { None }; Ok(Prio3PrepareMessage { joint_rand_seed }) } fn prepare_step( &self, step: Prio3PrepareState, msg: Prio3PrepareMessage, ) -> Result, VdafError> { if self.typ.joint_rand_len() > 0 { // Check that the joint randomness was correct. if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() { return Err(VdafError::Uncategorized( "joint randomness mismatch".to_string(), )); } } // Compute the output share. let input_share = match step.input_share { Share::Leader(data) => data, Share::Helper(seed) => { let mut info = [0; DST_LEN + 1]; info[..VERSION.len()].copy_from_slice(VERSION); info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); info[DST_LEN] = step.agg_id; let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info)); prng.take(self.typ.input_len()).collect() } }; let output_share = match self.typ.truncate(input_share) { Ok(data) => OutputShare(data), Err(err) => { return Err(VdafError::from(err)); } }; Ok(PrepareTransition::Finish(output_share)) } /// Aggregates a sequence of output shares into an aggregate share. fn aggregate>>( &self, _agg_param: &(), output_shares: It, ) -> Result, VdafError> { let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); for output_share in output_shares.into_iter() { agg_share.accumulate(&output_share)?; } Ok(agg_share) } } impl Collector for Prio3 where T: Type, P: Prg, { /// Combines aggregate shares into the aggregate result. fn unshard>>( &self, _agg_param: &(), agg_shares: It, num_measurements: usize, ) -> Result { let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); for agg_share in agg_shares.into_iter() { agg.merge(&agg_share)?; } Ok(self.typ.decode_result(&agg.0, num_measurements)?) } } #[derive(Clone, Debug, Eq, PartialEq)] struct JointRandParam { /// The joint randomness seed parts corresponding to the other Aggregators' shares. seed_hint: Vec>, /// The blinding factor, used to derive the aggregator's joint randomness seed part. blind: Seed, } #[derive(Clone)] struct HelperShare { input_share: Seed, proof_share: Seed, joint_rand_param: JointRandParam, } impl HelperShare { fn from_rand_source(rand_source: RandSource) -> Result { Ok(HelperShare { input_share: Seed::from_rand_source(rand_source)?, proof_share: Seed::from_rand_source(rand_source)?, joint_rand_param: JointRandParam { seed_hint: Vec::new(), blind: Seed::from_rand_source(rand_source)?, }, }) } } fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> { if num_aggregators == 0 { return Err(VdafError::Uncategorized(format!( "at least one aggregator is required; got {}", num_aggregators ))); } else if num_aggregators > 254 { return Err(VdafError::Uncategorized(format!( "number of aggregators must not exceed 254; got {}", num_aggregators ))); } Ok(()) } #[cfg(test)] mod tests { use super::*; use crate::vdaf::{run_vdaf, run_vdaf_prepare}; use assert_matches::assert_matches; use rand::prelude::*; #[test] fn test_prio3_count() { let prio3 = Prio3::new_aes128_count(2).unwrap(); assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3); let mut verify_key = [0; 16]; thread_rng().fill(&mut verify_key[..]); let nonce = b"This is a good nonce."; let (public_share, input_shares) = prio3.shard(&0).unwrap(); run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap(); let (public_share, input_shares) = prio3.shard(&1).unwrap(); run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap(); test_prepare_state_serialization(&prio3, &1).unwrap(); let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap(); assert_eq!( run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(), 3, ); } #[test] fn test_prio3_sum() { let prio3 = Prio3::new_aes128_sum(3, 16).unwrap(); assert_eq!( run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), (1 << 16) + 1 ); let mut verify_key = [0; 16]; thread_rng().fill(&mut verify_key[..]); let nonce = b"This is a good nonce."; let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255; let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255; let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); assert_matches!(result, Err(VdafError::Uncategorized(_))); test_prepare_state_serialization(&prio3, &1).unwrap(); } #[test] fn test_prio3_countvec() { let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap(); assert_eq!( run_vdaf( &prio3, &(), [vec![ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, ]] ) .unwrap(), vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,] ); } #[test] #[cfg(feature = "multithreaded")] fn test_prio3_countvec_multithreaded() { let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap(); assert_eq!( run_vdaf( &prio3, &(), [vec![ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, ]] ) .unwrap(), vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,] ); } #[test] fn test_prio3_histogram() { let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap(); assert_eq!( run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(), vec![1, 1, 1, 1] ); assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]); assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]); assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]); assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]); assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]); test_prepare_state_serialization(&prio3, &23).unwrap(); } #[test] fn test_prio3_average() { let prio3 = Prio3::new_aes128_average(2, 64).unwrap(); assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64); assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64); assert_eq!( run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), 207.5f64 ); } #[test] fn test_prio3_input_share() { let prio3 = Prio3::new_aes128_sum(5, 16).unwrap(); let (_public_share, input_shares) = prio3.shard(&1).unwrap(); // Check that seed shares are distinct. for (i, x) in input_shares.iter().enumerate() { for (j, y) in input_shares.iter().enumerate() { if i != j { if let (Share::Helper(left), Share::Helper(right)) = (&x.input_share, &y.input_share) { assert_ne!(left, right); } if let (Share::Helper(left), Share::Helper(right)) = (&x.proof_share, &y.proof_share) { assert_ne!(left, right); } assert_ne!(x.joint_rand_param, y.joint_rand_param); } } } } fn test_prepare_state_serialization( prio3: &Prio3, measurement: &T::Measurement, ) -> Result<(), VdafError> where T: Type, P: Prg, { let mut verify_key = [0; L]; thread_rng().fill(&mut verify_key[..]); let (public_share, input_shares) = prio3.shard(measurement)?; for (agg_id, input_share) in input_shares.iter().enumerate() { let (want, _msg) = prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?; let got = Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded()) .expect("failed to decode prepare step"); assert_eq!(got, want); } Ok(()) } }