From 8dd16259287f58f9273002717ec4d27e97127719 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 12 Jun 2024 07:43:14 +0200 Subject: Merging upstream version 127.0. Signed-off-by: Daniel Baumann --- third_party/rust/prio/src/vdaf/prio3.rs | 630 ++++++++++++++++++++------------ 1 file changed, 389 insertions(+), 241 deletions(-) (limited to 'third_party/rust/prio/src/vdaf/prio3.rs') diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs index 4a7cdefb84..084f87f411 100644 --- a/third_party/rust/prio/src/vdaf/prio3.rs +++ b/third_party/rust/prio/src/vdaf/prio3.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-07]]. +//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-08]]. //! //! **WARNING:** This code has not undergone significant security analysis. Use at your own risk. //! @@ -9,7 +9,7 @@ //! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication //! cost. The security of the construction was analyzed in [[DPRS23]]. //! -//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-07]] into +//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-08]] into //! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of //! which are instantiated here: //! @@ -20,14 +20,14 @@ //! //! Additional types can be constructed from [`Prio3`] as needed. //! -//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-07]]. +//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-08]]. //! //! [BBCG+19]: https://ia.cr/2019/188 //! [CGB17]: https://crypto.stanford.edu/prio/ //! [DPRS23]: https://ia.cr/2023/130 -//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/ -use super::xof::XofShake128; +use super::xof::XofTurboShake128; #[cfg(feature = "experimental")] use super::AggregatorWithNoise; use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; @@ -72,19 +72,19 @@ const DST_JOINT_RAND_SEED: u16 = 6; const DST_JOINT_RAND_PART: u16 = 7; /// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum. -pub type Prio3Count = Prio3, XofShake128, 16>; +pub type Prio3Count = Prio3, XofTurboShake128, 16>; impl Prio3Count { /// Construct an instance of Prio3Count with the given number of aggregators. pub fn new_count(num_aggregators: u8) -> Result { - Prio3::new(num_aggregators, Count::new()) + Prio3::new(num_aggregators, 1, 0x00000000, Count::new()) } } /// The count-vector type. Each measurement is a vector of integers in `[0,2^bits)` and the /// aggregate is the element-wise sum. pub type Prio3SumVec = - Prio3>>, XofShake128, 16>; + Prio3>>, XofTurboShake128, 16>; impl Prio3SumVec { /// Construct an instance of Prio3SumVec with the given number of aggregators. `bits` defines @@ -96,7 +96,12 @@ impl Prio3SumVec { len: usize, chunk_length: usize, ) -> Result { - Prio3::new(num_aggregators, SumVec::new(bits, len, chunk_length)?) + Prio3::new( + num_aggregators, + 1, + 0x00000002, + SumVec::new(bits, len, chunk_length)?, + ) } } @@ -104,8 +109,11 @@ impl Prio3SumVec { /// time. Note that the improvement is only noticeable for very large input lengths. #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] -pub type Prio3SumVecMultithreaded = - Prio3>>, XofShake128, 16>; +pub type Prio3SumVecMultithreaded = Prio3< + SumVec>>, + XofTurboShake128, + 16, +>; #[cfg(feature = "multithreaded")] impl Prio3SumVecMultithreaded { @@ -118,13 +126,18 @@ impl Prio3SumVecMultithreaded { len: usize, chunk_length: usize, ) -> Result { - Prio3::new(num_aggregators, SumVec::new(bits, len, chunk_length)?) + Prio3::new( + num_aggregators, + 1, + 0x00000002, + SumVec::new(bits, len, chunk_length)?, + ) } } /// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the /// aggregate is the sum. -pub type Prio3Sum = Prio3, XofShake128, 16>; +pub type Prio3Sum = Prio3, XofTurboShake128, 16>; impl Prio3Sum { /// Construct an instance of Prio3Sum with the given number of aggregators and required bit @@ -136,7 +149,7 @@ impl Prio3Sum { ))); } - Prio3::new(num_aggregators, Sum::new(bits)?) + Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(bits)?) } } @@ -160,7 +173,7 @@ pub type Prio3FixedPointBoundedL2VecSum = Prio3< ParallelSum>, ParallelSum>, >, - XofShake128, + XofTurboShake128, 16, >; @@ -173,7 +186,12 @@ impl Prio3FixedPointBoundedL2VecSum { entries: usize, ) -> Result { check_num_aggregators(num_aggregators)?; - Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?) + Prio3::new( + num_aggregators, + 1, + 0xFFFF0000, + FixedPointBoundedL2VecSum::new(entries)?, + ) } } @@ -191,7 +209,7 @@ pub type Prio3FixedPointBoundedL2VecSumMultithreaded = Prio3< ParallelSumMultithreaded>, ParallelSumMultithreaded>, >, - XofShake128, + XofTurboShake128, 16, >; @@ -204,14 +222,19 @@ impl Prio3FixedPointBoundedL2VecSumMultithreaded Result { check_num_aggregators(num_aggregators)?; - Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?) + Prio3::new( + num_aggregators, + 1, + 0xFFFF0000, + FixedPointBoundedL2VecSum::new(entries)?, + ) } } /// The histogram type. Each measurement is an integer in `[0, length)` and the result is a /// histogram counting the number of occurrences of each measurement. pub type Prio3Histogram = - Prio3>>, XofShake128, 16>; + Prio3>>, XofTurboShake128, 16>; impl Prio3Histogram { /// Constructs an instance of Prio3Histogram with the given number of aggregators, @@ -221,7 +244,12 @@ impl Prio3Histogram { length: usize, chunk_length: usize, ) -> Result { - Prio3::new(num_aggregators, Histogram::new(length, chunk_length)?) + Prio3::new( + num_aggregators, + 1, + 0x00000003, + Histogram::new(length, chunk_length)?, + ) } } @@ -229,8 +257,11 @@ impl Prio3Histogram { /// time. Note that this improvement is only noticeable for very large input lengths. #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] -pub type Prio3HistogramMultithreaded = - Prio3>>, XofShake128, 16>; +pub type Prio3HistogramMultithreaded = Prio3< + Histogram>>, + XofTurboShake128, + 16, +>; #[cfg(feature = "multithreaded")] impl Prio3HistogramMultithreaded { @@ -241,13 +272,18 @@ impl Prio3HistogramMultithreaded { length: usize, chunk_length: usize, ) -> Result { - Prio3::new(num_aggregators, Histogram::new(length, chunk_length)?) + Prio3::new( + num_aggregators, + 1, + 0x00000003, + Histogram::new(length, chunk_length)?, + ) } } /// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and /// the aggregate is the arithmetic average. -pub type Prio3Average = Prio3, XofShake128, 16>; +pub type Prio3Average = Prio3, XofTurboShake128, 16>; impl Prio3Average { /// Construct an instance of Prio3Average with the given number of aggregators and required bit @@ -263,6 +299,8 @@ impl Prio3Average { Ok(Prio3 { num_aggregators, + num_proofs: 1, + algorithm_id: 0xFFFF0000, typ: Average::new(bits)?, phantom: PhantomData, }) @@ -277,7 +315,7 @@ impl Prio3Average { /// - a [`Xof`] for deriving vectors of field elements from seeds. /// /// New instances can be defined by aliasing the base type. For example, [`Prio3Count`] is an alias -/// for `Prio3, XofShake128, 16>`. +/// for `Prio3, XofTurboShake128, 16>`. /// /// ``` /// use prio::vdaf::{ @@ -292,7 +330,7 @@ impl Prio3Average { /// let mut out_shares = vec![vec![]; num_shares.into()]; /// let mut rng = thread_rng(); /// let verify_key = rng.gen(); -/// let measurements = [0, 1, 1, 1, 0]; +/// let measurements = [false, true, true, true, false]; /// for measurement in measurements { /// // Shard /// let nonce = rng.gen::<[u8; 16]>(); @@ -316,7 +354,7 @@ impl Prio3Average { /// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap(); /// /// for (agg_id, state) in prep_states.into_iter().enumerate() { -/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() { +/// let out_share = match vdaf.prepare_next(state, prep_msg.clone()).unwrap() { /// PrepareTransition::Finish(out_share) => out_share, /// _ => panic!("unexpected transition"), /// }; @@ -339,6 +377,8 @@ where P: Xof, { num_aggregators: u8, + num_proofs: u8, + algorithm_id: u32, typ: T, phantom: PhantomData

