// SPDX-License-Identifier: MPL-2.0 //! Backwards-compatible port of the ENPA Prio system to a VDAF. use crate::{ codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::{ decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldPrio2, }, prng::Prng, vdaf::{ prio2::{ client::{self as v2_client, proof_length}, server as v2_server, }, xof::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError, }, }; use hmac::{Hmac, Mac}; use rand_core::RngCore; use sha2::Sha256; use std::{convert::TryFrom, io::Cursor}; use subtle::{Choice, ConstantTimeEq}; mod client; mod server; #[cfg(test)] mod test_vector; /// The Prio2 VDAF. It supports the same measurement type as /// [`Prio3SumVec`](crate::vdaf::prio3::Prio3SumVec) with `bits == 1` but uses the proof system and /// finite field deployed in ENPA. #[derive(Clone, Debug)] pub struct Prio2 { input_len: usize, } impl Prio2 { /// Returns an instance of the VDAF for the given input length. pub fn new(input_len: usize) -> Result { let n = (input_len + 1).next_power_of_two(); if let Ok(size) = u32::try_from(2 * n) { if size > FieldPrio2::generator_order() { return Err(VdafError::Uncategorized( "input size exceeds field capacity".into(), )); } } else { return Err(VdafError::Uncategorized( "input size exceeds memory capacity".into(), )); } Ok(Prio2 { input_len }) } /// Prepare an input share for aggregation using the given field element `query_rand` to /// compute the verifier share. /// /// In the [`Aggregator`] trait implementation for [`Prio2`], the query randomness is computed /// jointly by the Aggregators. This method is designed to be used in applications, like ENPA, /// in which the query randomness is instead chosen by a third-party. pub fn prepare_init_with_query_rand( &self, query_rand: FieldPrio2, input_share: &Share, is_leader: bool, ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { let expanded_data: Option> = match input_share { Share::Leader(_) => None, Share::Helper(ref seed) => { let prng = Prng::from_prio2_seed(seed.as_ref()); Some(prng.take(proof_length(self.input_len)).collect()) } }; let data = match input_share { Share::Leader(ref data) => data, Share::Helper(_) => expanded_data.as_ref().unwrap(), }; let verifier_share = v2_server::generate_verification_message( self.input_len, query_rand, data, // Combined input and proof shares is_leader, ) .map_err(|e| VdafError::Uncategorized(e.to_string()))?; let truncated_share = match input_share { Share::Leader(data) => Share::Leader(data[..self.input_len].to_vec()), Share::Helper(seed) => Share::Helper(seed.clone()), }; Ok(( Prio2PrepareState(truncated_share), Prio2PrepareShare(verifier_share), )) } /// Choose a random point for polynomial evaluation. /// /// The point returned is not one of the roots used for polynomial interpolation. pub(crate) fn choose_eval_at(&self, prng: &mut Prng) -> FieldPrio2 where S: RngCore, { // Make sure the query randomness isn't a root of unity. Evaluating the proof at any of // these points would be a privacy violation, since these points were used by the prover to // construct the wire polynomials. let n = (self.input_len + 1).next_power_of_two(); let proof_length = 2 * n; loop { let eval_at: FieldPrio2 = prng.get(); // Unwrap safety: the constructor checks that this conversion succeeds. if eval_at.pow(u32::try_from(proof_length).unwrap()) != FieldPrio2::one() { return eval_at; } } } } impl Vdaf for Prio2 { type Measurement = Vec; type AggregateResult = Vec; type AggregationParam = (); type PublicShare = (); type InputShare = Share; type OutputShare = OutputShare; type AggregateShare = AggregateShare; fn algorithm_id(&self) -> u32 { 0xFFFF0000 } fn num_aggregators(&self) -> usize { // Prio2 can easily be extended to support more than two Aggregators. 2 } } impl Client<16> for Prio2 { fn shard( &self, measurement: &Vec, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { if measurement.len() != self.input_len { return Err(VdafError::Uncategorized("incorrect input length".into())); } let mut input: Vec = Vec::with_capacity(measurement.len()); for int in measurement { input.push((*int).into()); } let mut mem = v2_client::ClientMemory::new(self.input_len)?; let copy_data = |share_data: &mut [FieldPrio2]| { share_data[..].clone_from_slice(&input); }; let mut leader_data = mem.prove_with(self.input_len, copy_data); let helper_seed = Seed::generate()?; let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref()); for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) { *s1 -= d; } Ok(( (), vec![Share::Leader(leader_data), Share::Helper(helper_seed)], )) } } /// State of each [`Aggregator`] during the Preparation phase. #[derive(Clone, Debug)] pub struct Prio2PrepareState(Share); impl PartialEq for Prio2PrepareState { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() } } impl Eq for Prio2PrepareState {} impl ConstantTimeEq for Prio2PrepareState { fn ct_eq(&self, other: &Self) -> Choice { self.0.ct_eq(&other.0) } } impl Encode for Prio2PrepareState { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { self.0.encode(bytes) } fn encoded_len(&self) -> Option { self.0.encoded_len() } } impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState { fn decode_with_param( (prio2, agg_id): &(&'a Prio2, usize), bytes: &mut Cursor<&[u8]>, ) -> Result { let share_decoder = if *agg_id == 0 { ShareDecodingParameter::Leader(prio2.input_len) } else { ShareDecodingParameter::Helper }; let out_share = Share::decode_with_param(&share_decoder, bytes)?; Ok(Self(out_share)) } } /// Message emitted by each [`Aggregator`] during the Preparation phase. #[derive(Clone, Debug)] pub struct Prio2PrepareShare(v2_server::VerificationMessage); impl Encode for Prio2PrepareShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { self.0.f_r.encode(bytes)?; self.0.g_r.encode(bytes)?; self.0.h_r.encode(bytes) } fn encoded_len(&self) -> Option { Some(FieldPrio2::ENCODED_SIZE * 3) } } impl ParameterizedDecode for Prio2PrepareShare { fn decode_with_param( _state: &Prio2PrepareState, bytes: &mut Cursor<&[u8]>, ) -> Result { Ok(Self(v2_server::VerificationMessage { f_r: FieldPrio2::decode(bytes)?, g_r: FieldPrio2::decode(bytes)?, h_r: FieldPrio2::decode(bytes)?, })) } } impl Aggregator<32, 16> for Prio2 { type PrepareState = Prio2PrepareState; type PrepareShare = Prio2PrepareShare; type PrepareMessage = (); fn prepare_init( &self, agg_key: &[u8; 32], agg_id: usize, _agg_param: &Self::AggregationParam, nonce: &[u8; 16], _public_share: &Self::PublicShare, input_share: &Share, ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { let is_leader = role_try_from(agg_id)?; // In the ENPA Prio system, the query randomness is generated by a third party and // distributed to the Aggregators after they receive their input shares. In a VDAF, shared // randomness is derived from a nonce selected by the client. For Prio2 we compute the // query using HMAC-SHA256 evaluated over the nonce. // // Unwrap safety: new_from_slice() is infallible for Hmac. let mut mac = Hmac::::new_from_slice(agg_key).unwrap(); mac.update(nonce); let hmac_tag = mac.finalize(); let mut prng = Prng::from_prio2_seed(&hmac_tag.into_bytes().into()); let query_rand = self.choose_eval_at(&mut prng); self.prepare_init_with_query_rand(query_rand, input_share, is_leader) } fn prepare_shares_to_prepare_message>( &self, _: &Self::AggregationParam, inputs: M, ) -> Result<(), VdafError> { let verifier_shares: Vec> = inputs.into_iter().map(|msg| msg.0).collect(); if verifier_shares.len() != 2 { return Err(VdafError::Uncategorized( "wrong number of verifier shares".into(), )); } if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) { return Err(VdafError::Uncategorized( "proof verifier check failed".into(), )); } Ok(()) } fn prepare_next( &self, state: Prio2PrepareState, _input: (), ) -> Result, VdafError> { let data = match state.0 { Share::Leader(data) => data, Share::Helper(seed) => { let prng = Prng::from_prio2_seed(seed.as_ref()); prng.take(self.input_len).collect() } }; Ok(PrepareTransition::Finish(OutputShare::from(data))) } fn aggregate>>( &self, _agg_param: &Self::AggregationParam, out_shares: M, ) -> Result, VdafError> { let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); for out_share in out_shares.into_iter() { agg_share.accumulate(&out_share)?; } Ok(agg_share) } } impl Collector for Prio2 { fn unshard>>( &self, _agg_param: &Self::AggregationParam, agg_shares: M, _num_measurements: usize, ) -> Result, VdafError> { let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); for agg_share in agg_shares.into_iter() { agg.merge(&agg_share)?; } Ok(agg.0.into_iter().map(u32::from).collect()) } } impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share { fn decode_with_param( (prio2, agg_id): &(&'a Prio2, usize), bytes: &mut Cursor<&[u8]>, ) -> Result { let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?; let decoder = if is_leader { ShareDecodingParameter::Leader(proof_length(prio2.input_len)) } else { ShareDecodingParameter::Helper }; Share::decode_with_param(&decoder, bytes) } } impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for OutputShare where F: FieldElement, { fn decode_with_param( (prio2, _): &(&'a Prio2, &'a ()), bytes: &mut Cursor<&[u8]>, ) -> Result { decode_fieldvec(prio2.input_len, bytes).map(Self) } } impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for AggregateShare where F: FieldElement, { fn decode_with_param( (prio2, _): &(&'a Prio2, &'a ()), bytes: &mut Cursor<&[u8]>, ) -> Result { decode_fieldvec(prio2.input_len, bytes).map(Self) } } fn role_try_from(agg_id: usize) -> Result { match agg_id { 0 => Ok(true), 1 => Ok(false), _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())), } } #[cfg(test)] mod tests { use super::*; use crate::vdaf::{ equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, test_utils::run_vdaf, }; use assert_matches::assert_matches; use rand::prelude::*; #[test] fn run_prio2() { let prio2 = Prio2::new(6).unwrap(); assert_eq!( run_vdaf( &prio2, &(), [ vec![0, 0, 0, 0, 1, 0], vec![0, 1, 0, 0, 0, 0], vec![0, 1, 1, 0, 0, 0], vec![1, 1, 1, 0, 0, 0], vec![0, 0, 0, 0, 1, 1], ] ) .unwrap(), vec![1, 3, 2, 0, 2, 1], ); } #[test] fn prepare_state_serialization() { let mut rng = thread_rng(); let verify_key = rng.gen::<[u8; 32]>(); let nonce = rng.gen::<[u8; 16]>(); let data = vec![0, 0, 1, 1, 0]; let prio2 = Prio2::new(data.len()).unwrap(); let (public_share, input_shares) = prio2.shard(&data, &nonce).unwrap(); for (agg_id, input_share) in input_shares.iter().enumerate() { let (prepare_state, prepare_share) = prio2 .prepare_init( &verify_key, agg_id, &(), &[0; 16], &public_share, input_share, ) .unwrap(); let encoded_prepare_state = prepare_state.get_encoded().unwrap(); let decoded_prepare_state = Prio2PrepareState::get_decoded_with_param( &(&prio2, agg_id), &encoded_prepare_state, ) .expect("failed to decode prepare state"); assert_eq!(decoded_prepare_state, prepare_state); assert_eq!( prepare_state.encoded_len().unwrap(), encoded_prepare_state.len() ); let encoded_prepare_share = prepare_share.get_encoded().unwrap(); let decoded_prepare_share = Prio2PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share) .expect("failed to decode prepare share"); assert_eq!(decoded_prepare_share.0.f_r, prepare_share.0.f_r); assert_eq!(decoded_prepare_share.0.g_r, prepare_share.0.g_r); assert_eq!(decoded_prepare_share.0.h_r, prepare_share.0.h_r); assert_eq!( prepare_share.encoded_len().unwrap(), encoded_prepare_share.len() ); } } #[test] fn roundtrip_output_share() { let vdaf = Prio2::new(31).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 31); } #[test] fn roundtrip_aggregate_share() { let vdaf = Prio2::new(31).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 31); } #[test] fn priov2_backward_compatibility() { let test_vector: Priov2TestVector = serde_json::from_str(include_str!("test_vec/prio2/fieldpriov2.json")).unwrap(); let vdaf = Prio2::new(test_vector.dimension).unwrap(); let mut leader_output_shares = Vec::new(); let mut helper_output_shares = Vec::new(); for (server_1_share, server_2_share) in test_vector .server_1_decrypted_shares .iter() .zip(&test_vector.server_2_decrypted_shares) { let input_share_1 = Share::get_decoded_with_param(&(&vdaf, 0), server_1_share).unwrap(); let input_share_2 = Share::get_decoded_with_param(&(&vdaf, 1), server_2_share).unwrap(); let (prepare_state_1, prepare_share_1) = vdaf .prepare_init(&[0; 32], 0, &(), &[0; 16], &(), &input_share_1) .unwrap(); let (prepare_state_2, prepare_share_2) = vdaf .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2) .unwrap(); vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2]) .unwrap(); let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap(); let output_share_1 = assert_matches!(transition_1, PrepareTransition::Finish(out) => out); let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap(); let output_share_2 = assert_matches!(transition_2, PrepareTransition::Finish(out) => out); leader_output_shares.push(output_share_1); helper_output_shares.push(output_share_2); } let leader_aggregate_share = vdaf.aggregate(&(), leader_output_shares).unwrap(); let helper_aggregate_share = vdaf.aggregate(&(), helper_output_shares).unwrap(); let aggregate_result = vdaf .unshard( &(), [leader_aggregate_share, helper_aggregate_share], test_vector.server_1_decrypted_shares.len(), ) .unwrap(); let reconstructed = aggregate_result .into_iter() .map(FieldPrio2::from) .collect::>(); assert_eq!(reconstructed, test_vector.reference_sum); } #[test] fn prepare_state_equality_test() { equality_comparison_test(&[ Prio2PrepareState(Share::Leader(Vec::from([ FieldPrio2::from(0), FieldPrio2::from(1), ]))), Prio2PrepareState(Share::Leader(Vec::from([ FieldPrio2::from(1), FieldPrio2::from(0), ]))), Prio2PrepareState(Share::Helper(Seed( (0..32).collect::>().try_into().unwrap(), ))), Prio2PrepareState(Share::Helper(Seed( (1..33).collect::>().try_into().unwrap(), ))), ]) } }