diff options
Diffstat (limited to 'third_party/rust/prio/src/vdaf/dummy.rs')
-rw-r--r-- | third_party/rust/prio/src/vdaf/dummy.rs | 147 |
1 files changed, 126 insertions, 21 deletions
diff --git a/third_party/rust/prio/src/vdaf/dummy.rs b/third_party/rust/prio/src/vdaf/dummy.rs index 507e7916bb..2bb0f96b8a 100644 --- a/third_party/rust/prio/src/vdaf/dummy.rs +++ b/third_party/rust/prio/src/vdaf/dummy.rs @@ -12,6 +12,9 @@ use crate::{ use rand::random; use std::{fmt::Debug, io::Cursor, sync::Arc}; +/// The Dummy VDAF does summation modulus 256 so we can predict aggregation results. +const MODULUS: u64 = u8::MAX as u64 + 1; + type ArcPrepInitFn = Arc<dyn Fn(&AggregationParam) -> Result<(), VdafError> + 'static + Send + Sync>; type ArcPrepStepFn = Arc< @@ -49,7 +52,9 @@ impl Vdaf { move |state| -> Result<PrepareTransition<Self, 0, 16>, VdafError> { let new_round = state.current_round + 1; if new_round == rounds { - Ok(PrepareTransition::Finish(OutputShare(state.input_share))) + Ok(PrepareTransition::Finish(OutputShare(u64::from( + state.input_share, + )))) } else { Ok(PrepareTransition::Continue( PrepareState { @@ -76,7 +81,7 @@ impl Vdaf { self } - /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_step`]. + /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_next`]. pub fn with_prep_step_fn< F: Fn(&PrepareState) -> Result<PrepareTransition<Self, 0, 16>, VdafError>, >( @@ -98,16 +103,18 @@ impl Default for Vdaf { } impl vdaf::Vdaf for Vdaf { - const ID: u32 = 0xFFFF0000; - type Measurement = u8; - type AggregateResult = u8; + type AggregateResult = u64; type AggregationParam = AggregationParam; type PublicShare = (); type InputShare = InputShare; type OutputShare = OutputShare; type AggregateShare = AggregateShare; + fn algorithm_id(&self) -> u32 { + 0xFFFF0000 + } + fn num_aggregators(&self) -> usize { 2 } @@ -155,7 +162,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn aggregate<M: IntoIterator<Item = Self::OutputShare>>( &self, - _: &Self::AggregationParam, + _aggregation_param: &Self::AggregationParam, output_shares: M, ) -> Result<Self::AggregateShare, VdafError> { let mut aggregate_share = AggregateShare(0); @@ -184,12 +191,28 @@ impl vdaf::Client<16> for Vdaf { } } +impl vdaf::Collector for Vdaf { + fn unshard<M: IntoIterator<Item = Self::AggregateShare>>( + &self, + aggregation_param: &Self::AggregationParam, + agg_shares: M, + _num_measurements: usize, + ) -> Result<Self::AggregateResult, VdafError> { + Ok(agg_shares + .into_iter() + .fold(0, |acc, share| (acc + share.0) % MODULUS) + // Sum in the aggregation parameter so that collections over the same measurements with + // varying parameters will yield predictable but distinct results. + + u64::from(aggregation_param.0)) + } +} + /// A dummy input share. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] pub struct InputShare(pub u8); impl Encode for InputShare { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { self.0.encode(bytes) } @@ -209,7 +232,7 @@ impl Decode for InputShare { pub struct AggregationParam(pub u8); impl Encode for AggregationParam { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { self.0.encode(bytes) } @@ -226,17 +249,17 @@ impl Decode for AggregationParam { /// Dummy output share. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct OutputShare(pub u8); +pub struct OutputShare(pub u64); impl Decode for OutputShare { fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { - Ok(Self(u8::decode(bytes)?)) + Ok(Self(u64::decode(bytes)?)) } } impl Encode for OutputShare { - fn encode(&self, bytes: &mut Vec<u8>) { - self.0.encode(bytes); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + self.0.encode(bytes) } fn encoded_len(&self) -> Option<usize> { @@ -252,9 +275,9 @@ pub struct PrepareState { } impl Encode for PrepareState { - fn encode(&self, bytes: &mut Vec<u8>) { - self.input_share.encode(bytes); - self.current_round.encode(bytes); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + self.input_share.encode(bytes)?; + self.current_round.encode(bytes) } fn encoded_len(&self) -> Option<usize> { @@ -282,31 +305,30 @@ impl Aggregatable for AggregateShare { type OutputShare = OutputShare; fn merge(&mut self, other: &Self) -> Result<(), VdafError> { - self.0 += other.0; + self.0 = (self.0 + other.0) % MODULUS; Ok(()) } fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> { - self.0 += u64::from(out_share.0); + self.0 = (self.0 + out_share.0) % MODULUS; Ok(()) } } impl From<OutputShare> for AggregateShare { fn from(out_share: OutputShare) -> Self { - Self(u64::from(out_share.0)) + Self(out_share.0) } } impl Decode for AggregateShare { fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { - let val = u64::decode(bytes)?; - Ok(Self(val)) + Ok(Self(u64::decode(bytes)?)) } } impl Encode for AggregateShare { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { self.0.encode(bytes) } @@ -314,3 +336,86 @@ impl Encode for AggregateShare { self.0.encoded_len() } } + +/// Returns the aggregate result that the dummy VDAF would compute over the provided measurements, +/// for the provided aggregation parameter. +pub fn expected_aggregate_result<M>(aggregation_parameter: u8, measurements: M) -> u64 +where + M: IntoIterator<Item = u8>, +{ + (measurements.into_iter().map(u64::from).sum::<u64>()) % MODULUS + + u64::from(aggregation_parameter) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vdaf::{test_utils::run_vdaf_sharded, Client}; + use rand::prelude::*; + + fn run_test(rounds: u32, aggregation_parameter: u8) { + let vdaf = Vdaf::new(rounds); + let mut verify_key = [0; 0]; + thread_rng().fill(&mut verify_key[..]); + let measurements = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]; + + let mut sharded_measurements = Vec::new(); + for measurement in measurements { + let nonce = thread_rng().gen(); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + + sharded_measurements.push((public_share, nonce, input_shares)); + } + + let result = run_vdaf_sharded( + &vdaf, + &AggregationParam(aggregation_parameter), + sharded_measurements.clone(), + ) + .unwrap(); + assert_eq!( + result, + expected_aggregate_result(aggregation_parameter, measurements) + ); + } + + #[test] + fn single_round_agg_param_10() { + run_test(1, 10) + } + + #[test] + fn single_round_agg_param_20() { + run_test(1, 20) + } + + #[test] + fn single_round_agg_param_32() { + run_test(1, 32) + } + + #[test] + fn single_round_agg_param_u8_max() { + run_test(1, u8::MAX) + } + + #[test] + fn two_round_agg_param_10() { + run_test(2, 10) + } + + #[test] + fn two_round_agg_param_20() { + run_test(2, 20) + } + + #[test] + fn two_round_agg_param_32() { + run_test(2, 32) + } + + #[test] + fn two_round_agg_param_u8_max() { + run_test(2, u8::MAX) + } +} |