diff options
Diffstat (limited to 'third_party/rust/prio/src/vdaf')
-rw-r--r-- | third_party/rust/prio/src/vdaf/poplar1.rs | 933 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/prg.rs | 239 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/prio2.rs | 425 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/prio3.rs | 1168 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/prio3_test.rs | 162 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json | 7 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json | 37 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json | 53 | ||||
-rw-r--r-- | third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json | 38 |
9 files changed, 3062 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/vdaf/poplar1.rs b/third_party/rust/prio/src/vdaf/poplar1.rs new file mode 100644 index 0000000000..f6ab110ebb --- /dev/null +++ b/third_party/rust/prio/src/vdaf/poplar1.rs @@ -0,0 +1,933 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! **(NOTE: This module is experimental. Applications should not use it yet.)** This module +//! partially implements the core component of the Poplar protocol [[BBCG+21]]. Named for the +//! Poplar1 section of [[draft-irtf-cfrg-vdaf-03]], the specification of this VDAF is under active +//! development. Thus this code should be regarded as experimental and not compliant with any +//! existing speciication. +//! +//! TODO Make the input shares stateful so that applications can efficiently evaluate the IDPF over +//! multiple rounds. Question: Will this require API changes to [`crate::vdaf::Vdaf`]? +//! +//! TODO Update trait [`Idpf`] so that the IDPF can have a different field type at the leaves than +//! at the inner nodes. +//! +//! TODO Implement the efficient IDPF of [[BBCG+21]]. [`ToyIdpf`] is not space efficient and is +//! merely intended as a proof-of-concept. +//! +//! [BBCG+21]: https://eprint.iacr.org/2021/017 +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +use std::cmp::Ordering; +use std::collections::{BTreeMap, BTreeSet}; +use std::convert::{TryFrom, TryInto}; +use std::fmt::Debug; +use std::io::Cursor; +use std::iter::FromIterator; +use std::marker::PhantomData; + +use crate::codec::{ + decode_u16_items, decode_u24_items, encode_u16_items, encode_u24_items, CodecError, Decode, + Encode, ParameterizedDecode, +}; +use crate::field::{split_vector, FieldElement}; +use crate::fp::log2; +use crate::prng::Prng; +use crate::vdaf::prg::{Prg, Seed}; +use crate::vdaf::{ + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, + Share, ShareDecodingParameter, Vdaf, VdafError, +}; + +/// An input for an IDPF ([`Idpf`]). +/// +/// TODO Make this an associated type of `Idpf`. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct IdpfInput { + index: usize, + level: usize, +} + +impl IdpfInput { + /// Constructs an IDPF input using the first `level` bits of `data`. + pub fn new(data: &[u8], level: usize) -> Result<Self, VdafError> { + if level > data.len() << 3 { + return Err(VdafError::Uncategorized(format!( + "desired bit length ({} bits) exceeds data length ({} bytes)", + level, + data.len() + ))); + } + + let mut index = 0; + let mut i = 0; + for byte in data { + for j in 0..8 { + let bit = (byte >> j) & 1; + if i < level { + index |= (bit as usize) << i; + } + i += 1; + } + } + + Ok(Self { index, level }) + } + + /// Construct a new input that is a prefix of `self`. Bounds checking is performed by the + /// caller. + fn prefix(&self, level: usize) -> Self { + let index = self.index & ((1 << level) - 1); + Self { index, level } + } + + /// Return the position of `self` in the look-up table of `ToyIdpf`. + fn data_index(&self) -> usize { + self.index | (1 << self.level) + } +} + +impl Ord for IdpfInput { + fn cmp(&self, other: &Self) -> Ordering { + match self.level.cmp(&other.level) { + Ordering::Equal => self.index.cmp(&other.index), + ord => ord, + } + } +} + +impl PartialOrd for IdpfInput { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Encode for IdpfInput { + fn encode(&self, bytes: &mut Vec<u8>) { + (self.index as u64).encode(bytes); + (self.level as u64).encode(bytes); + } +} + +impl Decode for IdpfInput { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let index = u64::decode(bytes)? as usize; + let level = u64::decode(bytes)? as usize; + + Ok(Self { index, level }) + } +} + +/// An Incremental Distributed Point Function (IDPF), as defined by [[BBCG+21]]. +/// +/// [BBCG+21]: https://eprint.iacr.org/2021/017 +// +// NOTE(cjpatton) The real IDPF API probably needs to be stateful. +pub trait Idpf<const KEY_LEN: usize, const OUT_LEN: usize>: + Sized + Clone + Debug + Encode + Decode +{ + /// The finite field over which the IDPF is defined. + // + // NOTE(cjpatton) The IDPF of [BBCG+21] might use different fields for different levels of the + // prefix tree. + type Field: FieldElement; + + /// Generate and return a sequence of IDPF shares for `input`. Parameter `output` is an + /// iterator that is invoked to get the output value for each successive level of the prefix + /// tree. + fn gen<M: IntoIterator<Item = [Self::Field; OUT_LEN]>>( + input: &IdpfInput, + values: M, + ) -> Result<[Self; KEY_LEN], VdafError>; + + /// Evaluate an IDPF share on `prefix`. + fn eval(&self, prefix: &IdpfInput) -> Result<[Self::Field; OUT_LEN], VdafError>; +} + +/// A "toy" IDPF used for demonstration purposes. The space consumed by each share is `O(2^n)`, +/// where `n` is the length of the input. The size of each share is restricted to 1MB, so this IDPF +/// is only suitable for very short inputs. +// +// NOTE(cjpatton) It would be straight-forward to generalize this construction to any `KEY_LEN` and +// `OUT_LEN`. +#[derive(Debug, Clone)] +pub struct ToyIdpf<F> { + data0: Vec<F>, + data1: Vec<F>, + level: usize, +} + +impl<F: FieldElement> Idpf<2, 2> for ToyIdpf<F> { + type Field = F; + + fn gen<M: IntoIterator<Item = [Self::Field; 2]>>( + input: &IdpfInput, + values: M, + ) -> Result<[Self; 2], VdafError> { + const MAX_DATA_BYTES: usize = 1024 * 1024; // 1MB + + let max_input_len = + usize::try_from(log2((MAX_DATA_BYTES / F::ENCODED_SIZE) as u128)).unwrap(); + if input.level > max_input_len { + return Err(VdafError::Uncategorized(format!( + "input length ({}) exceeds maximum of ({})", + input.level, max_input_len + ))); + } + + let data_len = 1 << (input.level + 1); + let mut data0 = vec![F::zero(); data_len]; + let mut data1 = vec![F::zero(); data_len]; + let mut values = values.into_iter(); + for level in 0..input.level + 1 { + let value = values.next().unwrap(); + let index = input.prefix(level).data_index(); + data0[index] = value[0]; + data1[index] = value[1]; + } + + let mut data0 = split_vector(&data0, 2)?.into_iter(); + let mut data1 = split_vector(&data1, 2)?.into_iter(); + Ok([ + ToyIdpf { + data0: data0.next().unwrap(), + data1: data1.next().unwrap(), + level: input.level, + }, + ToyIdpf { + data0: data0.next().unwrap(), + data1: data1.next().unwrap(), + level: input.level, + }, + ]) + } + + fn eval(&self, prefix: &IdpfInput) -> Result<[F; 2], VdafError> { + if prefix.level > self.level { + return Err(VdafError::Uncategorized(format!( + "prefix length ({}) exceeds input length ({})", + prefix.level, self.level + ))); + } + + let index = prefix.data_index(); + Ok([self.data0[index], self.data1[index]]) + } +} + +impl<F: FieldElement> Encode for ToyIdpf<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + encode_u24_items(bytes, &(), &self.data0); + encode_u24_items(bytes, &(), &self.data1); + (self.level as u64).encode(bytes); + } +} + +impl<F: FieldElement> Decode for ToyIdpf<F> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let data0 = decode_u24_items(&(), bytes)?; + let data1 = decode_u24_items(&(), bytes)?; + let level = u64::decode(bytes)? as usize; + + Ok(Self { + data0, + data1, + level, + }) + } +} + +impl Encode for BTreeSet<IdpfInput> { + fn encode(&self, bytes: &mut Vec<u8>) { + // Encodes the aggregation parameter as a variable length vector of + // [`IdpfInput`], because the size of the aggregation parameter is not + // determined by the VDAF. + let items: Vec<IdpfInput> = self.iter().map(IdpfInput::clone).collect(); + encode_u24_items(bytes, &(), &items); + } +} + +impl Decode for BTreeSet<IdpfInput> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let inputs = decode_u24_items(&(), bytes)?; + Ok(Self::from_iter(inputs.into_iter())) + } +} + +/// An input share for the `poplar1` VDAF. +#[derive(Debug, Clone)] +pub struct Poplar1InputShare<I: Idpf<2, 2>, const L: usize> { + /// IDPF share of input + idpf: I, + + /// PRNG seed used to generate the aggregator's share of the randomness used in the first part + /// of the sketching protocol. + sketch_start_seed: Seed<L>, + + /// Aggregator's share of the randomness used in the second part of the sketching protocol. + sketch_next: Share<I::Field, L>, +} + +impl<I: Idpf<2, 2>, const L: usize> Encode for Poplar1InputShare<I, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + self.idpf.encode(bytes); + self.sketch_start_seed.encode(bytes); + self.sketch_next.encode(bytes); + } +} + +impl<'a, I, P, const L: usize> ParameterizedDecode<(&'a Poplar1<I, P, L>, usize)> + for Poplar1InputShare<I, L> +where + I: Idpf<2, 2>, +{ + fn decode_with_param( + (poplar1, agg_id): &(&'a Poplar1<I, P, L>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let idpf = I::decode(bytes)?; + let sketch_start_seed = Seed::decode(bytes)?; + let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?; + + let share_decoding_parameter = if is_leader { + // The sketch is two field elements for every bit of input, plus two more, corresponding + // to construction of shares in `Poplar1::shard`. + ShareDecodingParameter::Leader((poplar1.input_length + 1) * 2) + } else { + ShareDecodingParameter::Helper + }; + + let sketch_next = + <Share<I::Field, L>>::decode_with_param(&share_decoding_parameter, bytes)?; + + Ok(Self { + idpf, + sketch_start_seed, + sketch_next, + }) + } +} + +/// The poplar1 VDAF. +#[derive(Debug)] +pub struct Poplar1<I, P, const L: usize> { + input_length: usize, + phantom: PhantomData<(I, P)>, +} + +impl<I, P, const L: usize> Poplar1<I, P, L> { + /// Create an instance of the poplar1 VDAF. The caller provides a cipher suite `suite` used for + /// deriving pseudorandom sequences of field elements, and a input length in bits, corresponding + /// to `BITS` as defined in the [VDAF specification][1]. + /// + /// [1]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + pub fn new(bits: usize) -> Self { + Self { + input_length: bits, + phantom: PhantomData, + } + } +} + +impl<I, P, const L: usize> Clone for Poplar1<I, P, L> { + fn clone(&self) -> Self { + Self::new(self.input_length) + } +} +impl<I, P, const L: usize> Vdaf for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + // TODO: This currently uses a codepoint reserved for testing purposes. Replace it with + // 0x00001000 once the implementation is updated to match draft-irtf-cfrg-vdaf-03. + const ID: u32 = 0xFFFF0000; + type Measurement = IdpfInput; + type AggregateResult = BTreeMap<IdpfInput, u64>; + type AggregationParam = BTreeSet<IdpfInput>; + type PublicShare = (); // TODO: Replace this when the IDPF from [BBCGGI21] is implemented. + type InputShare = Poplar1InputShare<I, L>; + type OutputShare = OutputShare<I::Field>; + type AggregateShare = AggregateShare<I::Field>; + + fn num_aggregators(&self) -> usize { + 2 + } +} + +impl<I, P, const L: usize> Client for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + #[allow(clippy::many_single_char_names)] + fn shard(&self, input: &IdpfInput) -> Result<((), Vec<Poplar1InputShare<I, L>>), VdafError> { + let idpf_values: Vec<[I::Field; 2]> = Prng::new()? + .take(input.level + 1) + .map(|k| [I::Field::one(), k]) + .collect(); + + // For each level of the prefix tree, generate correlated randomness that the aggregators use + // to validate the output. See [BBCG+21, Appendix C.4]. + let leader_sketch_start_seed = Seed::generate()?; + let helper_sketch_start_seed = Seed::generate()?; + let helper_sketch_next_seed = Seed::generate()?; + let mut leader_sketch_start_prng: Prng<I::Field, _> = + Prng::from_seed_stream(P::seed_stream(&leader_sketch_start_seed, b"")); + let mut helper_sketch_start_prng: Prng<I::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper_sketch_start_seed, b"")); + let mut helper_sketch_next_prng: Prng<I::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper_sketch_next_seed, b"")); + let mut leader_sketch_next: Vec<I::Field> = Vec::with_capacity(2 * idpf_values.len()); + for value in idpf_values.iter() { + let k = value[1]; + + // [BBCG+21, Appendix C.4] + // + // $(a, b, c)$ + let a = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + let b = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + let c = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + + // $A = -2a + k$ + // $B = a^2 + b + -ak + c$ + let d = k - (a + a); + let e = (a * a) + b - (a * k) + c; + leader_sketch_next.push(d - helper_sketch_next_prng.get()); + leader_sketch_next.push(e - helper_sketch_next_prng.get()); + } + + // Generate IDPF shares of the data and authentication vectors. + let idpf_shares = I::gen(input, idpf_values)?; + + Ok(( + (), + vec![ + Poplar1InputShare { + idpf: idpf_shares[0].clone(), + sketch_start_seed: leader_sketch_start_seed, + sketch_next: Share::Leader(leader_sketch_next), + }, + Poplar1InputShare { + idpf: idpf_shares[1].clone(), + sketch_start_seed: helper_sketch_start_seed, + sketch_next: Share::Helper(helper_sketch_next_seed), + }, + ], + )) + } +} + +fn get_level(agg_param: &BTreeSet<IdpfInput>) -> Result<usize, VdafError> { + let mut level = None; + for prefix in agg_param { + if let Some(l) = level { + if prefix.level != l { + return Err(VdafError::Uncategorized( + "prefixes must all have the same length".to_string(), + )); + } + } else { + level = Some(prefix.level); + } + } + + match level { + Some(level) => Ok(level), + None => Err(VdafError::Uncategorized("prefix set is empty".to_string())), + } +} + +impl<I, P, const L: usize> Aggregator<L> for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + type PrepareState = Poplar1PrepareState<I::Field>; + type PrepareShare = Poplar1PrepareMessage<I::Field>; + type PrepareMessage = Poplar1PrepareMessage<I::Field>; + + #[allow(clippy::type_complexity)] + fn prepare_init( + &self, + verify_key: &[u8; L], + agg_id: usize, + agg_param: &BTreeSet<IdpfInput>, + nonce: &[u8], + _public_share: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result< + ( + Poplar1PrepareState<I::Field>, + Poplar1PrepareMessage<I::Field>, + ), + VdafError, + > { + let level = get_level(agg_param)?; + let is_leader = role_try_from(agg_id)?; + + // Derive the verification randomness. + let mut p = P::init(verify_key); + p.update(nonce); + let mut verify_rand_prng: Prng<I::Field, _> = Prng::from_seed_stream(p.into_seed_stream()); + + // Evaluate the IDPF shares and compute the polynomial coefficients. + let mut z = [I::Field::zero(); 3]; + let mut output_share = Vec::with_capacity(agg_param.len()); + for prefix in agg_param.iter() { + let value = input_share.idpf.eval(prefix)?; + let (v, k) = (value[0], value[1]); + let r = verify_rand_prng.get(); + + // [BBCG+21, Appendix C.4] + // + // $(z_\sigma, z^*_\sigma, z^{**}_\sigma)$ + let tmp = r * v; + z[0] += tmp; + z[1] += r * tmp; + z[2] += r * k; + output_share.push(v); + } + + // [BBCG+21, Appendix C.4] + // + // Add blind shares $(a_\sigma b_\sigma, c_\sigma)$ + // + // NOTE(cjpatton) We can make this faster by a factor of 3 by using three seed shares instead + // of one. On the other hand, if the input shares are made stateful, then we could store + // the PRNG state theire and avoid fast-forwarding. + let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream( + &input_share.sketch_start_seed, + b"", + )) + .skip(3 * level); + z[0] += prng.next().unwrap(); + z[1] += prng.next().unwrap(); + z[2] += prng.next().unwrap(); + + let (d, e) = match &input_share.sketch_next { + Share::Leader(data) => (data[2 * level], data[2 * level + 1]), + Share::Helper(seed) => { + let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream(seed, b"")) + .skip(2 * level); + (prng.next().unwrap(), prng.next().unwrap()) + } + }; + + let x = if is_leader { + I::Field::one() + } else { + I::Field::zero() + }; + + Ok(( + Poplar1PrepareState { + sketch: SketchState::RoundOne, + output_share: OutputShare(output_share), + d, + e, + x, + }, + Poplar1PrepareMessage(z.to_vec()), + )) + } + + fn prepare_preprocess<M: IntoIterator<Item = Poplar1PrepareMessage<I::Field>>>( + &self, + inputs: M, + ) -> Result<Poplar1PrepareMessage<I::Field>, VdafError> { + let mut output: Option<Vec<I::Field>> = None; + let mut count = 0; + for data_share in inputs.into_iter() { + count += 1; + if let Some(ref mut data) = output { + if data_share.0.len() != data.len() { + return Err(VdafError::Uncategorized(format!( + "unexpected message length: got {}; want {}", + data_share.0.len(), + data.len(), + ))); + } + + for (x, y) in data.iter_mut().zip(data_share.0.iter()) { + *x += *y; + } + } else { + output = Some(data_share.0); + } + } + + if count != 2 { + return Err(VdafError::Uncategorized(format!( + "unexpected message count: got {}; want 2", + count, + ))); + } + + Ok(Poplar1PrepareMessage(output.unwrap())) + } + + fn prepare_step( + &self, + mut state: Poplar1PrepareState<I::Field>, + msg: Poplar1PrepareMessage<I::Field>, + ) -> Result<PrepareTransition<Self, L>, VdafError> { + match &state.sketch { + SketchState::RoundOne => { + if msg.0.len() != 3 { + return Err(VdafError::Uncategorized(format!( + "unexpected message length ({:?}): got {}; want 3", + state.sketch, + msg.0.len(), + ))); + } + + // Compute polynomial coefficients. + let z: [I::Field; 3] = msg.0.try_into().unwrap(); + let y_share = + vec![(state.d * z[0]) + state.e + state.x * ((z[0] * z[0]) - z[1] - z[2])]; + + state.sketch = SketchState::RoundTwo; + Ok(PrepareTransition::Continue( + state, + Poplar1PrepareMessage(y_share), + )) + } + + SketchState::RoundTwo => { + if msg.0.len() != 1 { + return Err(VdafError::Uncategorized(format!( + "unexpected message length ({:?}): got {}; want 1", + state.sketch, + msg.0.len(), + ))); + } + + let y = msg.0[0]; + if y != I::Field::zero() { + return Err(VdafError::Uncategorized(format!( + "output is invalid: polynomial evaluated to {}; want {}", + y, + I::Field::zero(), + ))); + } + + Ok(PrepareTransition::Finish(state.output_share)) + } + } + } + + fn aggregate<M: IntoIterator<Item = OutputShare<I::Field>>>( + &self, + agg_param: &BTreeSet<IdpfInput>, + output_shares: M, + ) -> Result<AggregateShare<I::Field>, VdafError> { + let mut agg_share = AggregateShare(vec![I::Field::zero(); agg_param.len()]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + + Ok(agg_share) + } +} + +/// A prepare message sent exchanged between Poplar1 aggregators +#[derive(Clone, Debug)] +pub struct Poplar1PrepareMessage<F>(Vec<F>); + +impl<F> AsRef<[F]> for Poplar1PrepareMessage<F> { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl<F: FieldElement> Encode for Poplar1PrepareMessage<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + // TODO: This is encoded as a variable length vector of F, but we may + // be able to make this a fixed-length vector for specific Poplar1 + // instantations + encode_u16_items(bytes, &(), &self.0); + } +} + +impl<F: FieldElement> ParameterizedDecode<Poplar1PrepareState<F>> for Poplar1PrepareMessage<F> { + fn decode_with_param( + _decoding_parameter: &Poplar1PrepareState<F>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + // TODO: This is decoded as a variable length vector of F, but we may be + // able to make this a fixed-length vector for specific Poplar1 + // instantiations. + let items = decode_u16_items(&(), bytes)?; + + Ok(Self(items)) + } +} + +/// The state of each Aggregator during the Prepare process. +#[derive(Clone, Debug)] +pub struct Poplar1PrepareState<F> { + /// State of the secure sketching protocol. + sketch: SketchState, + + /// The output share. + output_share: OutputShare<F>, + + /// Aggregator's share of $A = -2a + k$. + d: F, + + /// Aggregator's share of $B = a^2 + b -ak + c$. + e: F, + + /// Equal to 1 if this Aggregator is the "leader" and 0 otherwise. + x: F, +} + +#[derive(Clone, Debug)] +enum SketchState { + RoundOne, + RoundTwo, +} + +impl<I, P, const L: usize> Collector for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + fn unshard<M: IntoIterator<Item = AggregateShare<I::Field>>>( + &self, + agg_param: &BTreeSet<IdpfInput>, + agg_shares: M, + _num_measurements: usize, + ) -> Result<BTreeMap<IdpfInput, u64>, VdafError> { + let mut agg_data = AggregateShare(vec![I::Field::zero(); agg_param.len()]); + for agg_share in agg_shares.into_iter() { + agg_data.merge(&agg_share)?; + } + + let mut agg = BTreeMap::new(); + for (prefix, count) in agg_param.iter().zip(agg_data.as_ref()) { + let count = <I::Field as FieldElement>::Integer::from(*count); + let count: u64 = count + .try_into() + .map_err(|_| VdafError::Uncategorized("aggregate overflow".to_string()))?; + agg.insert(*prefix, count); + } + Ok(agg) + } +} + +fn role_try_from(agg_id: usize) -> Result<bool, VdafError> { + match agg_id { + 0 => Ok(true), + 1 => Ok(false), + _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::field::Field128; + use crate::vdaf::prg::PrgAes128; + use crate::vdaf::{run_vdaf, run_vdaf_prepare}; + use rand::prelude::*; + + #[test] + fn test_idpf() { + // IDPF input equality tests. + assert_eq!( + IdpfInput::new(b"hello", 40).unwrap(), + IdpfInput::new(b"hello", 40).unwrap() + ); + assert_eq!( + IdpfInput::new(b"hi", 9).unwrap(), + IdpfInput::new(b"ha", 9).unwrap(), + ); + assert_eq!( + IdpfInput::new(b"hello", 25).unwrap(), + IdpfInput::new(b"help", 25).unwrap() + ); + assert_ne!( + IdpfInput::new(b"hello", 40).unwrap(), + IdpfInput::new(b"hello", 39).unwrap() + ); + assert_ne!( + IdpfInput::new(b"hello", 40).unwrap(), + IdpfInput::new(b"hell-", 40).unwrap() + ); + + // IDPF uniqueness tests + let mut unique = BTreeSet::new(); + assert!(unique.insert(IdpfInput::new(b"hello", 40).unwrap())); + assert!(!unique.insert(IdpfInput::new(b"hello", 40).unwrap())); + assert!(unique.insert(IdpfInput::new(b"hello", 39).unwrap())); + assert!(unique.insert(IdpfInput::new(b"bye", 20).unwrap())); + + // Generate IDPF keys. + let input = IdpfInput::new(b"hi", 16).unwrap(); + let keys = ToyIdpf::<Field128>::gen( + &input, + std::iter::repeat([Field128::one(), Field128::one()]), + ) + .unwrap(); + + // Try evaluating the IDPF keys on all prefixes. + for prefix_len in 0..input.level + 1 { + let res = eval_idpf( + &keys, + &input.prefix(prefix_len), + &[Field128::one(), Field128::one()], + ); + assert!(res.is_ok(), "prefix_len={} error: {:?}", prefix_len, res); + } + + // Try evaluating the IDPF keys on incorrect prefixes. + eval_idpf( + &keys, + &IdpfInput::new(&[2], 2).unwrap(), + &[Field128::zero(), Field128::zero()], + ) + .unwrap(); + + eval_idpf( + &keys, + &IdpfInput::new(&[23, 1], 12).unwrap(), + &[Field128::zero(), Field128::zero()], + ) + .unwrap(); + } + + fn eval_idpf<I, const KEY_LEN: usize, const OUT_LEN: usize>( + keys: &[I; KEY_LEN], + input: &IdpfInput, + expected_output: &[I::Field; OUT_LEN], + ) -> Result<(), VdafError> + where + I: Idpf<KEY_LEN, OUT_LEN>, + { + let mut output = [I::Field::zero(); OUT_LEN]; + for key in keys { + let output_share = key.eval(input)?; + for (x, y) in output.iter_mut().zip(output_share) { + *x += y; + } + } + + if expected_output != &output { + return Err(VdafError::Uncategorized(format!( + "eval_idpf(): unexpected output: got {:?}; want {:?}", + output, expected_output + ))); + } + + Ok(()) + } + + #[test] + fn test_poplar1() { + const INPUT_LEN: usize = 8; + + let vdaf: Poplar1<ToyIdpf<Field128>, PrgAes128, 16> = Poplar1::new(INPUT_LEN); + assert_eq!(vdaf.num_aggregators(), 2); + + // Run the VDAF input-distribution algorithm. + let input = vec![IdpfInput::new(&[0b0110_1000], INPUT_LEN).unwrap()]; + + let mut agg_param = BTreeSet::new(); + agg_param.insert(input[0]); + check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]); + + // Try evaluating the VDAF on each prefix of the input. + for prefix_len in 0..input[0].level + 1 { + let mut agg_param = BTreeSet::new(); + agg_param.insert(input[0].prefix(prefix_len)); + check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]); + } + + // Try various prefixes. + let prefix_len = 4; + let mut agg_param = BTreeSet::new(); + // At length 4, the next two prefixes are equal. Neither one matches the input. + agg_param.insert(IdpfInput::new(&[0b0000_0000], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0001_0000], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_0001], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap()); + // At length 4, the next two prefixes are equal. Both match the input. + agg_param.insert(IdpfInput::new(&[0b0111_1101], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap()); + let aggregate = run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(); + assert_eq!(aggregate.len(), agg_param.len()); + check_btree( + &aggregate, + // We put six prefixes in the aggregation parameter, but the vector we get back is only + // 4 elements because at the given prefix length, some of the prefixes are equal. + &[0, 0, 0, 1], + ); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = b"this is a nonce"; + + // Try evaluating the VDAF with an invalid aggregation parameter. (It's an error to have a + // mixture of prefix lengths.) + let mut agg_param = BTreeSet::new(); + agg_param.insert(IdpfInput::new(&[0b0000_0111], 6).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_1000], 7).unwrap()); + let (public_share, input_shares) = vdaf.shard(&input[0]).unwrap(); + run_vdaf_prepare( + &vdaf, + &verify_key, + &agg_param, + nonce, + public_share, + input_shares, + ) + .unwrap_err(); + + // Try evaluating the VDAF with malformed inputs. + // + // This IDPF key pair evaluates to 1 everywhere, which is illegal. + let (public_share, mut input_shares) = vdaf.shard(&input[0]).unwrap(); + for (i, x) in input_shares[0].idpf.data0.iter_mut().enumerate() { + if i != input[0].index { + *x += Field128::one(); + } + } + let mut agg_param = BTreeSet::new(); + agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap()); + run_vdaf_prepare( + &vdaf, + &verify_key, + &agg_param, + nonce, + public_share, + input_shares, + ) + .unwrap_err(); + + // This IDPF key pair has a garbled authentication vector. + let (public_share, mut input_shares) = vdaf.shard(&input[0]).unwrap(); + for x in input_shares[0].idpf.data1.iter_mut() { + *x = Field128::zero(); + } + let mut agg_param = BTreeSet::new(); + agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap()); + run_vdaf_prepare( + &vdaf, + &verify_key, + &agg_param, + nonce, + public_share, + input_shares, + ) + .unwrap_err(); + } + + fn check_btree(btree: &BTreeMap<IdpfInput, u64>, counts: &[u64]) { + for (got, want) in btree.values().zip(counts.iter()) { + assert_eq!(got, want, "got {:?} want {:?}", btree.values(), counts); + } + } +} diff --git a/third_party/rust/prio/src/vdaf/prg.rs b/third_party/rust/prio/src/vdaf/prg.rs new file mode 100644 index 0000000000..a5930f1283 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prg.rs @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementations of PRGs specified in [[draft-irtf-cfrg-vdaf-03]]. +//! +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +use crate::vdaf::{CodecError, Decode, Encode}; +#[cfg(feature = "crypto-dependencies")] +use aes::{ + cipher::{KeyIvInit, StreamCipher}, + Aes128, +}; +#[cfg(feature = "crypto-dependencies")] +use cmac::{Cmac, Mac}; +#[cfg(feature = "crypto-dependencies")] +use ctr::Ctr64BE; +#[cfg(feature = "crypto-dependencies")] +use std::fmt::Formatter; +use std::{ + fmt::Debug, + io::{Cursor, Read}, +}; + +/// Function pointer to fill a buffer with random bytes. Under normal operation, +/// `getrandom::getrandom()` will be used, but other implementations can be used to control +/// randomness when generating or verifying test vectors. +pub(crate) type RandSource = fn(&mut [u8]) -> Result<(), getrandom::Error>; + +/// Input of [`Prg`]. +#[derive(Clone, Debug, Eq)] +pub struct Seed<const L: usize>(pub(crate) [u8; L]); + +impl<const L: usize> Seed<L> { + /// Generate a uniform random seed. + pub fn generate() -> Result<Self, getrandom::Error> { + Self::from_rand_source(getrandom::getrandom) + } + + pub(crate) fn from_rand_source(rand_source: RandSource) -> Result<Self, getrandom::Error> { + let mut seed = [0; L]; + rand_source(&mut seed)?; + Ok(Self(seed)) + } +} + +impl<const L: usize> AsRef<[u8; L]> for Seed<L> { + fn as_ref(&self) -> &[u8; L] { + &self.0 + } +} + +impl<const L: usize> PartialEq for Seed<L> { + fn eq(&self, other: &Self) -> bool { + // Do constant-time compare. + let mut r = 0; + for (x, y) in self.0[..].iter().zip(&other.0[..]) { + r |= x ^ y; + } + r == 0 + } +} + +impl<const L: usize> Encode for Seed<L> { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&self.0[..]); + } +} + +impl<const L: usize> Decode for Seed<L> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut seed = [0; L]; + bytes.read_exact(&mut seed)?; + Ok(Seed(seed)) + } +} + +/// A stream of pseudorandom bytes derived from a seed. +pub trait SeedStream { + /// Fill `buf` with the next `buf.len()` bytes of output. + fn fill(&mut self, buf: &mut [u8]); +} + +/// A pseudorandom generator (PRG) with the interface specified in [[draft-irtf-cfrg-vdaf-03]]. +/// +/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +pub trait Prg<const L: usize>: Clone + Debug { + /// The type of stream produced by this PRG. + type SeedStream: SeedStream; + + /// Construct an instance of [`Prg`] with the given seed. + fn init(seed_bytes: &[u8; L]) -> Self; + + /// Update the PRG state by passing in the next fragment of the info string. The final info + /// string is assembled from the concatenation of sequence of fragments passed to this method. + fn update(&mut self, data: &[u8]); + + /// Finalize the PRG state, producing a seed stream. + fn into_seed_stream(self) -> Self::SeedStream; + + /// Finalize the PRG state, producing a seed. + fn into_seed(self) -> Seed<L> { + let mut new_seed = [0; L]; + let mut seed_stream = self.into_seed_stream(); + seed_stream.fill(&mut new_seed); + Seed(new_seed) + } + + /// Construct a seed stream from the given seed and info string. + fn seed_stream(seed: &Seed<L>, info: &[u8]) -> Self::SeedStream { + let mut prg = Self::init(seed.as_ref()); + prg.update(info); + prg.into_seed_stream() + } +} + +/// The PRG based on AES128 as specified in [[draft-irtf-cfrg-vdaf-03]]. +/// +/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +#[derive(Clone, Debug)] +#[cfg(feature = "crypto-dependencies")] +pub struct PrgAes128(Cmac<Aes128>); + +#[cfg(feature = "crypto-dependencies")] +impl Prg<16> for PrgAes128 { + type SeedStream = SeedStreamAes128; + + fn init(seed_bytes: &[u8; 16]) -> Self { + Self(Cmac::new_from_slice(seed_bytes).unwrap()) + } + + fn update(&mut self, data: &[u8]) { + self.0.update(data); + } + + fn into_seed_stream(self) -> SeedStreamAes128 { + let key = self.0.finalize().into_bytes(); + SeedStreamAes128::new(&key, &[0; 16]) + } +} + +/// The key stream produced by AES128 in CTR-mode. +#[cfg(feature = "crypto-dependencies")] +pub struct SeedStreamAes128(Ctr64BE<Aes128>); + +#[cfg(feature = "crypto-dependencies")] +impl SeedStreamAes128 { + pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self { + SeedStreamAes128(Ctr64BE::<Aes128>::new(key.into(), iv.into())) + } +} + +#[cfg(feature = "crypto-dependencies")] +impl SeedStream for SeedStreamAes128 { + fn fill(&mut self, buf: &mut [u8]) { + buf.fill(0); + self.0.apply_keystream(buf); + } +} + +#[cfg(feature = "crypto-dependencies")] +impl Debug for SeedStreamAes128 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + // Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that + // with [`cipher::StreamCipherCoreWrapper::get_core`][2]. + // + // [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html + // [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html + self.0.get_core().fmt(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{field::Field128, prng::Prng}; + use serde::{Deserialize, Serialize}; + use std::convert::TryInto; + + #[derive(Deserialize, Serialize)] + struct PrgTestVector { + #[serde(with = "hex")] + seed: Vec<u8>, + #[serde(with = "hex")] + info: Vec<u8>, + length: usize, + #[serde(with = "hex")] + derived_seed: Vec<u8>, + #[serde(with = "hex")] + expanded_vec_field128: Vec<u8>, + } + + // Test correctness of dervied methods. + fn test_prg<P, const L: usize>() + where + P: Prg<L>, + { + let seed = Seed::generate().unwrap(); + let info = b"info string"; + + let mut prg = P::init(seed.as_ref()); + prg.update(info); + + let mut want = Seed([0; L]); + prg.clone().into_seed_stream().fill(&mut want.0[..]); + let got = prg.clone().into_seed(); + assert_eq!(got, want); + + let mut want = [0; 45]; + prg.clone().into_seed_stream().fill(&mut want); + let mut got = [0; 45]; + P::seed_stream(&seed, info).fill(&mut got); + assert_eq!(got, want); + } + + #[test] + fn prg_aes128() { + let t: PrgTestVector = + serde_json::from_str(include_str!("test_vec/03/PrgAes128.json")).unwrap(); + let mut prg = PrgAes128::init(&t.seed.try_into().unwrap()); + prg.update(&t.info); + + assert_eq!( + prg.clone().into_seed(), + Seed(t.derived_seed.try_into().unwrap()) + ); + + let mut bytes = std::io::Cursor::new(t.expanded_vec_field128.as_slice()); + let mut want = Vec::with_capacity(t.length); + while (bytes.position() as usize) < t.expanded_vec_field128.len() { + want.push(Field128::decode(&mut bytes).unwrap()) + } + let got: Vec<Field128> = Prng::from_seed_stream(prg.clone().into_seed_stream()) + .take(t.length) + .collect(); + assert_eq!(got, want); + + test_prg::<PrgAes128, 16>(); + } +} diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs new file mode 100644 index 0000000000..47fc076790 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio2.rs @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Port of the ENPA Prio system to a VDAF. It is backwards compatible with +//! [`Client`](crate::client::Client) and [`Server`](crate::server::Server). + +use crate::{ + client as v2_client, + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{FieldElement, FieldPrio2}, + prng::Prng, + server as v2_server, + util::proof_length, + vdaf::{ + prg::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, + PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError, + }, +}; +use ring::hmac; +use std::{ + convert::{TryFrom, TryInto}, + io::Cursor, +}; + +/// The Prio2 VDAF. It supports the same measurement type as +/// [`Prio3Aes128CountVec`](crate::vdaf::prio3::Prio3Aes128CountVec) but uses the proof system +/// and finite field deployed in ENPA. +#[derive(Clone, Debug)] +pub struct Prio2 { + input_len: usize, +} + +impl Prio2 { + /// Returns an instance of the VDAF for the given input length. + pub fn new(input_len: usize) -> Result<Self, VdafError> { + let n = (input_len + 1).next_power_of_two(); + if let Ok(size) = u32::try_from(2 * n) { + if size > FieldPrio2::generator_order() { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + } else { + return Err(VdafError::Uncategorized( + "input size exceeds memory capacity".into(), + )); + } + + Ok(Prio2 { input_len }) + } + + /// Prepare an input share for aggregation using the given field element `query_rand` to + /// compute the verifier share. + /// + /// In the [`Aggregator`](crate::vdaf::Aggregator) trait implementation for [`Prio2`], the + /// query randomness is computed jointly by the Aggregators. This method is designed to be used + /// in applications, like ENPA, in which the query randomness is instead chosen by a + /// third-party. + pub fn prepare_init_with_query_rand( + &self, + query_rand: FieldPrio2, + input_share: &Share<FieldPrio2, 32>, + is_leader: bool, + ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { + let expanded_data: Option<Vec<FieldPrio2>> = match input_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_prio2_seed(seed.as_ref()); + Some(prng.take(proof_length(self.input_len)).collect()) + } + }; + let data = match input_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_data.as_ref().unwrap(), + }; + + let mut mem = v2_server::ValidationMemory::new(self.input_len); + let verifier_share = v2_server::generate_verification_message( + self.input_len, + query_rand, + data, // Combined input and proof shares + is_leader, + &mut mem, + ) + .map_err(|e| VdafError::Uncategorized(e.to_string()))?; + + Ok(( + Prio2PrepareState(input_share.truncated(self.input_len)), + Prio2PrepareShare(verifier_share), + )) + } +} + +impl Vdaf for Prio2 { + const ID: u32 = 0xFFFF0000; + type Measurement = Vec<u32>; + type AggregateResult = Vec<u32>; + type AggregationParam = (); + type PublicShare = (); + type InputShare = Share<FieldPrio2, 32>; + type OutputShare = OutputShare<FieldPrio2>; + type AggregateShare = AggregateShare<FieldPrio2>; + + fn num_aggregators(&self) -> usize { + // Prio2 can easily be extended to support more than two Aggregators. + 2 + } +} + +impl Client for Prio2 { + fn shard(&self, measurement: &Vec<u32>) -> Result<((), Vec<Share<FieldPrio2, 32>>), VdafError> { + if measurement.len() != self.input_len { + return Err(VdafError::Uncategorized("incorrect input length".into())); + } + let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len()); + for int in measurement { + input.push((*int).into()); + } + + let mut mem = v2_client::ClientMemory::new(self.input_len)?; + let copy_data = |share_data: &mut [FieldPrio2]| { + share_data[..].clone_from_slice(&input); + }; + let mut leader_data = mem.prove_with(self.input_len, copy_data); + + let helper_seed = Seed::generate()?; + let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref()); + for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) { + *s1 -= d; + } + + Ok(( + (), + vec![Share::Leader(leader_data), Share::Helper(helper_seed)], + )) + } +} + +/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Prio2PrepareState(Share<FieldPrio2, 32>); + +impl Encode for Prio2PrepareState { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes); + } +} + +impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState { + fn decode_with_param( + (prio2, agg_id): &(&'a Prio2, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let share_decoder = if *agg_id == 0 { + ShareDecodingParameter::Leader(prio2.input_len) + } else { + ShareDecodingParameter::Helper + }; + let out_share = Share::decode_with_param(&share_decoder, bytes)?; + Ok(Self(out_share)) + } +} + +/// Message emitted by each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. +#[derive(Clone, Debug)] +pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>); + +impl Encode for Prio2PrepareShare { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.f_r.encode(bytes); + self.0.g_r.encode(bytes); + self.0.h_r.encode(bytes); + } +} + +impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare { + fn decode_with_param( + _state: &Prio2PrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Ok(Self(v2_server::VerificationMessage { + f_r: FieldPrio2::decode(bytes)?, + g_r: FieldPrio2::decode(bytes)?, + h_r: FieldPrio2::decode(bytes)?, + })) + } +} + +impl Aggregator<32> for Prio2 { + type PrepareState = Prio2PrepareState; + type PrepareShare = Prio2PrepareShare; + type PrepareMessage = (); + + fn prepare_init( + &self, + agg_key: &[u8; 32], + agg_id: usize, + _agg_param: &(), + nonce: &[u8], + _public_share: &Self::PublicShare, + input_share: &Share<FieldPrio2, 32>, + ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { + let is_leader = role_try_from(agg_id)?; + + // In the ENPA Prio system, the query randomness is generated by a third party and + // distributed to the Aggregators after they receive their input shares. In a VDAF, shared + // randomness is derived from a nonce selected by the client. For Prio2 we compute the + // query using HMAC-SHA256 evaluated over the nonce. + let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, agg_key); + let hmac_tag = hmac::sign(&hmac_key, nonce); + let query_rand = Prng::from_prio2_seed(hmac_tag.as_ref().try_into().unwrap()) + .next() + .unwrap(); + + self.prepare_init_with_query_rand(query_rand, input_share, is_leader) + } + + fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>( + &self, + inputs: M, + ) -> Result<(), VdafError> { + let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> = + inputs.into_iter().map(|msg| msg.0).collect(); + if verifier_shares.len() != 2 { + return Err(VdafError::Uncategorized( + "wrong number of verifier shares".into(), + )); + } + + if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) { + return Err(VdafError::Uncategorized( + "proof verifier check failed".into(), + )); + } + + Ok(()) + } + + fn prepare_step( + &self, + state: Prio2PrepareState, + _input: (), + ) -> Result<PrepareTransition<Self, 32>, VdafError> { + let data = match state.0 { + Share::Leader(data) => data, + Share::Helper(seed) => { + let prng = Prng::from_prio2_seed(seed.as_ref()); + prng.take(self.input_len).collect() + } + }; + Ok(PrepareTransition::Finish(OutputShare::from(data))) + } + + fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>( + &self, + _agg_param: &(), + out_shares: M, + ) -> Result<AggregateShare<FieldPrio2>, VdafError> { + let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + for out_share in out_shares.into_iter() { + agg_share.accumulate(&out_share)?; + } + + Ok(agg_share) + } +} + +impl Collector for Prio2 { + fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>( + &self, + _agg_param: &(), + agg_shares: M, + _num_measurements: usize, + ) -> Result<Vec<u32>, VdafError> { + let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + for agg_share in agg_shares.into_iter() { + agg.merge(&agg_share)?; + } + + Ok(agg.0.into_iter().map(u32::from).collect()) + } +} + +impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> { + fn decode_with_param( + (prio2, agg_id): &(&'a Prio2, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?; + let decoder = if is_leader { + ShareDecodingParameter::Leader(proof_length(prio2.input_len)) + } else { + ShareDecodingParameter::Helper + }; + + Share::decode_with_param(&decoder, bytes) + } +} + +fn role_try_from(agg_id: usize) -> Result<bool, VdafError> { + match agg_id { + 0 => Ok(true), + 1 => Ok(false), + _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + client::encode_simple, + encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey}, + field::random_vector, + server::Server, + vdaf::{run_vdaf, run_vdaf_prepare}, + }; + use rand::prelude::*; + + const PRIV_KEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw=="; + const PRIV_KEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w=="; + + #[test] + fn run_prio2() { + let prio2 = Prio2::new(6).unwrap(); + + assert_eq!( + run_vdaf( + &prio2, + &(), + [ + vec![0, 0, 0, 0, 1, 0], + vec![0, 1, 0, 0, 0, 0], + vec![0, 1, 1, 0, 0, 0], + vec![1, 1, 1, 0, 0, 0], + vec![0, 0, 0, 0, 1, 1], + ] + ) + .unwrap(), + vec![1, 3, 2, 0, 2, 1], + ); + } + + #[test] + fn enpa_client_interop() { + let mut rng = thread_rng(); + let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap(); + let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap(); + let pub_key1 = PublicKey::from(&priv_key1); + let pub_key2 = PublicKey::from(&priv_key2); + + let data: Vec<FieldPrio2> = [0, 0, 1, 1, 0] + .iter() + .map(|x| FieldPrio2::from(*x)) + .collect(); + let (encrypted_input_share1, encrypted_input_share2) = + encode_simple(&data, pub_key1, pub_key2).unwrap(); + + let input_share1 = decrypt_share(&encrypted_input_share1, &priv_key1).unwrap(); + let input_share2 = decrypt_share(&encrypted_input_share2, &priv_key2).unwrap(); + + let prio2 = Prio2::new(data.len()).unwrap(); + let input_shares = vec![ + Share::get_decoded_with_param(&(&prio2, 0), &input_share1).unwrap(), + Share::get_decoded_with_param(&(&prio2, 1), &input_share2).unwrap(), + ]; + + let verify_key = rng.gen(); + let mut nonce = [0; 16]; + rng.fill(&mut nonce); + run_vdaf_prepare(&prio2, &verify_key, &(), &nonce, (), input_shares).unwrap(); + } + + #[test] + fn enpa_server_interop() { + let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap(); + let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap(); + let pub_key1 = PublicKey::from(&priv_key1); + let pub_key2 = PublicKey::from(&priv_key2); + + let data = vec![0, 0, 1, 1, 0]; + let prio2 = Prio2::new(data.len()).unwrap(); + let (_public_share, input_shares) = prio2.shard(&data).unwrap(); + + let encrypted_input_share1 = + encrypt_share(&input_shares[0].get_encoded(), &pub_key1).unwrap(); + let encrypted_input_share2 = + encrypt_share(&input_shares[1].get_encoded(), &pub_key2).unwrap(); + + let mut server1 = Server::new(data.len(), true, priv_key1).unwrap(); + let mut server2 = Server::new(data.len(), false, priv_key2).unwrap(); + + let eval_at: FieldPrio2 = random_vector(1).unwrap()[0]; + let verifier1 = server1 + .generate_verification_message(eval_at, &encrypted_input_share1) + .unwrap(); + let verifier2 = server2 + .generate_verification_message(eval_at, &encrypted_input_share2) + .unwrap(); + + server1 + .aggregate(&encrypted_input_share1, &verifier1, &verifier2) + .unwrap(); + server2 + .aggregate(&encrypted_input_share2, &verifier1, &verifier2) + .unwrap(); + } + + #[test] + fn prepare_state_serialization() { + let mut verify_key = [0; 32]; + thread_rng().fill(&mut verify_key[..]); + let data = vec![0, 0, 1, 1, 0]; + let prio2 = Prio2::new(data.len()).unwrap(); + let (public_share, input_shares) = prio2.shard(&data).unwrap(); + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (want, _msg) = prio2 + .prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share) + .unwrap(); + let got = + Prio2PrepareState::get_decoded_with_param(&(&prio2, agg_id), &want.get_encoded()) + .expect("failed to decode prepare step"); + assert_eq!(got, want); + } + } +} diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs new file mode 100644 index 0000000000..31853f15ab --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio3.rs @@ -0,0 +1,1168 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-03]]. +//! +//! **WARNING:** Neither this code nor the cryptographic construction it implements has undergone +//! significant security analysis. Use at your own risk. +//! +//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented +//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO +//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication +//! cost. +//! +//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-03]] into +//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of +//! which are instantiated here: +//! +//! - [`Prio3Aes128Count`] for aggregating a counter (*) +//! - [`Prio3Aes128CountVec`] for aggregating a vector of counters +//! - [`Prio3Aes128Sum`] for copmputing the sum of integers (*) +//! - [`Prio3Aes128Histogram`] for estimating a distribution via a histogram (*) +//! +//! Additional types can be constructed from [`Prio3`] as needed. +//! +//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-03]]. +//! +//! [BBCG+19]: https://ia.cr/2019/188 +//! [CGB17]: https://crypto.stanford.edu/prio/ +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +#[cfg(feature = "crypto-dependencies")] +use super::prg::PrgAes128; +use super::{DST_LEN, VERSION}; +use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; +use crate::field::FieldElement; +#[cfg(feature = "crypto-dependencies")] +use crate::field::{Field128, Field64}; +#[cfg(feature = "multithreaded")] +use crate::flp::gadgets::ParallelSumMultithreaded; +#[cfg(feature = "crypto-dependencies")] +use crate::flp::gadgets::{BlindPolyEval, ParallelSum}; +#[cfg(feature = "crypto-dependencies")] +use crate::flp::types::{Average, Count, CountVec, Histogram, Sum}; +use crate::flp::Type; +use crate::prng::Prng; +use crate::vdaf::prg::{Prg, RandSource, Seed}; +use crate::vdaf::{ + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, + Share, ShareDecodingParameter, Vdaf, VdafError, +}; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::io::Cursor; +use std::iter::IntoIterator; +use std::marker::PhantomData; + +/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Count = Prio3<Count<Field64>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Count { + /// Construct an instance of Prio3Aes128Count with the given number of aggregators. + pub fn new_aes128_count(num_aggregators: u8) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, Count::new()) + } +} + +/// The count-vector type. Each measurement is a vector of integers in `[0,2)` and the aggregate is +/// the element-wise sum. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128CountVec = + Prio3<CountVec<Field128, ParallelSum<Field128, BlindPolyEval<Field128>>>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128CountVec { + /// Construct an instance of Prio3Aes1238CountVec with the given number of aggregators. `len` + /// defines the length of each measurement. + pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, CountVec::new(len)) + } +} + +/// Like [`Prio3Aes128CountVec`] except this type uses multithreading to improve sharding and +/// preparation time. Note that the improvement is only noticeable for very large input lengths, +/// e.g., 201 and up. (Your system's mileage may vary.) +#[cfg(feature = "multithreaded")] +#[cfg(feature = "crypto-dependencies")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +pub type Prio3Aes128CountVecMultithreaded = Prio3< + CountVec<Field128, ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>>, + PrgAes128, + 16, +>; + +#[cfg(feature = "multithreaded")] +#[cfg(feature = "crypto-dependencies")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +impl Prio3Aes128CountVecMultithreaded { + /// Construct an instance of Prio3Aes1238CountVecMultithreaded with the given number of + /// aggregators. `len` defines the length of each measurement. + pub fn new_aes128_count_vec_multithreaded( + num_aggregators: u8, + len: usize, + ) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, CountVec::new(len)) + } +} + +/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the +/// aggregate is the sum. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Sum = Prio3<Sum<Field128>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Sum { + /// Construct an instance of Prio3Aes128Sum with the given number of aggregators and required + /// bit length. The bit length must not exceed 64. + pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> { + if bits > 64 { + return Err(VdafError::Uncategorized(format!( + "bit length ({}) exceeds limit for aggregate type (64)", + bits + ))); + } + + Prio3::new(num_aggregators, Sum::new(bits as usize)?) + } +} + +/// The histogram type. Each measurement is an unsigned integer and the result is a histogram +/// representation of the distribution. The bucket boundaries are fixed in advance. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Histogram = Prio3<Histogram<Field128>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Histogram { + /// Constructs an instance of Prio3Aes128Histogram with the given number of aggregators and + /// desired histogram bucket boundaries. + pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result<Self, VdafError> { + let buckets = buckets.iter().map(|bucket| *bucket as u128).collect(); + + Prio3::new(num_aggregators, Histogram::new(buckets)?) + } +} + +/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and +/// the aggregate is the arithmetic average. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Average = Prio3<Average<Field128>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Average { + /// Construct an instance of Prio3Aes128Average with the given number of aggregators and + /// required bit length. The bit length must not exceed 64. + pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + + if bits > 64 { + return Err(VdafError::Uncategorized(format!( + "bit length ({}) exceeds limit for aggregate type (64)", + bits + ))); + } + + Ok(Prio3 { + num_aggregators, + typ: Average::new(bits as usize)?, + phantom: PhantomData, + }) + } +} + +/// The base type for Prio3. +/// +/// An instance of Prio3 is determined by: +/// +/// - a [`Type`](crate::flp::Type) that defines the set of valid input measurements; and +/// - a [`Prg`](crate::vdaf::prg::Prg) for deriving vectors of field elements from seeds. +/// +/// New instances can be defined by aliasing the base type. For example, [`Prio3Aes128Count`] is an +/// alias for `Prio3<Count<Field64>, PrgAes128, 16>`. +/// +/// ``` +/// use prio::vdaf::{ +/// Aggregator, Client, Collector, PrepareTransition, +/// prio3::Prio3, +/// }; +/// use rand::prelude::*; +/// +/// let num_shares = 2; +/// let vdaf = Prio3::new_aes128_count(num_shares).unwrap(); +/// +/// let mut out_shares = vec![vec![]; num_shares.into()]; +/// let mut rng = thread_rng(); +/// let verify_key = rng.gen(); +/// let measurements = [0, 1, 1, 1, 0]; +/// for measurement in measurements { +/// // Shard +/// let (public_share, input_shares) = vdaf.shard(&measurement).unwrap(); +/// let mut nonce = [0; 16]; +/// rng.fill(&mut nonce); +/// +/// // Prepare +/// let mut prep_states = vec![]; +/// let mut prep_shares = vec![]; +/// for (agg_id, input_share) in input_shares.iter().enumerate() { +/// let (state, share) = vdaf.prepare_init( +/// &verify_key, +/// agg_id, +/// &(), +/// &nonce, +/// &public_share, +/// input_share +/// ).unwrap(); +/// prep_states.push(state); +/// prep_shares.push(share); +/// } +/// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); +/// +/// for (agg_id, state) in prep_states.into_iter().enumerate() { +/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() { +/// PrepareTransition::Finish(out_share) => out_share, +/// _ => panic!("unexpected transition"), +/// }; +/// out_shares[agg_id].push(out_share); +/// } +/// } +/// +/// // Aggregate +/// let agg_shares = out_shares.into_iter() +/// .map(|o| vdaf.aggregate(&(), o).unwrap()); +/// +/// // Unshard +/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap(); +/// assert_eq!(agg_res, 3); +/// ``` +/// +/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +#[derive(Clone, Debug)] +pub struct Prio3<T, P, const L: usize> +where + T: Type, + P: Prg<L>, +{ + num_aggregators: u8, + typ: T, + phantom: PhantomData<P>, +} + +impl<T, P, const L: usize> Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the + /// underlying type. + pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + Ok(Self { + num_aggregators, + typ, + phantom: PhantomData, + }) + } + + /// The output length of the underlying FLP. + pub fn output_len(&self) -> usize { + self.typ.output_len() + } + + /// The verifier length of the underlying FLP. + pub fn verifier_len(&self) -> usize { + self.typ.verifier_len() + } + + fn derive_joint_randomness<'a>(parts: impl Iterator<Item = &'a Seed<L>>) -> Seed<L> { + let mut info = [0; VERSION.len() + 5]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes()); + info[VERSION.len() + 4] = 255; + let mut deriver = P::init(&[0; L]); + deriver.update(&info); + for part in parts { + deriver.update(part.as_ref()); + } + deriver.into_seed() + } + + fn shard_with_rand_source( + &self, + measurement: &T::Measurement, + rand_source: RandSource, + ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> { + let mut info = [0; DST_LEN + 1]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); + + let num_aggregators = self.num_aggregators; + let input = self.typ.encode_measurement(measurement)?; + + // Generate the input shares and compute the joint randomness. + let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1); + let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 { + Some(Vec::with_capacity(num_aggregators as usize - 1)) + } else { + None + }; + let mut leader_input_share = input.clone(); + for agg_id in 1..num_aggregators { + let helper = HelperShare::from_rand_source(rand_source)?; + + let mut deriver = P::init(helper.joint_rand_param.blind.as_ref()); + info[DST_LEN] = agg_id; + deriver.update(&info); + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info)); + for (x, y) in leader_input_share + .iter_mut() + .zip(prng) + .take(self.typ.input_len()) + { + *x -= y; + deriver.update(&y.into()); + } + + if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() { + helper_joint_rand_parts.push(deriver.into_seed()); + } + helper_shares.push(helper); + } + + let leader_blind = Seed::from_rand_source(rand_source)?; + + info[DST_LEN] = 0; // ID of the leader + let mut deriver = P::init(leader_blind.as_ref()); + deriver.update(&info); + for x in leader_input_share.iter() { + deriver.update(&(*x).into()); + } + + let leader_joint_rand_seed_part = deriver.into_seed(); + + // Compute the joint randomness seed. + let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| { + Self::derive_joint_randomness( + std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()), + ) + }); + + // Run the proof-generation algorithm. + let domain_separation_tag = &info[..DST_LEN]; + let joint_rand: Vec<T::Field> = joint_rand_seed + .map(|joint_rand_seed| { + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag)); + prng.take(self.typ.joint_rand_len()).collect() + }) + .unwrap_or_default(); + let prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream( + &Seed::from_rand_source(rand_source)?, + domain_separation_tag, + )); + let prove_rand: Vec<T::Field> = prng.take(self.typ.prove_rand_len()).collect(); + let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?; + + // Generate the proof shares and distribute the joint randomness seed hints. + for (j, helper) in helper_shares.iter_mut().enumerate() { + info[DST_LEN] = j as u8 + 1; + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info)); + for (x, y) in leader_proof_share + .iter_mut() + .zip(prng) + .take(self.typ.proof_len()) + { + *x -= y; + } + + if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() { + let mut hint = Vec::with_capacity(num_aggregators as usize - 1); + hint.push(leader_joint_rand_seed_part.clone()); + hint.extend(helper_joint_rand_parts[..j].iter().cloned()); + hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned()); + helper.joint_rand_param.seed_hint = hint; + } + } + + let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 { + Some(JointRandParam { + seed_hint: helper_joint_rand_parts.unwrap_or_default(), + blind: leader_blind, + }) + } else { + None + }; + + // Prep the output messages. + let mut out = Vec::with_capacity(num_aggregators as usize); + out.push(Prio3InputShare { + input_share: Share::Leader(leader_input_share), + proof_share: Share::Leader(leader_proof_share), + joint_rand_param: leader_joint_rand_param, + }); + + for helper in helper_shares.into_iter() { + let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 { + Some(helper.joint_rand_param) + } else { + None + }; + + out.push(Prio3InputShare { + input_share: Share::Helper(helper.input_share), + proof_share: Share::Helper(helper.proof_share), + joint_rand_param: helper_joint_rand_param, + }); + } + + Ok(out) + } + + /// Shard measurement with constant randomness of repeated bytes. + /// This method is not secure. It is used for running test vectors for Prio3. + #[cfg(feature = "test-util")] + pub fn test_vec_shard( + &self, + measurement: &T::Measurement, + ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> { + self.shard_with_rand_source(measurement, |buf| { + buf.fill(1); + Ok(()) + }) + } + + fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> { + if agg_id >= self.num_aggregators as usize { + return Err(VdafError::Uncategorized("unexpected aggregator id".into())); + } + Ok(u8::try_from(agg_id).unwrap()) + } +} + +impl<T, P, const L: usize> Vdaf for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + const ID: u32 = T::ID; + type Measurement = T::Measurement; + type AggregateResult = T::AggregateResult; + type AggregationParam = (); + type PublicShare = (); + type InputShare = Prio3InputShare<T::Field, L>; + type OutputShare = OutputShare<T::Field>; + type AggregateShare = AggregateShare<T::Field>; + + fn num_aggregators(&self) -> usize { + self.num_aggregators as usize + } +} + +/// Message sent by the [`Client`](crate::vdaf::Client) to each +/// [`Aggregator`](crate::vdaf::Aggregator) during the Sharding phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Prio3InputShare<F, const L: usize> { + /// The input share. + input_share: Share<F, L>, + + /// The proof share. + proof_share: Share<F, L>, + + /// Parameters used by the Aggregator to compute the joint randomness. This field is optional + /// because not every [`Type`](`crate::flp::Type`) requires joint randomness. + joint_rand_param: Option<JointRandParam<L>>, +} + +impl<F: FieldElement, const L: usize> Encode for Prio3InputShare<F, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + if matches!( + (&self.input_share, &self.proof_share), + (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_)) + ) { + panic!("tried to encode input share with ambiguous encoding") + } + + self.input_share.encode(bytes); + self.proof_share.encode(bytes); + if let Some(ref param) = self.joint_rand_param { + param.blind.encode(bytes); + for part in param.seed_hint.iter() { + part.encode(bytes); + } + } + } +} + +impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)> + for Prio3InputShare<T::Field, L> +where + T: Type, + P: Prg<L>, +{ + fn decode_with_param( + (prio3, agg_id): &(&'a Prio3<T, P, L>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let agg_id = prio3 + .role_try_from(*agg_id) + .map_err(|e| CodecError::Other(Box::new(e)))?; + let (input_decoder, proof_decoder) = if agg_id == 0 { + ( + ShareDecodingParameter::Leader(prio3.typ.input_len()), + ShareDecodingParameter::Leader(prio3.typ.proof_len()), + ) + } else { + ( + ShareDecodingParameter::Helper, + ShareDecodingParameter::Helper, + ) + }; + + let input_share = Share::decode_with_param(&input_decoder, bytes)?; + let proof_share = Share::decode_with_param(&proof_decoder, bytes)?; + let joint_rand_param = if prio3.typ.joint_rand_len() > 0 { + let num_aggregators = prio3.num_aggregators(); + let blind = Seed::decode(bytes)?; + let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes)) + .take(num_aggregators - 1) + .collect::<Result<Vec<_>, _>>()?; + Some(JointRandParam { blind, seed_hint }) + } else { + None + }; + + Ok(Prio3InputShare { + input_share, + proof_share, + joint_rand_param, + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +/// Message broadcast by each [`Aggregator`](crate::vdaf::Aggregator) in each round of the +/// Preparation phase. +pub struct Prio3PrepareShare<F, const L: usize> { + /// A share of the FLP verifier message. (See [`Type`](crate::flp::Type).) + verifier: Vec<F>, + + /// A part of the joint randomness seed. + joint_rand_part: Option<Seed<L>>, +} + +impl<F: FieldElement, const L: usize> Encode for Prio3PrepareShare<F, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + for x in &self.verifier { + x.encode(bytes); + } + if let Some(ref seed) = self.joint_rand_part { + seed.encode(bytes); + } + } +} + +impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>> + for Prio3PrepareShare<F, L> +{ + fn decode_with_param( + decoding_parameter: &Prio3PrepareState<F, L>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len); + for _ in 0..decoding_parameter.verifier_len { + verifier.push(F::decode(bytes)?); + } + + let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Prio3PrepareShare { + verifier, + joint_rand_part, + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +/// Result of combining a round of [`Prio3PrepareShare`] messages. +pub struct Prio3PrepareMessage<const L: usize> { + /// The joint randomness seed computed by the Aggregators. + joint_rand_seed: Option<Seed<L>>, +} + +impl<const L: usize> Encode for Prio3PrepareMessage<L> { + fn encode(&self, bytes: &mut Vec<u8>) { + if let Some(ref seed) = self.joint_rand_seed { + seed.encode(bytes); + } + } +} + +impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>> + for Prio3PrepareMessage<L> +{ + fn decode_with_param( + decoding_parameter: &Prio3PrepareState<F, L>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Prio3PrepareMessage { joint_rand_seed }) + } +} + +impl<T, P, const L: usize> Client for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + #[allow(clippy::type_complexity)] + fn shard( + &self, + measurement: &T::Measurement, + ) -> Result<((), Vec<Prio3InputShare<T::Field, L>>), VdafError> { + self.shard_with_rand_source(measurement, getrandom::getrandom) + .map(|input_shares| ((), input_shares)) + } +} + +/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Prio3PrepareState<F, const L: usize> { + input_share: Share<F, L>, + joint_rand_seed: Option<Seed<L>>, + agg_id: u8, + verifier_len: usize, +} + +impl<F: FieldElement, const L: usize> Encode for Prio3PrepareState<F, L> { + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + fn encode(&self, bytes: &mut Vec<u8>) { + self.input_share.encode(bytes); + if let Some(ref seed) = self.joint_rand_seed { + seed.encode(bytes); + } + } +} + +impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)> + for Prio3PrepareState<T::Field, L> +where + T: Type, + P: Prg<L>, +{ + fn decode_with_param( + (prio3, agg_id): &(&'a Prio3<T, P, L>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let agg_id = prio3 + .role_try_from(*agg_id) + .map_err(|e| CodecError::Other(Box::new(e)))?; + + let share_decoder = if agg_id == 0 { + ShareDecodingParameter::Leader(prio3.typ.input_len()) + } else { + ShareDecodingParameter::Helper + }; + let input_share = Share::decode_with_param(&share_decoder, bytes)?; + + let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Self { + input_share, + joint_rand_seed, + agg_id, + verifier_len: prio3.typ.verifier_len(), + }) + } +} + +impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + type PrepareState = Prio3PrepareState<T::Field, L>; + type PrepareShare = Prio3PrepareShare<T::Field, L>; + type PrepareMessage = Prio3PrepareMessage<L>; + + /// Begins the Prep process with the other aggregators. The result of this process is + /// the aggregator's output share. + #[allow(clippy::type_complexity)] + fn prepare_init( + &self, + verify_key: &[u8; L], + agg_id: usize, + _agg_param: &(), + nonce: &[u8], + _public_share: &(), + msg: &Prio3InputShare<T::Field, L>, + ) -> Result< + ( + Prio3PrepareState<T::Field, L>, + Prio3PrepareShare<T::Field, L>, + ), + VdafError, + > { + let agg_id = self.role_try_from(agg_id)?; + let mut info = [0; DST_LEN + 1]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); + info[DST_LEN] = agg_id; + let domain_separation_tag = &info[..DST_LEN]; + + let mut deriver = P::init(verify_key); + deriver.update(domain_separation_tag); + deriver.update(&[255]); + deriver.update(nonce); + let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream()); + + // Create a reference to the (expanded) input share. + let expanded_input_share: Option<Vec<T::Field>> = match msg.input_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_seed_stream(P::seed_stream(seed, &info)); + Some(prng.take(self.typ.input_len()).collect()) + } + }; + let input_share = match msg.input_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_input_share.as_ref().unwrap(), + }; + + // Create a reference to the (expanded) proof share. + let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_seed_stream(P::seed_stream(seed, &info)); + Some(prng.take(self.typ.proof_len()).collect()) + } + }; + let proof_share = match msg.proof_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_proof_share.as_ref().unwrap(), + }; + + // Compute the joint randomness. + let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 { + let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref()); + deriver.update(&info); + for x in input_share { + deriver.update(&(*x).into()); + } + let joint_rand_seed_part = deriver.into_seed(); + + let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint; + let joint_rand_seed = Self::derive_joint_randomness( + hints[..agg_id as usize] + .iter() + .chain(std::iter::once(&joint_rand_seed_part)) + .chain(hints[agg_id as usize..].iter()), + ); + + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag)); + ( + Some(joint_rand_seed), + Some(joint_rand_seed_part), + prng.take(self.typ.joint_rand_len()).collect(), + ) + } else { + (None, None, Vec::new()) + }; + + // Compute the query randomness. + let query_rand: Vec<T::Field> = query_rand_prng.take(self.typ.query_rand_len()).collect(); + + // Run the query-generation algorithm. + let verifier_share = self.typ.query( + input_share, + proof_share, + &query_rand, + &joint_rand, + self.num_aggregators as usize, + )?; + + Ok(( + Prio3PrepareState { + input_share: msg.input_share.clone(), + joint_rand_seed, + agg_id, + verifier_len: verifier_share.len(), + }, + Prio3PrepareShare { + verifier: verifier_share, + joint_rand_part: joint_rand_seed_part, + }, + )) + } + + fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, L>>>( + &self, + inputs: M, + ) -> Result<Prio3PrepareMessage<L>, VdafError> { + let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()]; + let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators()); + let mut count = 0; + for share in inputs.into_iter() { + count += 1; + + if share.verifier.len() != verifier.len() { + return Err(VdafError::Uncategorized(format!( + "unexpected verifier share length: got {}; want {}", + share.verifier.len(), + verifier.len(), + ))); + } + + if self.typ.joint_rand_len() > 0 { + let joint_rand_seed_part = share.joint_rand_part.unwrap(); + joint_rand_parts.push(joint_rand_seed_part); + } + + for (x, y) in verifier.iter_mut().zip(share.verifier) { + *x += y; + } + } + + if count != self.num_aggregators { + return Err(VdafError::Uncategorized(format!( + "unexpected message count: got {}; want {}", + count, self.num_aggregators, + ))); + } + + // Check the proof verifier. + match self.typ.decide(&verifier) { + Ok(true) => (), + Ok(false) => { + return Err(VdafError::Uncategorized( + "proof verifier check failed".into(), + )) + } + Err(err) => return Err(VdafError::from(err)), + }; + + let joint_rand_seed = if self.typ.joint_rand_len() > 0 { + Some(Self::derive_joint_randomness(joint_rand_parts.iter())) + } else { + None + }; + + Ok(Prio3PrepareMessage { joint_rand_seed }) + } + + fn prepare_step( + &self, + step: Prio3PrepareState<T::Field, L>, + msg: Prio3PrepareMessage<L>, + ) -> Result<PrepareTransition<Self, L>, VdafError> { + if self.typ.joint_rand_len() > 0 { + // Check that the joint randomness was correct. + if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() { + return Err(VdafError::Uncategorized( + "joint randomness mismatch".to_string(), + )); + } + } + + // Compute the output share. + let input_share = match step.input_share { + Share::Leader(data) => data, + Share::Helper(seed) => { + let mut info = [0; DST_LEN + 1]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); + info[DST_LEN] = step.agg_id; + let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info)); + prng.take(self.typ.input_len()).collect() + } + }; + + let output_share = match self.typ.truncate(input_share) { + Ok(data) => OutputShare(data), + Err(err) => { + return Err(VdafError::from(err)); + } + }; + + Ok(PrepareTransition::Finish(output_share)) + } + + /// Aggregates a sequence of output shares into an aggregate share. + fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>( + &self, + _agg_param: &(), + output_shares: It, + ) -> Result<AggregateShare<T::Field>, VdafError> { + let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + + Ok(agg_share) + } +} + +impl<T, P, const L: usize> Collector for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + /// Combines aggregate shares into the aggregate result. + fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>( + &self, + _agg_param: &(), + agg_shares: It, + num_measurements: usize, + ) -> Result<T::AggregateResult, VdafError> { + let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + for agg_share in agg_shares.into_iter() { + agg.merge(&agg_share)?; + } + + Ok(self.typ.decode_result(&agg.0, num_measurements)?) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct JointRandParam<const L: usize> { + /// The joint randomness seed parts corresponding to the other Aggregators' shares. + seed_hint: Vec<Seed<L>>, + + /// The blinding factor, used to derive the aggregator's joint randomness seed part. + blind: Seed<L>, +} + +#[derive(Clone)] +struct HelperShare<const L: usize> { + input_share: Seed<L>, + proof_share: Seed<L>, + joint_rand_param: JointRandParam<L>, +} + +impl<const L: usize> HelperShare<L> { + fn from_rand_source(rand_source: RandSource) -> Result<Self, VdafError> { + Ok(HelperShare { + input_share: Seed::from_rand_source(rand_source)?, + proof_share: Seed::from_rand_source(rand_source)?, + joint_rand_param: JointRandParam { + seed_hint: Vec::new(), + blind: Seed::from_rand_source(rand_source)?, + }, + }) + } +} + +fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> { + if num_aggregators == 0 { + return Err(VdafError::Uncategorized(format!( + "at least one aggregator is required; got {}", + num_aggregators + ))); + } else if num_aggregators > 254 { + return Err(VdafError::Uncategorized(format!( + "number of aggregators must not exceed 254; got {}", + num_aggregators + ))); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vdaf::{run_vdaf, run_vdaf_prepare}; + use assert_matches::assert_matches; + use rand::prelude::*; + + #[test] + fn test_prio3_count() { + let prio3 = Prio3::new_aes128_count(2).unwrap(); + + assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = b"This is a good nonce."; + + let (public_share, input_shares) = prio3.shard(&0).unwrap(); + run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap(); + + let (public_share, input_shares) = prio3.shard(&1).unwrap(); + run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap(); + + test_prepare_state_serialization(&prio3, &1).unwrap(); + + let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap(); + assert_eq!( + run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(), + 3, + ); + } + + #[test] + fn test_prio3_sum() { + let prio3 = Prio3::new_aes128_sum(3, 16).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), + (1 << 16) + 1 + ); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = b"This is a good nonce."; + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255; + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255; + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + test_prepare_state_serialization(&prio3, &1).unwrap(); + } + + #[test] + fn test_prio3_countvec() { + let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap(); + assert_eq!( + run_vdaf( + &prio3, + &(), + [vec![ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, + ]] + ) + .unwrap(), + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,] + ); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_prio3_countvec_multithreaded() { + let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap(); + assert_eq!( + run_vdaf( + &prio3, + &(), + [vec![ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, + ]] + ) + .unwrap(), + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,] + ); + } + + #[test] + fn test_prio3_histogram() { + let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(), + vec![1, 1, 1, 1] + ); + assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]); + assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]); + assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]); + test_prepare_state_serialization(&prio3, &23).unwrap(); + } + + #[test] + fn test_prio3_average() { + let prio3 = Prio3::new_aes128_average(2, 64).unwrap(); + + assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64); + assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); + assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64); + assert_eq!( + run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), + 207.5f64 + ); + } + + #[test] + fn test_prio3_input_share() { + let prio3 = Prio3::new_aes128_sum(5, 16).unwrap(); + let (_public_share, input_shares) = prio3.shard(&1).unwrap(); + + // Check that seed shares are distinct. + for (i, x) in input_shares.iter().enumerate() { + for (j, y) in input_shares.iter().enumerate() { + if i != j { + if let (Share::Helper(left), Share::Helper(right)) = + (&x.input_share, &y.input_share) + { + assert_ne!(left, right); + } + + if let (Share::Helper(left), Share::Helper(right)) = + (&x.proof_share, &y.proof_share) + { + assert_ne!(left, right); + } + + assert_ne!(x.joint_rand_param, y.joint_rand_param); + } + } + } + } + + fn test_prepare_state_serialization<T, P, const L: usize>( + prio3: &Prio3<T, P, L>, + measurement: &T::Measurement, + ) -> Result<(), VdafError> + where + T: Type, + P: Prg<L>, + { + let mut verify_key = [0; L]; + thread_rng().fill(&mut verify_key[..]); + let (public_share, input_shares) = prio3.shard(measurement)?; + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (want, _msg) = + prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?; + let got = + Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded()) + .expect("failed to decode prepare step"); + assert_eq!(got, want); + } + Ok(()) + } +} diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs new file mode 100644 index 0000000000..d4c9151ce0 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio3_test.rs @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{ + codec::{Encode, ParameterizedDecode}, + flp::Type, + vdaf::{ + prg::Prg, + prio3::{Prio3, Prio3InputShare, Prio3PrepareShare}, + Aggregator, PrepareTransition, + }, +}; +use serde::{Deserialize, Serialize}; +use std::{convert::TryInto, fmt::Debug}; + +#[derive(Debug, Deserialize, Serialize)] +struct TEncoded(#[serde(with = "hex")] Vec<u8>); + +impl AsRef<[u8]> for TEncoded { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Deserialize, Serialize)] +struct TPrio3Prep<M> { + measurement: M, + #[serde(with = "hex")] + nonce: Vec<u8>, + input_shares: Vec<TEncoded>, + prep_shares: Vec<Vec<TEncoded>>, + prep_messages: Vec<TEncoded>, + out_shares: Vec<Vec<M>>, +} + +#[derive(Deserialize, Serialize)] +struct TPrio3<M> { + verify_key: TEncoded, + prep: Vec<TPrio3Prep<M>>, +} + +macro_rules! err { + ( + $test_num:ident, + $error:expr, + $msg:expr + ) => { + panic!("test #{} failed: {} err: {}", $test_num, $msg, $error) + }; +} + +// TODO Generalize this method to work with any VDAF. To do so we would need to add +// `test_vec_setup()` and `test_vec_shard()` to traits. (There may be a less invasive alternative.) +fn check_prep_test_vec<M, T, P, const L: usize>( + prio3: &Prio3<T, P, L>, + verify_key: &[u8; L], + test_num: usize, + t: &TPrio3Prep<M>, +) where + T: Type<Measurement = M>, + P: Prg<L>, + M: From<<T as Type>::Field> + Debug + PartialEq, +{ + let input_shares = prio3 + .test_vec_shard(&t.measurement) + .expect("failed to generate input shares"); + + assert_eq!(2, t.input_shares.len(), "#{}", test_num); + for (agg_id, want) in t.input_shares.iter().enumerate() { + assert_eq!( + input_shares[agg_id], + Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")), + "#{}", + test_num + ); + assert_eq!( + input_shares[agg_id].get_encoded(), + want.as_ref(), + "#{}", + test_num + ) + } + + let mut states = Vec::new(); + let mut prep_shares = Vec::new(); + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (state, prep_share) = prio3 + .prepare_init(verify_key, agg_id, &(), &t.nonce, &(), input_share) + .unwrap_or_else(|e| err!(test_num, e, "prep state init")); + states.push(state); + prep_shares.push(prep_share); + } + + assert_eq!(1, t.prep_shares.len(), "#{}", test_num); + for (i, want) in t.prep_shares[0].iter().enumerate() { + assert_eq!( + prep_shares[i], + Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")), + "#{}", + test_num + ); + assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{}", test_num); + } + + let inbound = prio3 + .prepare_preprocess(prep_shares) + .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); + assert_eq!(t.prep_messages.len(), 1); + assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref()); + + let mut out_shares = Vec::new(); + for state in states.iter_mut() { + match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() { + PrepareTransition::Finish(out_share) => { + out_shares.push(out_share); + } + _ => panic!("unexpected transition"), + } + } + + for (got, want) in out_shares.iter().zip(t.out_shares.iter()) { + let got: Vec<M> = got.as_ref().iter().map(|x| M::from(*x)).collect(); + assert_eq!(&got, want); + } +} + +#[test] +fn test_vec_prio3_count() { + let t: TPrio3<u64> = + serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Count_0.json")).unwrap(); + let prio3 = Prio3::new_aes128_count(2).unwrap(); + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + for (test_num, p) in t.prep.iter().enumerate() { + check_prep_test_vec(&prio3, &verify_key, test_num, p); + } +} + +#[test] +fn test_vec_prio3_sum() { + let t: TPrio3<u128> = + serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Sum_0.json")).unwrap(); + let prio3 = Prio3::new_aes128_sum(2, 8).unwrap(); + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + for (test_num, p) in t.prep.iter().enumerate() { + check_prep_test_vec(&prio3, &verify_key, test_num, p); + } +} + +#[test] +fn test_vec_prio3_histogram() { + let t: TPrio3<u128> = + serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Histogram_0.json")).unwrap(); + let prio3 = Prio3::new_aes128_histogram(2, &[1, 10, 100]).unwrap(); + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + for (test_num, p) in t.prep.iter().enumerate() { + check_prep_test_vec(&prio3, &verify_key, test_num, p); + } +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json b/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json new file mode 100644 index 0000000000..e450665173 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json @@ -0,0 +1,7 @@ +{ + "derived_seed": "ccf3be704c982182ad2961e9795a88aa", + "expanded_vec_field128": "ccf3be704c982182ad2961e9795a88aa8df71c0b5ea5c13bcf3173c3f3626505e1bf4738874d5405805082cc38c55d1f04f85fbb88b8cf8592ffed8a4ac7f76991c58d850a15e8deb34fb289ab6fab584554ffef16c683228db2b76e792ca4f3c15760044d0703b438c2aefd7975c5dd4b9992ee6f87f20e570572dea18fa580ee17204903c1234f1332d47a442ea636580518ce7aa5943c415117460a049bc19cc81edbb0114d71890cbdbe4ea2664cd038e57b88fb7fd3557830ad363c20b9840d35fd6bee6c3c8424f026ee7fbca3daf3c396a4d6736d7bd3b65b2c228d22a40f4404e47c61b26ac3c88bebf2f268fa972f8831f18bee374a22af0f8bb94d9331a1584bdf8cf3e8a5318b546efee8acd28f6cba8b21b9d52acbae8e726500340da98d643d0a5f1270ecb94c574130cea61224b0bc6d438b2f4f74152e72b37e6a9541c9dc5515f8f98fd0d1bce8743f033ab3e8574180ffc3363f3a0490f6f9583bf73a87b9bb4b51bfd0ef260637a4288c37a491c6cbdc46b6a86cd26edf611793236e912e7227bfb85b560308b06238bbd978f72ed4a58583cf0c6e134066eb6b399ad2f26fa01d69a62d8a2d04b4b8acf82299b07a834d4c2f48fee23a24c20307f9cabcd34b6d69f1969588ebde777e46e9522e866e6dd1e14119a1cb4c0709fa9ea347d9f872e76a39313e7d49bfbf3e5ce807183f43271ba2b5c6aaeaef22da301327c1fd9fedde7c5a68d9b97fa6eb687ec8ca692cb0f631f46e6699a211a1254026c9a0a43eceb450dc97cfa923321baf1f4b6f233260d46182b844dccec153aaddd20f920e9e13ff11434bcd2aa632bf4f544f41b5ddced962939676476f70e0b8640c3471fc7af62d80053781295b070388f7b7f1fa66220cb3", + "info": "696e666f20737472696e67", + "length": 40, + "seed": "01010101010101010101010101010101" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json new file mode 100644 index 0000000000..9e79888745 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json @@ -0,0 +1,37 @@ +{ + "agg_param": null, + "agg_result": 1, + "agg_shares": [ + "ad8bb894e3222b47", + "5274476a1cddd4bb" + ], + "prep": [ + { + "input_shares": [ + "ad8bb894e3222b47b70eb67d4f70cb78644826d67d31129e422b910cf0aab70c0b78fa57b4a7b3aaafae57bd1012e813", + "0101010101010101010101010101010101010101010101010101010101010101" + ], + "measurement": 1, + "nonce": "01010101010101010101010101010101", + "out_shares": [ + [ + 12505291739929652039 + ], + [ + 5941452329484932283 + ] + ], + "prep_messages": [ + "" + ], + "prep_shares": [ + [ + "38d535dd68f3c02ed6681f7ff24239d46fde93c8402d24ebbafa25c77ca3535d", + "c72aca21970c3fd35274476a1cddd4bb4efa24ee0d71473e4a0a23713a347d78" + ] + ], + "public_share": "" + } + ], + "verify_key": "01010101010101010101010101010101" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json new file mode 100644 index 0000000000..f5476455fa --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json @@ -0,0 +1,53 @@ +{ + "agg_param": null, + "agg_result": [ + 0, + 0, + 1, + 0 + ], + "agg_shares": [ + "ee1076c1ebc2d48a557a71031bc9dd5c9cd5e91180bbb51f4ac366946bcbfa93b908792bd15d402f4ac8da264e24a20f645ef68472180c5894bac704ae0675d7", + "11ef893e143d2b59aa858efce43622a5632a16ee7f444ac4b53c996b9434056e46f786d42ea2bfb4b53725d9b1db5df39ba1097b8de7f38b6b4538fb51f98a2a" + ], + "buckets": [ + 1, + 10, + 100 + ], + "prep": [ + { + "input_shares": [ + "ee1076c1ebc2d48a557a71031bc9dd5c9cd5e91180bbb51f4ac366946bcbfa93b908792bd15d402f4ac8da264e24a20f645ef68472180c5894bac704ae0675d7f16776df4f93852a40b514593a73be51ad64d8c28322a47af92c92223dd489998a3c6687861cdc2e4d834885d03d8d3273af0bf742c47985ae8fec6d16c31216792bb0cdca0d1d1fa2287414cd069f8caa42dc08f78dd43e14c4095e2ef9d9609937caebcd534e813136e79a4233e873397a6c7fd164928d43673b32e061139dc6650152d8433e2342f595149418929b74c9e23f1469ed1eebdaa57d0b5c62f90cb5a53dc68c8e030448bb2d9c07aeed50d82c93e1afe8febd68918933ed9b2dd36b9d8a35fd6c57cd76707011fca77526437aeb8392a2013f829c1e395f7f8ddef030f5bc869833f528ae2137a2e667aa648d8643f6c13e8d76e8832ab9ef7d0101010101010101010101010101010194c3f0f1061c8f440b51f806ad822510", + "0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101016195ec204fd5d65c14fac36b73723cde" + ], + "measurement": 50, + "nonce": "01010101010101010101010101010101", + "out_shares": [ + [ + 316441748434879643753815489063091297628, + 208470253761472213750543248431791209107, + 245951175238245845331446316072865931791, + 133415875449384174923011884997795018199 + ], + [ + 23840618486058819193050284304809468581, + 131812113159466249196322524936109557102, + 94331191682692617615419457295034834419, + 206866491471554288023853888370105748010 + ] + ], + "prep_messages": [ + "7912f1157c2ce3a4dca6456224aeaeea" + ], + "prep_shares": [ + [ + "f2dc9e823b867d760b2169644633804eabec10e5869fe8f3030c5da6dc0fce03a433572cb8aaa7ca3559959f7bad68306195ec204fd5d65c14fac36b73723cde", + "0d23617dc479826df4de969bb9cc7fb3f5e934542e987db0271aee33551b28a4c16f7ad00127c43df9c433a1c224594d94c3f0f1061c8f440b51f806ad822510" + ] + ], + "public_share": "" + } + ], + "verify_key": "01010101010101010101010101010101" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json new file mode 100644 index 0000000000..55d2a000db --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json @@ -0,0 +1,38 @@ +{ + "agg_param": null, + "agg_result": 100, + "agg_shares": [ + "ab3bcc5ef693737a7a3e76cd9face3e0", + "54c433a1096c8c6985c1893260531c85" + ], + "bits": 8, + "prep": [ + { + "input_shares": [ + "fc1e42a024e3d8b45f63f485ebe2dc8a356c5e5780a0bd751184e6a02a96c0767f518e87282ebdc039590aef02e40e5492c9eb69dd22b6b4f1d630e7ca8612b7a7e090b39460bc4036345f5ef537d691fd585bc05a2ea580c7e354680afd0fd49f3d083d5e383b97755a842cf5e69870a970b14a10595c0c639ad2e7bda42c7146c4b69fd79e7403d89dac5816d0dc6f2bb987fccca4c4aee64444b7f46431433c59c6e7f2839fe2b7ad9316d31a52dcc0df07f1da14aa38e0cd88de380fda29b33704e8c3439376762739aa5b5cff9e925939773d24ca0e75bcf87149c9bcc2f8462afa6b50513ab003ac00c9ae3685ea52bdee3c814ffd5afc8357d93454b3ffaf0b5e9fd351730f0d55aed54a9cfa86f9119601ce9857ee0af3f579251bcc7ffe51b8393adc36ab6142eb0e0d07c9b2d5ab71d8d5639f32c61f7d59b45a95129cbc76d7e30c02a1329454f843553413d4e84bcab2c3ba1a0150292026dfa37488da5dd639c53edd51bf4eb5aa54d5b165fcd55d10f3496008f4e3b6d3eb200c19c5b9c42ad4f12977a857d02f787b14ced27fc5eefb05722b372a7d48c1891d30a32d84ec8d1f9a783a38bfac2793f0da6796cff90521e1d73f497f7d2c910b7fbbea2ba4b906d437a53bcbed16986f5646fd238e736f1c3e9d3a910218ce7f48dea3e9a1a848c580a1c506a80edb0c0a973a269667475ce88f4424674b14a3a8f2b71ef529d2ca96a3c5e4da384545749a55188d4de0074ad601695e934c9fe71d27c139b7678ead7f904cd2ae2a3aafa96d8211579e391507df96bf42c383f2ac71d7a558ebf1e3d5ab086b65422415bd24be9c979ca5b4f381d51b06ec4f6740b1a084999cd95fe63fec4a019f635640ba18d42312de7d1994947502b9010101010101010101010101010101015f0721f50826593dc3908dad39353846", + "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101094240ceae2d63ba1bdda997fa0bcbd8" + ], + "measurement": 100, + "nonce": "01010101010101010101010101010101", + "out_shares": [ + [ + 227608477929192160221239678567201956832 + ], + [ + 112673888991746302725626094800698809477 + ] + ], + "prep_messages": [ + "60af733578d766f2305c1d53c840b4b5" + ], + "prep_shares": [ + [ + "0a85b5e51cacf514ee9e9bbe5d3ac023795e910b765411e5cea8ff187973640694bd740cc15bc9cad60bc85785206062094240ceae2d63ba1bdda997fa0bcbd8", + "f57a4a1ae3530acf11616441a2c53fde804d262dc42e15e556ee02c588c3ca9d924eefa735a95f6e420f2c5161706e025f0721f50826593dc3908dad39353846" + ] + ], + "public_share": "" + } + ], + "verify_key": "01010101010101010101010101010101" +} |