, } @@ -348,12 +388,25 @@ where T: Type, P: Xof, { - /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the - /// underlying type. - pub fn new(num_aggregators: u8, typ: T) -> Result { + /// Construct an instance of this Prio3 VDAF with the given number of aggregators, number of + /// proofs to generate and verify, the algorithm ID, and the underlying type. + pub fn new( + num_aggregators: u8, + num_proofs: u8, + algorithm_id: u32, + typ: T, + ) -> Result { check_num_aggregators(num_aggregators)?; + if num_proofs == 0 { + return Err(VdafError::Uncategorized( + "num_proofs must be at least 1".to_string(), + )); + } + Ok(Self { num_aggregators, + num_proofs, + algorithm_id, typ, phantom: PhantomData, }) @@ -369,19 +422,72 @@ where self.typ.verifier_len() } + #[inline] + fn num_proofs(&self) -> usize { + self.num_proofs.into() + } + + fn derive_prove_rands(&self, prove_rand_seed: &Seed) -> Vec { + P::seed_stream( + prove_rand_seed, + &self.domain_separation_tag(DST_PROVE_RANDOMNESS), + &[self.num_proofs], + ) + .into_field_vec(self.typ.prove_rand_len() * self.num_proofs()) + } + fn derive_joint_rand_seed<'a>( - parts: impl Iterator>, + &self, + joint_rand_parts: impl Iterator>, ) -> Seed { let mut xof = P::init( &[0; SEED_SIZE], - &Self::domain_separation_tag(DST_JOINT_RAND_SEED), + &self.domain_separation_tag(DST_JOINT_RAND_SEED), ); - for part in parts { + for part in joint_rand_parts { xof.update(part.as_ref()); } xof.into_seed() } + fn derive_joint_rands<'a>( + &self, + joint_rand_parts: impl Iterator>, + ) -> (Seed, Vec) { + let joint_rand_seed = self.derive_joint_rand_seed(joint_rand_parts); + let joint_rands = P::seed_stream( + &joint_rand_seed, + &self.domain_separation_tag(DST_JOINT_RANDOMNESS), + &[self.num_proofs], + ) + .into_field_vec(self.typ.joint_rand_len() * self.num_proofs()); + + (joint_rand_seed, joint_rands) + } + + fn derive_helper_proofs_share( + &self, + proofs_share_seed: &Seed, + agg_id: u8, + ) -> Prng { + Prng::from_seed_stream(P::seed_stream( + proofs_share_seed, + &self.domain_separation_tag(DST_PROOF_SHARE), + &[self.num_proofs, agg_id], + )) + } + + fn derive_query_rands(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec { + let mut xof = P::init( + verify_key, + &self.domain_separation_tag(DST_QUERY_RANDOMNESS), + ); + xof.update(&[self.num_proofs]); + xof.update(nonce); + xof.into_seed_stream() + .into_field_vec(self.typ.query_rand_len() * self.num_proofs()) + } + fn random_size(&self) -> usize { if self.typ.joint_rand_len() == 0 { // Two seeds per helper for measurement and proof shares, plus one seed for proving @@ -438,42 +544,45 @@ where let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap(); let measurement_share_prng: Prng = Prng::from_seed_stream(P::seed_stream( &Seed(measurement_share_seed), - &Self::domain_separation_tag(DST_MEASUREMENT_SHARE), + &self.domain_separation_tag(DST_MEASUREMENT_SHARE), &[agg_id], )); - let joint_rand_blind = - if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() { - let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap(); - let mut joint_rand_part_xof = P::init( - &joint_rand_blind, - &Self::domain_separation_tag(DST_JOINT_RAND_PART), - ); - joint_rand_part_xof.update(&[agg_id]); // Aggregator ID - joint_rand_part_xof.update(nonce); - - let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); - for (x, y) in leader_measurement_share - .iter_mut() - .zip(measurement_share_prng) - { - *x -= y; - y.encode(&mut encoding_buffer); - joint_rand_part_xof.update(&encoding_buffer); - encoding_buffer.clear(); - } + let joint_rand_blind = if let Some(helper_joint_rand_parts) = + helper_joint_rand_parts.as_mut() + { + let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap(); + let mut joint_rand_part_xof = P::init( + &joint_rand_blind, + &self.domain_separation_tag(DST_JOINT_RAND_PART), + ); + joint_rand_part_xof.update(&[agg_id]); // Aggregator ID + joint_rand_part_xof.update(nonce); + + let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); + for (x, y) in leader_measurement_share + .iter_mut() + .zip(measurement_share_prng) + { + *x -= y; + y.encode(&mut encoding_buffer).map_err(|_| { + VdafError::Uncategorized("failed to encode measurement share".to_string()) + })?; + joint_rand_part_xof.update(&encoding_buffer); + encoding_buffer.clear(); + } - helper_joint_rand_parts.push(joint_rand_part_xof.into_seed()); + helper_joint_rand_parts.push(joint_rand_part_xof.into_seed()); - Some(joint_rand_blind) - } else { - for (x, y) in leader_measurement_share - .iter_mut() - .zip(measurement_share_prng) - { - *x -= y; - } - None - }; + Some(joint_rand_blind) + } else { + for (x, y) in leader_measurement_share + .iter_mut() + .zip(measurement_share_prng) + { + *x -= y; + } + None + }; let helper = HelperShare::from_seeds(measurement_share_seed, proof_share_seed, joint_rand_blind); helper_shares.push(helper); @@ -483,71 +592,75 @@ where let public_share = Prio3PublicShare { joint_rand_parts: helper_joint_rand_parts .as_ref() - .map(|helper_joint_rand_parts| { - let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap(); - let leader_blind = Seed::from_bytes(leader_blind_bytes); - - let mut joint_rand_part_xof = P::init( - leader_blind.as_ref(), - &Self::domain_separation_tag(DST_JOINT_RAND_PART), - ); - joint_rand_part_xof.update(&[0]); // Aggregator ID - joint_rand_part_xof.update(nonce); - let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); - for x in leader_measurement_share.iter() { - x.encode(&mut encoding_buffer); - joint_rand_part_xof.update(&encoding_buffer); - encoding_buffer.clear(); - } - leader_blind_opt = Some(leader_blind); - - let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed(); - - let mut vec = Vec::with_capacity(self.num_aggregators()); - vec.push(leader_joint_rand_seed_part); - vec.extend(helper_joint_rand_parts.iter().cloned()); - vec - }), + .map( + |helper_joint_rand_parts| -> Result>, VdafError> { + let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap(); + let leader_blind = Seed::from_bytes(leader_blind_bytes); + + let mut joint_rand_part_xof = P::init( + leader_blind.as_ref(), + &self.domain_separation_tag(DST_JOINT_RAND_PART), + ); + joint_rand_part_xof.update(&[0]); // Aggregator ID + joint_rand_part_xof.update(nonce); + let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); + for x in leader_measurement_share.iter() { + x.encode(&mut encoding_buffer).map_err(|_| { + VdafError::Uncategorized( + "failed to encode measurement share".to_string(), + ) + })?; + joint_rand_part_xof.update(&encoding_buffer); + encoding_buffer.clear(); + } + leader_blind_opt = Some(leader_blind); + + let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed(); + + let mut vec = Vec::with_capacity(self.num_aggregators()); + vec.push(leader_joint_rand_seed_part); + vec.extend(helper_joint_rand_parts.iter().cloned()); + Ok(vec) + }, + ) + .transpose()?, }; // Compute the joint randomness. - let joint_rand: Vec = public_share + let joint_rands = public_share .joint_rand_parts .as_ref() - .map(|joint_rand_parts| { - let joint_rand_seed = Self::derive_joint_rand_seed(joint_rand_parts.iter()); - P::seed_stream( - &joint_rand_seed, - &Self::domain_separation_tag(DST_JOINT_RANDOMNESS), - &[], - ) - .into_field_vec(self.typ.joint_rand_len()) - }) + .map(|joint_rand_parts| self.derive_joint_rands(joint_rand_parts.iter()).1) .unwrap_or_default(); - // Run the proof-generation algorithm. - let prove_rand_seed = random_seeds.next().unwrap().try_into().unwrap(); - let prove_rand = P::seed_stream( - &Seed::from_bytes(prove_rand_seed), - &Self::domain_separation_tag(DST_PROVE_RANDOMNESS), - &[], - ) - .into_field_vec(self.typ.prove_rand_len()); - let mut leader_proof_share = - self.typ - .prove(&encoded_measurement, &prove_rand, &joint_rand)?; + // Generate the proofs. + let prove_rands = self.derive_prove_rands(&Seed::from_bytes( + random_seeds.next().unwrap().try_into().unwrap(), + )); + let mut leader_proofs_share = Vec::with_capacity(self.typ.proof_len() * self.num_proofs()); + for p in 0..self.num_proofs() { + let prove_rand = + &prove_rands[p * self.typ.prove_rand_len()..(p + 1) * self.typ.prove_rand_len()]; + let joint_rand = + &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()]; + + leader_proofs_share.append(&mut self.typ.prove( + &encoded_measurement, + prove_rand, + joint_rand, + )?); + } // Generate the proof shares and distribute the joint randomness seed hints. for (j, helper) in helper_shares.iter_mut().enumerate() { - let proof_share_prng: Prng = Prng::from_seed_stream(P::seed_stream( - &helper.proof_share, - &Self::domain_separation_tag(DST_PROOF_SHARE), - &[j as u8 + 1], - )); - for (x, y) in leader_proof_share - .iter_mut() - .zip(proof_share_prng) - .take(self.typ.proof_len()) + for (x, y) in + leader_proofs_share + .iter_mut() + .zip(self.derive_helper_proofs_share( + &helper.proofs_share, + u8::try_from(j).unwrap() + 1, + )) + .take(self.typ.proof_len() * self.num_proofs()) { *x -= y; } @@ -557,14 +670,14 @@ where let mut out = Vec::with_capacity(num_aggregators as usize); out.push(Prio3InputShare { measurement_share: Share::Leader(leader_measurement_share), - proof_share: Share::Leader(leader_proof_share), + proofs_share: Share::Leader(leader_proofs_share), joint_rand_blind: leader_blind_opt, }); for helper in helper_shares.into_iter() { out.push(Prio3InputShare { measurement_share: Share::Helper(helper.measurement_share), - proof_share: Share::Helper(helper.proof_share), + proofs_share: Share::Helper(helper.proofs_share), joint_rand_blind: helper.joint_rand_blind, }); } @@ -585,7 +698,6 @@ where T: Type, P: Xof, { - const ID: u32 = T::ID; type Measurement = T::Measurement; type AggregateResult = T::AggregateResult; type AggregationParam = (); @@ -594,6 +706,10 @@ where type OutputShare = OutputShare; type AggregateShare = AggregateShare; + fn algorithm_id(&self) -> u32 { + self.algorithm_id + } + fn num_aggregators(&self) -> usize { self.num_aggregators as usize } @@ -607,12 +723,13 @@ pub struct Prio3PublicShare { } impl Encode for Prio3PublicShare { - fn encode(&self, bytes: &mut Vec) { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() { for part in joint_rand_parts.iter() { - part.encode(bytes); + part.encode(bytes)?; } } + Ok(()) } fn encoded_len(&self) -> Option { @@ -675,7 +792,7 @@ pub struct Prio3InputShare { measurement_share: Share, /// The proof share. - proof_share: Share, + proofs_share: Share, /// Blinding seed used by the Aggregator to compute the joint randomness. This field is optional /// because not every [`Type`] requires joint randomness. @@ -697,28 +814,29 @@ impl ConstantTimeEq for Prio3InputSha self.joint_rand_blind.as_ref(), other.joint_rand_blind.as_ref(), ) & self.measurement_share.ct_eq(&other.measurement_share) - & self.proof_share.ct_eq(&other.proof_share) + & self.proofs_share.ct_eq(&other.proofs_share) } } impl Encode for Prio3InputShare { - fn encode(&self, bytes: &mut Vec) { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { if matches!( - (&self.measurement_share, &self.proof_share), + (&self.measurement_share, &self.proofs_share), (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_)) ) { panic!("tried to encode input share with ambiguous encoding") } - self.measurement_share.encode(bytes); - self.proof_share.encode(bytes); + self.measurement_share.encode(bytes)?; + self.proofs_share.encode(bytes)?; if let Some(ref blind) = self.joint_rand_blind { - blind.encode(bytes); + blind.encode(bytes)?; } + Ok(()) } fn encoded_len(&self) -> Option { - let mut len = self.measurement_share.encoded_len()? + self.proof_share.encoded_len()?; + let mut len = self.measurement_share.encoded_len()? + self.proofs_share.encoded_len()?; if let Some(ref blind) = self.joint_rand_blind { len += blind.encoded_len()?; } @@ -742,7 +860,7 @@ where let (input_decoder, proof_decoder) = if agg_id == 0 { ( ShareDecodingParameter::Leader(prio3.typ.input_len()), - ShareDecodingParameter::Leader(prio3.typ.proof_len()), + ShareDecodingParameter::Leader(prio3.typ.proof_len() * prio3.num_proofs()), ) } else { ( @@ -752,7 +870,7 @@ where }; let measurement_share = Share::decode_with_param(&input_decoder, bytes)?; - let proof_share = Share::decode_with_param(&proof_decoder, bytes)?; + let proofs_share = Share::decode_with_param(&proof_decoder, bytes)?; let joint_rand_blind = if prio3.typ.joint_rand_len() > 0 { let blind = Seed::decode(bytes)?; Some(blind) @@ -762,7 +880,7 @@ where Ok(Prio3InputShare { measurement_share, - proof_share, + proofs_share, joint_rand_blind, }) } @@ -772,7 +890,7 @@ where /// Message broadcast by each [`Aggregator`] in each round of the Preparation phase. pub struct Prio3PrepareShare { /// A share of the FLP verifier message. (See [`Type`].) - verifier: Vec, + verifiers: Vec, /// A part of the joint randomness seed. joint_rand_part: Option>, @@ -792,25 +910,26 @@ impl ConstantTimeEq for Prio3PrepareS option_ct_eq( self.joint_rand_part.as_ref(), other.joint_rand_part.as_ref(), - ) & self.verifier.ct_eq(&other.verifier) + ) & self.verifiers.ct_eq(&other.verifiers) } } impl Encode for Prio3PrepareShare { - fn encode(&self, bytes: &mut Vec) { - for x in &self.verifier { - x.encode(bytes); + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + for x in &self.verifiers { + x.encode(bytes)?; } if let Some(ref seed) = self.joint_rand_part { - seed.encode(bytes); + seed.encode(bytes)?; } + Ok(()) } fn encoded_len(&self) -> Option { // Each element of the verifier has the same size. - let mut len = F::ENCODED_SIZE * self.verifier.len(); + let mut len = F::ENCODED_SIZE * self.verifiers.len(); if let Some(ref seed) = self.joint_rand_part { len += seed.encoded_len()?; } @@ -825,9 +944,9 @@ impl decoding_parameter: &Prio3PrepareState, bytes: &mut Cursor<&[u8]>, ) -> Result { - let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len); - for _ in 0..decoding_parameter.verifier_len { - verifier.push(F::decode(bytes)?); + let mut verifiers = Vec::with_capacity(decoding_parameter.verifiers_len); + for _ in 0..decoding_parameter.verifiers_len { + verifiers.push(F::decode(bytes)?); } let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() { @@ -837,7 +956,7 @@ impl }; Ok(Prio3PrepareShare { - verifier, + verifiers, joint_rand_part, }) } @@ -869,10 +988,11 @@ impl ConstantTimeEq for Prio3PrepareMessage { } impl Encode for Prio3PrepareMessage { - fn encode(&self, bytes: &mut Vec) { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { if let Some(ref seed) = self.joint_rand_seed { - seed.encode(bytes); + seed.encode(bytes)?; } + Ok(()) } fn encoded_len(&self) -> Option { @@ -924,7 +1044,7 @@ pub struct Prio3PrepareState { measurement_share: Share, joint_rand_seed: Option>, agg_id: u8, - verifier_len: usize, + verifiers_len: usize, } impl PartialEq for Prio3PrepareState { @@ -939,7 +1059,7 @@ impl ConstantTimeEq for Prio3PrepareS fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_seed, as well as // the aggregator ID & verifier length parameters. - if self.agg_id != other.agg_id || self.verifier_len != other.verifier_len { + if self.agg_id != other.agg_id || self.verifiers_len != other.verifiers_len { return Choice::from(0); } @@ -962,7 +1082,7 @@ impl Debug for Prio3PrepareState { }, ) .field("agg_id", &self.agg_id) - .field("verifier_len", &self.verifier_len) + .field("verifiers_len", &self.verifiers_len) .finish() } } @@ -971,11 +1091,12 @@ impl Encode for Prio3PrepareState { /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. - fn encode(&self, bytes: &mut Vec) { - self.measurement_share.encode(bytes); + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + self.measurement_share.encode(bytes)?; if let Some(ref seed) = self.joint_rand_seed { - seed.encode(bytes); + seed.encode(bytes)?; } + Ok(()) } fn encoded_len(&self) -> Option { @@ -1018,7 +1139,7 @@ where measurement_share, joint_rand_seed, agg_id, - verifier_len: prio3.typ.verifier_len(), + verifiers_len: prio3.typ.verifier_len() * prio3.num_proofs(), }) } } @@ -1051,14 +1172,6 @@ where VdafError, > { let agg_id = self.role_try_from(agg_id)?; - let mut query_rand_xof = P::init( - verify_key, - &Self::domain_separation_tag(DST_QUERY_RANDOMNESS), - ); - query_rand_xof.update(nonce); - let query_rand = query_rand_xof - .into_seed_stream() - .into_field_vec(self.typ.query_rand_len()); // Create a reference to the (expanded) measurement share. let expanded_measurement_share: Option> = match msg.measurement_share { @@ -1066,7 +1179,7 @@ where Share::Helper(ref seed) => Some( P::seed_stream( seed, - &Self::domain_separation_tag(DST_MEASUREMENT_SHARE), + &self.domain_separation_tag(DST_MEASUREMENT_SHARE), &[agg_id], ) .into_field_vec(self.typ.input_len()), @@ -1078,33 +1191,32 @@ where }; // Create a reference to the (expanded) proof share. - let expanded_proof_share: Option> = match msg.proof_share { + let expanded_proofs_share: Option> = match msg.proofs_share { Share::Leader(_) => None, - Share::Helper(ref seed) => Some( - P::seed_stream( - seed, - &Self::domain_separation_tag(DST_PROOF_SHARE), - &[agg_id], - ) - .into_field_vec(self.typ.proof_len()), + Share::Helper(ref proof_shares_seed) => Some( + self.derive_helper_proofs_share(proof_shares_seed, agg_id) + .take(self.typ.proof_len() * self.num_proofs()) + .collect(), ), }; - let proof_share = match msg.proof_share { + let proofs_share = match msg.proofs_share { Share::Leader(ref data) => data, - Share::Helper(_) => expanded_proof_share.as_ref().unwrap(), + Share::Helper(_) => expanded_proofs_share.as_ref().unwrap(), }; // Compute the joint randomness. - let (joint_rand_seed, joint_rand_part, joint_rand) = if self.typ.joint_rand_len() > 0 { + let (joint_rand_seed, joint_rand_part, joint_rands) = if self.typ.joint_rand_len() > 0 { let mut joint_rand_part_xof = P::init( msg.joint_rand_blind.as_ref().unwrap().as_ref(), - &Self::domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART), ); joint_rand_part_xof.update(&[agg_id]); joint_rand_part_xof.update(nonce); let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); for x in measurement_share { - x.encode(&mut encoding_buffer); + x.encode(&mut encoding_buffer).map_err(|_| { + VdafError::Uncategorized("failed to encode measurement share".to_string()) + })?; joint_rand_part_xof.update(&encoding_buffer); encoding_buffer.clear(); } @@ -1131,37 +1243,47 @@ where .skip(agg_id as usize + 1), ); - let joint_rand_seed = Self::derive_joint_rand_seed(corrected_joint_rand_parts); + let (joint_rand_seed, joint_rands) = + self.derive_joint_rands(corrected_joint_rand_parts); - let joint_rand = P::seed_stream( - &joint_rand_seed, - &Self::domain_separation_tag(DST_JOINT_RANDOMNESS), - &[], + ( + Some(joint_rand_seed), + Some(own_joint_rand_part), + joint_rands, ) - .into_field_vec(self.typ.joint_rand_len()); - (Some(joint_rand_seed), Some(own_joint_rand_part), joint_rand) } else { (None, None, Vec::new()) }; // Run the query-generation algorithm. - let verifier_share = self.typ.query( - measurement_share, - proof_share, - &query_rand, - &joint_rand, - self.num_aggregators as usize, - )?; + let query_rands = self.derive_query_rands(verify_key, nonce); + let mut verifiers_share = Vec::with_capacity(self.typ.verifier_len() * self.num_proofs()); + for p in 0..self.num_proofs() { + let query_rand = + &query_rands[p * self.typ.query_rand_len()..(p + 1) * self.typ.query_rand_len()]; + let joint_rand = + &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()]; + let proof_share = + &proofs_share[p * self.typ.proof_len()..(p + 1) * self.typ.proof_len()]; + + verifiers_share.append(&mut self.typ.query( + measurement_share, + proof_share, + query_rand, + joint_rand, + self.num_aggregators as usize, + )?); + } Ok(( Prio3PrepareState { measurement_share: msg.measurement_share.clone(), joint_rand_seed, agg_id, - verifier_len: verifier_share.len(), + verifiers_len: verifiers_share.len(), }, Prio3PrepareShare { - verifier: verifier_share, + verifiers: verifiers_share, joint_rand_part, }, )) @@ -1174,17 +1296,17 @@ where _: &Self::AggregationParam, inputs: M, ) -> Result, VdafError> { - let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()]; + let mut verifiers = vec![T::Field::zero(); self.typ.verifier_len() * self.num_proofs()]; let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators()); let mut count = 0; for share in inputs.into_iter() { count += 1; - if share.verifier.len() != verifier.len() { + if share.verifiers.len() != verifiers.len() { return Err(VdafError::Uncategorized(format!( "unexpected verifier share length: got {}; want {}", - share.verifier.len(), - verifier.len(), + share.verifiers.len(), + verifiers.len(), ))); } @@ -1193,7 +1315,7 @@ where joint_rand_parts.push(joint_rand_seed_part); } - for (x, y) in verifier.iter_mut().zip(share.verifier) { + for (x, y) in verifiers.iter_mut().zip(share.verifiers) { *x += y; } } @@ -1205,19 +1327,17 @@ where ))); } - // Check the proof verifier. - match self.typ.decide(&verifier) { - Ok(true) => (), - Ok(false) => { + // Check the proof verifiers. + for verifier in verifiers.chunks(self.typ.verifier_len()) { + if !self.typ.decide(verifier)? { return Err(VdafError::Uncategorized( "proof verifier check failed".into(), - )) + )); } - Err(err) => return Err(VdafError::from(err)), - }; + } let joint_rand_seed = if self.typ.joint_rand_len() > 0 { - Some(Self::derive_joint_rand_seed(joint_rand_parts.iter())) + Some(self.derive_joint_rand_seed(joint_rand_parts.iter())) } else { None }; @@ -1249,7 +1369,7 @@ where let measurement_share = match step.measurement_share { Share::Leader(data) => data, Share::Helper(seed) => { - let dst = Self::domain_separation_tag(DST_MEASUREMENT_SHARE); + let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE); P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len()) } }; @@ -1324,7 +1444,7 @@ where #[derive(Clone)] struct HelperShare { measurement_share: Seed, - proof_share: Seed, + proofs_share: Seed, joint_rand_blind: Option>, } @@ -1336,7 +1456,7 @@ impl HelperShare { ) -> Self { HelperShare { measurement_share: Seed::from_bytes(measurement_share), - proof_share: Seed::from_bytes(proof_share), + proofs_share: Seed::from_bytes(proof_share), joint_rand_blind: joint_rand_blind.map(Seed::from_bytes), } } @@ -1458,7 +1578,8 @@ mod tests { #[cfg(feature = "experimental")] use crate::flp::gadgets::ParallelSumGadget; use crate::vdaf::{ - equality_comparison_test, fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare, + equality_comparison_test, fieldvec_roundtrip_test, + test_utils::{run_vdaf, run_vdaf_prepare}, }; use assert_matches::assert_matches; #[cfg(feature = "experimental")] @@ -1474,24 +1595,27 @@ mod tests { fn test_prio3_count() { let prio3 = Prio3::new_count(2).unwrap(); - assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3); + assert_eq!( + run_vdaf(&prio3, &(), [true, false, false, true, true]).unwrap(), + 3 + ); let mut nonce = [0; 16]; let mut verify_key = [0; 16]; thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); - let (public_share, input_shares) = prio3.shard(&0, &nonce).unwrap(); + let (public_share, input_shares) = prio3.shard(&false, &nonce).unwrap(); run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); - let (public_share, input_shares) = prio3.shard(&1, &nonce).unwrap(); + let (public_share, input_shares) = prio3.shard(&true, &nonce).unwrap(); run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); - test_serialization(&prio3, &1, &nonce).unwrap(); + test_serialization(&prio3, &true, &nonce).unwrap(); let prio3_extra_helper = Prio3::new_count(3).unwrap(); assert_eq!( - run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(), + run_vdaf(&prio3_extra_helper, &(), [true, false, false, true, true]).unwrap(), 3, ); } @@ -1522,7 +1646,7 @@ mod tests { assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); - assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { + assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); @@ -1549,6 +1673,30 @@ mod tests { ); } + #[test] + fn test_prio3_sum_vec_multiproof() { + let prio3 = Prio3::< + SumVec>>, + XofTurboShake128, + 16, + >::new(2, 2, 0xFFFF0000, SumVec::new(2, 20, 4).unwrap()) + .unwrap(); + + assert_eq!( + run_vdaf( + &prio3, + &(), + [ + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1], + vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0], + vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1], + ] + ) + .unwrap(), + vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2], + ); + } + #[test] #[cfg(feature = "multithreaded")] fn test_prio3_sum_vec_multithreaded() { @@ -1598,7 +1746,7 @@ mod tests { fn test_fixed_vec( fp_0: Fx, - prio3: Prio3, XofShake128, 16>, + prio3: Prio3, XofTurboShake128, 16>, ) where Fx: Fixed + CompatibleFloat + std::ops::Neg, PE: Eq + ParallelSumGadget> + Clone + 'static, @@ -1690,7 +1838,7 @@ mod tests { fp_4_inv: Fx, fp_8_inv: Fx, fp_16_inv: Fx, - prio3: Prio3, XofShake128, 16>, + prio3: Prio3, XofTurboShake128, 16>, ) where Fx: Fixed + CompatibleFloat + std::ops::Neg, PE: Eq + ParallelSumGadget> + Clone + 'static, @@ -1752,7 +1900,7 @@ mod tests { let (public_share, mut input_shares) = prio3 .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); - assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { + assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); let result = @@ -1823,7 +1971,7 @@ mod tests { } if let (Share::Helper(left), Share::Helper(right)) = - (&x.proof_share, &y.proof_share) + (&x.proofs_share, &y.proofs_share) { assert_ne!(left, right); } @@ -1847,7 +1995,7 @@ mod tests { thread_rng().fill(&mut verify_key[..]); let (public_share, input_shares) = prio3.shard(measurement, nonce)?; - let encoded_public_share = public_share.get_encoded(); + let encoded_public_share = public_share.get_encoded().unwrap(); let decoded_public_share = Prio3PublicShare::get_decoded_with_param(prio3, &encoded_public_share) .expect("failed to decode public share"); @@ -1858,7 +2006,7 @@ mod tests { ); for (agg_id, input_share) in input_shares.iter().enumerate() { - let encoded_input_share = input_share.get_encoded(); + let encoded_input_share = input_share.get_encoded().unwrap(); let decoded_input_share = Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), &encoded_input_share) .expect("failed to decode input share"); @@ -1875,7 +2023,7 @@ mod tests { let (prepare_state, prepare_share) = prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?; - let encoded_prepare_state = prepare_state.get_encoded(); + let encoded_prepare_state = prepare_state.get_encoded().unwrap(); let decoded_prepare_state = Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &encoded_prepare_state) .expect("failed to decode prepare state"); @@ -1885,7 +2033,7 @@ mod tests { encoded_prepare_state.len() ); - let encoded_prepare_share = prepare_share.get_encoded(); + let encoded_prepare_share = prepare_share.get_encoded().unwrap(); let decoded_prepare_share = Prio3PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share) .expect("failed to decode prepare share"); @@ -1903,7 +2051,7 @@ mod tests { .prepare_shares_to_prepare_message(&(), prepare_shares) .unwrap(); - let encoded_prepare_message = prepare_message.get_encoded(); + let encoded_prepare_message = prepare_message.get_encoded().unwrap(); let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param( &last_prepare_state.unwrap(), &encoded_prepare_message, @@ -1967,31 +2115,31 @@ mod tests { // Default. Prio3InputShare { measurement_share: Share::Leader(Vec::from([0])), - proof_share: Share::Leader(Vec::from([1])), + proofs_share: Share::Leader(Vec::from([1])), joint_rand_blind: Some(Seed([2])), }, // Modified measurement share. Prio3InputShare { measurement_share: Share::Leader(Vec::from([100])), - proof_share: Share::Leader(Vec::from([1])), + proofs_share: Share::Leader(Vec::from([1])), joint_rand_blind: Some(Seed([2])), }, // Modified proof share. Prio3InputShare { measurement_share: Share::Leader(Vec::from([0])), - proof_share: Share::Leader(Vec::from([101])), + proofs_share: Share::Leader(Vec::from([101])), joint_rand_blind: Some(Seed([2])), }, // Modified joint_rand_blind. Prio3InputShare { measurement_share: Share::Leader(Vec::from([0])), - proof_share: Share::Leader(Vec::from([1])), + proofs_share: Share::Leader(Vec::from([1])), joint_rand_blind: Some(Seed([102])), }, // Missing joint_rand_blind. Prio3InputShare { measurement_share: Share::Leader(Vec::from([0])), - proof_share: Share::Leader(Vec::from([1])), + proofs_share: Share::Leader(Vec::from([1])), joint_rand_blind: None, }, ]) @@ -2002,22 +2150,22 @@ mod tests { equality_comparison_test(&[ // Default. Prio3PrepareShare { - verifier: Vec::from([0]), + verifiers: Vec::from([0]), joint_rand_part: Some(Seed([1])), }, // Modified verifier. Prio3PrepareShare { - verifier: Vec::from([100]), + verifiers: Vec::from([100]), joint_rand_part: Some(Seed([1])), }, // Modified joint_rand_part. Prio3PrepareShare { - verifier: Vec::from([0]), + verifiers: Vec::from([0]), joint_rand_part: Some(Seed([101])), }, // Missing joint_rand_part. Prio3PrepareShare { - verifier: Vec::from([0]), + verifiers: Vec::from([0]), joint_rand_part: None, }, ]) @@ -2049,42 +2197,42 @@ mod tests { measurement_share: Share::Leader(Vec::from([0])), joint_rand_seed: Some(Seed([1])), agg_id: 2, - verifier_len: 3, + verifiers_len: 3, }, // Modified measurement share. Prio3PrepareState { measurement_share: Share::Leader(Vec::from([100])), joint_rand_seed: Some(Seed([1])), agg_id: 2, - verifier_len: 3, + verifiers_len: 3, }, // Modified joint_rand_seed. Prio3PrepareState { measurement_share: Share::Leader(Vec::from([0])), joint_rand_seed: Some(Seed([101])), agg_id: 2, - verifier_len: 3, + verifiers_len: 3, }, // Missing joint_rand_seed. Prio3PrepareState { measurement_share: Share::Leader(Vec::from([0])), joint_rand_seed: None, agg_id: 2, - verifier_len: 3, + verifiers_len: 3, }, // Modified agg_id. Prio3PrepareState { measurement_share: Share::Leader(Vec::from([0])), joint_rand_seed: Some(Seed([1])), agg_id: 102, - verifier_len: 3, + verifiers_len: 3, }, // Modified verifier_len. Prio3PrepareState { measurement_share: Share::Leader(Vec::from([0])), joint_rand_seed: Some(Seed([1])), agg_id: 2, - verifier_len: 103, + verifiers_len: 103, }, ]) } -- cgit v1.2.3