summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/dummy.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/dummy.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/dummy.rs147
1 files changed, 126 insertions, 21 deletions
diff --git a/third_party/rust/prio/src/vdaf/dummy.rs b/third_party/rust/prio/src/vdaf/dummy.rs
index 507e7916bb..2bb0f96b8a 100644
--- a/third_party/rust/prio/src/vdaf/dummy.rs
+++ b/third_party/rust/prio/src/vdaf/dummy.rs
@@ -12,6 +12,9 @@ use crate::{
use rand::random;
use std::{fmt::Debug, io::Cursor, sync::Arc};
+/// The Dummy VDAF does summation modulus 256 so we can predict aggregation results.
+const MODULUS: u64 = u8::MAX as u64 + 1;
+
type ArcPrepInitFn =
Arc<dyn Fn(&AggregationParam) -> Result<(), VdafError> + 'static + Send + Sync>;
type ArcPrepStepFn = Arc<
@@ -49,7 +52,9 @@ impl Vdaf {
move |state| -> Result<PrepareTransition<Self, 0, 16>, VdafError> {
let new_round = state.current_round + 1;
if new_round == rounds {
- Ok(PrepareTransition::Finish(OutputShare(state.input_share)))
+ Ok(PrepareTransition::Finish(OutputShare(u64::from(
+ state.input_share,
+ ))))
} else {
Ok(PrepareTransition::Continue(
PrepareState {
@@ -76,7 +81,7 @@ impl Vdaf {
self
}
- /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_step`].
+ /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_next`].
pub fn with_prep_step_fn<
F: Fn(&PrepareState) -> Result<PrepareTransition<Self, 0, 16>, VdafError>,
>(
@@ -98,16 +103,18 @@ impl Default for Vdaf {
}
impl vdaf::Vdaf for Vdaf {
- const ID: u32 = 0xFFFF0000;
-
type Measurement = u8;
- type AggregateResult = u8;
+ type AggregateResult = u64;
type AggregationParam = AggregationParam;
type PublicShare = ();
type InputShare = InputShare;
type OutputShare = OutputShare;
type AggregateShare = AggregateShare;
+ fn algorithm_id(&self) -> u32 {
+ 0xFFFF0000
+ }
+
fn num_aggregators(&self) -> usize {
2
}
@@ -155,7 +162,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf {
fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
&self,
- _: &Self::AggregationParam,
+ _aggregation_param: &Self::AggregationParam,
output_shares: M,
) -> Result<Self::AggregateShare, VdafError> {
let mut aggregate_share = AggregateShare(0);
@@ -184,12 +191,28 @@ impl vdaf::Client<16> for Vdaf {
}
}
+impl vdaf::Collector for Vdaf {
+ fn unshard<M: IntoIterator<Item = Self::AggregateShare>>(
+ &self,
+ aggregation_param: &Self::AggregationParam,
+ agg_shares: M,
+ _num_measurements: usize,
+ ) -> Result<Self::AggregateResult, VdafError> {
+ Ok(agg_shares
+ .into_iter()
+ .fold(0, |acc, share| (acc + share.0) % MODULUS)
+ // Sum in the aggregation parameter so that collections over the same measurements with
+ // varying parameters will yield predictable but distinct results.
+ + u64::from(aggregation_param.0))
+ }
+}
+
/// A dummy input share.
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct InputShare(pub u8);
impl Encode for InputShare {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
@@ -209,7 +232,7 @@ impl Decode for InputShare {
pub struct AggregationParam(pub u8);
impl Encode for AggregationParam {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
@@ -226,17 +249,17 @@ impl Decode for AggregationParam {
/// Dummy output share.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub struct OutputShare(pub u8);
+pub struct OutputShare(pub u64);
impl Decode for OutputShare {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
- Ok(Self(u8::decode(bytes)?))
+ Ok(Self(u64::decode(bytes)?))
}
}
impl Encode for OutputShare {
- fn encode(&self, bytes: &mut Vec<u8>) {
- self.0.encode(bytes);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ self.0.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
@@ -252,9 +275,9 @@ pub struct PrepareState {
}
impl Encode for PrepareState {
- fn encode(&self, bytes: &mut Vec<u8>) {
- self.input_share.encode(bytes);
- self.current_round.encode(bytes);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ self.input_share.encode(bytes)?;
+ self.current_round.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
@@ -282,31 +305,30 @@ impl Aggregatable for AggregateShare {
type OutputShare = OutputShare;
fn merge(&mut self, other: &Self) -> Result<(), VdafError> {
- self.0 += other.0;
+ self.0 = (self.0 + other.0) % MODULUS;
Ok(())
}
fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> {
- self.0 += u64::from(out_share.0);
+ self.0 = (self.0 + out_share.0) % MODULUS;
Ok(())
}
}
impl From<OutputShare> for AggregateShare {
fn from(out_share: OutputShare) -> Self {
- Self(u64::from(out_share.0))
+ Self(out_share.0)
}
}
impl Decode for AggregateShare {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
- let val = u64::decode(bytes)?;
- Ok(Self(val))
+ Ok(Self(u64::decode(bytes)?))
}
}
impl Encode for AggregateShare {
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
@@ -314,3 +336,86 @@ impl Encode for AggregateShare {
self.0.encoded_len()
}
}
+
+/// Returns the aggregate result that the dummy VDAF would compute over the provided measurements,
+/// for the provided aggregation parameter.
+pub fn expected_aggregate_result<M>(aggregation_parameter: u8, measurements: M) -> u64
+where
+ M: IntoIterator<Item = u8>,
+{
+ (measurements.into_iter().map(u64::from).sum::<u64>()) % MODULUS
+ + u64::from(aggregation_parameter)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::vdaf::{test_utils::run_vdaf_sharded, Client};
+ use rand::prelude::*;
+
+ fn run_test(rounds: u32, aggregation_parameter: u8) {
+ let vdaf = Vdaf::new(rounds);
+ let mut verify_key = [0; 0];
+ thread_rng().fill(&mut verify_key[..]);
+ let measurements = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100];
+
+ let mut sharded_measurements = Vec::new();
+ for measurement in measurements {
+ let nonce = thread_rng().gen();
+ let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap();
+
+ sharded_measurements.push((public_share, nonce, input_shares));
+ }
+
+ let result = run_vdaf_sharded(
+ &vdaf,
+ &AggregationParam(aggregation_parameter),
+ sharded_measurements.clone(),
+ )
+ .unwrap();
+ assert_eq!(
+ result,
+ expected_aggregate_result(aggregation_parameter, measurements)
+ );
+ }
+
+ #[test]
+ fn single_round_agg_param_10() {
+ run_test(1, 10)
+ }
+
+ #[test]
+ fn single_round_agg_param_20() {
+ run_test(1, 20)
+ }
+
+ #[test]
+ fn single_round_agg_param_32() {
+ run_test(1, 32)
+ }
+
+ #[test]
+ fn single_round_agg_param_u8_max() {
+ run_test(1, u8::MAX)
+ }
+
+ #[test]
+ fn two_round_agg_param_10() {
+ run_test(2, 10)
+ }
+
+ #[test]
+ fn two_round_agg_param_20() {
+ run_test(2, 20)
+ }
+
+ #[test]
+ fn two_round_agg_param_32() {
+ run_test(2, 32)
+ }
+
+ #[test]
+ fn two_round_agg_param_u8_max() {
+ run_test(2, u8::MAX)
+ }
+}