summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio3.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/prio/src/vdaf/prio3.rs
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio3.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prio3.rs1168
1 files changed, 1168 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs
new file mode 100644
index 0000000000..31853f15ab
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio3.rs
@@ -0,0 +1,1168 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! **WARNING:** Neither this code nor the cryptographic construction it implements has undergone
+//! significant security analysis. Use at your own risk.
+//!
+//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented
+//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO
+//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication
+//! cost.
+//!
+//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-03]] into
+//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of
+//! which are instantiated here:
+//!
+//! - [`Prio3Aes128Count`] for aggregating a counter (*)
+//! - [`Prio3Aes128CountVec`] for aggregating a vector of counters
+//! - [`Prio3Aes128Sum`] for copmputing the sum of integers (*)
+//! - [`Prio3Aes128Histogram`] for estimating a distribution via a histogram (*)
+//!
+//! Additional types can be constructed from [`Prio3`] as needed.
+//!
+//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! [BBCG+19]: https://ia.cr/2019/188
+//! [CGB17]: https://crypto.stanford.edu/prio/
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+#[cfg(feature = "crypto-dependencies")]
+use super::prg::PrgAes128;
+use super::{DST_LEN, VERSION};
+use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
+use crate::field::FieldElement;
+#[cfg(feature = "crypto-dependencies")]
+use crate::field::{Field128, Field64};
+#[cfg(feature = "multithreaded")]
+use crate::flp::gadgets::ParallelSumMultithreaded;
+#[cfg(feature = "crypto-dependencies")]
+use crate::flp::gadgets::{BlindPolyEval, ParallelSum};
+#[cfg(feature = "crypto-dependencies")]
+use crate::flp::types::{Average, Count, CountVec, Histogram, Sum};
+use crate::flp::Type;
+use crate::prng::Prng;
+use crate::vdaf::prg::{Prg, RandSource, Seed};
+use crate::vdaf::{
+ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
+ Share, ShareDecodingParameter, Vdaf, VdafError,
+};
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::io::Cursor;
+use std::iter::IntoIterator;
+use std::marker::PhantomData;
+
+/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Count = Prio3<Count<Field64>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Count {
+ /// Construct an instance of Prio3Aes128Count with the given number of aggregators.
+ pub fn new_aes128_count(num_aggregators: u8) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, Count::new())
+ }
+}
+
+/// The count-vector type. Each measurement is a vector of integers in `[0,2)` and the aggregate is
+/// the element-wise sum.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128CountVec =
+ Prio3<CountVec<Field128, ParallelSum<Field128, BlindPolyEval<Field128>>>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128CountVec {
+ /// Construct an instance of Prio3Aes1238CountVec with the given number of aggregators. `len`
+ /// defines the length of each measurement.
+ pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, CountVec::new(len))
+ }
+}
+
+/// Like [`Prio3Aes128CountVec`] except this type uses multithreading to improve sharding and
+/// preparation time. Note that the improvement is only noticeable for very large input lengths,
+/// e.g., 201 and up. (Your system's mileage may vary.)
+#[cfg(feature = "multithreaded")]
+#[cfg(feature = "crypto-dependencies")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+pub type Prio3Aes128CountVecMultithreaded = Prio3<
+ CountVec<Field128, ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>>,
+ PrgAes128,
+ 16,
+>;
+
+#[cfg(feature = "multithreaded")]
+#[cfg(feature = "crypto-dependencies")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+impl Prio3Aes128CountVecMultithreaded {
+ /// Construct an instance of Prio3Aes1238CountVecMultithreaded with the given number of
+ /// aggregators. `len` defines the length of each measurement.
+ pub fn new_aes128_count_vec_multithreaded(
+ num_aggregators: u8,
+ len: usize,
+ ) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, CountVec::new(len))
+ }
+}
+
+/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
+/// aggregate is the sum.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Sum = Prio3<Sum<Field128>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Sum {
+ /// Construct an instance of Prio3Aes128Sum with the given number of aggregators and required
+ /// bit length. The bit length must not exceed 64.
+ pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({}) exceeds limit for aggregate type (64)",
+ bits
+ )));
+ }
+
+ Prio3::new(num_aggregators, Sum::new(bits as usize)?)
+ }
+}
+
+/// The histogram type. Each measurement is an unsigned integer and the result is a histogram
+/// representation of the distribution. The bucket boundaries are fixed in advance.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Histogram = Prio3<Histogram<Field128>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Histogram {
+ /// Constructs an instance of Prio3Aes128Histogram with the given number of aggregators and
+ /// desired histogram bucket boundaries.
+ pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result<Self, VdafError> {
+ let buckets = buckets.iter().map(|bucket| *bucket as u128).collect();
+
+ Prio3::new(num_aggregators, Histogram::new(buckets)?)
+ }
+}
+
+/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and
+/// the aggregate is the arithmetic average.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Average = Prio3<Average<Field128>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Average {
+ /// Construct an instance of Prio3Aes128Average with the given number of aggregators and
+ /// required bit length. The bit length must not exceed 64.
+ pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({}) exceeds limit for aggregate type (64)",
+ bits
+ )));
+ }
+
+ Ok(Prio3 {
+ num_aggregators,
+ typ: Average::new(bits as usize)?,
+ phantom: PhantomData,
+ })
+ }
+}
+
+/// The base type for Prio3.
+///
+/// An instance of Prio3 is determined by:
+///
+/// - a [`Type`](crate::flp::Type) that defines the set of valid input measurements; and
+/// - a [`Prg`](crate::vdaf::prg::Prg) for deriving vectors of field elements from seeds.
+///
+/// New instances can be defined by aliasing the base type. For example, [`Prio3Aes128Count`] is an
+/// alias for `Prio3<Count<Field64>, PrgAes128, 16>`.
+///
+/// ```
+/// use prio::vdaf::{
+/// Aggregator, Client, Collector, PrepareTransition,
+/// prio3::Prio3,
+/// };
+/// use rand::prelude::*;
+///
+/// let num_shares = 2;
+/// let vdaf = Prio3::new_aes128_count(num_shares).unwrap();
+///
+/// let mut out_shares = vec![vec![]; num_shares.into()];
+/// let mut rng = thread_rng();
+/// let verify_key = rng.gen();
+/// let measurements = [0, 1, 1, 1, 0];
+/// for measurement in measurements {
+/// // Shard
+/// let (public_share, input_shares) = vdaf.shard(&measurement).unwrap();
+/// let mut nonce = [0; 16];
+/// rng.fill(&mut nonce);
+///
+/// // Prepare
+/// let mut prep_states = vec![];
+/// let mut prep_shares = vec![];
+/// for (agg_id, input_share) in input_shares.iter().enumerate() {
+/// let (state, share) = vdaf.prepare_init(
+/// &verify_key,
+/// agg_id,
+/// &(),
+/// &nonce,
+/// &public_share,
+/// input_share
+/// ).unwrap();
+/// prep_states.push(state);
+/// prep_shares.push(share);
+/// }
+/// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap();
+///
+/// for (agg_id, state) in prep_states.into_iter().enumerate() {
+/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() {
+/// PrepareTransition::Finish(out_share) => out_share,
+/// _ => panic!("unexpected transition"),
+/// };
+/// out_shares[agg_id].push(out_share);
+/// }
+/// }
+///
+/// // Aggregate
+/// let agg_shares = out_shares.into_iter()
+/// .map(|o| vdaf.aggregate(&(), o).unwrap());
+///
+/// // Unshard
+/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap();
+/// assert_eq!(agg_res, 3);
+/// ```
+///
+/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+#[derive(Clone, Debug)]
+pub struct Prio3<T, P, const L: usize>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ num_aggregators: u8,
+ typ: T,
+ phantom: PhantomData<P>,
+}
+
+impl<T, P, const L: usize> Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the
+ /// underlying type.
+ pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+ Ok(Self {
+ num_aggregators,
+ typ,
+ phantom: PhantomData,
+ })
+ }
+
+ /// The output length of the underlying FLP.
+ pub fn output_len(&self) -> usize {
+ self.typ.output_len()
+ }
+
+ /// The verifier length of the underlying FLP.
+ pub fn verifier_len(&self) -> usize {
+ self.typ.verifier_len()
+ }
+
+ fn derive_joint_randomness<'a>(parts: impl Iterator<Item = &'a Seed<L>>) -> Seed<L> {
+ let mut info = [0; VERSION.len() + 5];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes());
+ info[VERSION.len() + 4] = 255;
+ let mut deriver = P::init(&[0; L]);
+ deriver.update(&info);
+ for part in parts {
+ deriver.update(part.as_ref());
+ }
+ deriver.into_seed()
+ }
+
+ fn shard_with_rand_source(
+ &self,
+ measurement: &T::Measurement,
+ rand_source: RandSource,
+ ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
+ let mut info = [0; DST_LEN + 1];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
+
+ let num_aggregators = self.num_aggregators;
+ let input = self.typ.encode_measurement(measurement)?;
+
+ // Generate the input shares and compute the joint randomness.
+ let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
+ let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
+ Some(Vec::with_capacity(num_aggregators as usize - 1))
+ } else {
+ None
+ };
+ let mut leader_input_share = input.clone();
+ for agg_id in 1..num_aggregators {
+ let helper = HelperShare::from_rand_source(rand_source)?;
+
+ let mut deriver = P::init(helper.joint_rand_param.blind.as_ref());
+ info[DST_LEN] = agg_id;
+ deriver.update(&info);
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info));
+ for (x, y) in leader_input_share
+ .iter_mut()
+ .zip(prng)
+ .take(self.typ.input_len())
+ {
+ *x -= y;
+ deriver.update(&y.into());
+ }
+
+ if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() {
+ helper_joint_rand_parts.push(deriver.into_seed());
+ }
+ helper_shares.push(helper);
+ }
+
+ let leader_blind = Seed::from_rand_source(rand_source)?;
+
+ info[DST_LEN] = 0; // ID of the leader
+ let mut deriver = P::init(leader_blind.as_ref());
+ deriver.update(&info);
+ for x in leader_input_share.iter() {
+ deriver.update(&(*x).into());
+ }
+
+ let leader_joint_rand_seed_part = deriver.into_seed();
+
+ // Compute the joint randomness seed.
+ let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| {
+ Self::derive_joint_randomness(
+ std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()),
+ )
+ });
+
+ // Run the proof-generation algorithm.
+ let domain_separation_tag = &info[..DST_LEN];
+ let joint_rand: Vec<T::Field> = joint_rand_seed
+ .map(|joint_rand_seed| {
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
+ prng.take(self.typ.joint_rand_len()).collect()
+ })
+ .unwrap_or_default();
+ let prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
+ &Seed::from_rand_source(rand_source)?,
+ domain_separation_tag,
+ ));
+ let prove_rand: Vec<T::Field> = prng.take(self.typ.prove_rand_len()).collect();
+ let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?;
+
+ // Generate the proof shares and distribute the joint randomness seed hints.
+ for (j, helper) in helper_shares.iter_mut().enumerate() {
+ info[DST_LEN] = j as u8 + 1;
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info));
+ for (x, y) in leader_proof_share
+ .iter_mut()
+ .zip(prng)
+ .take(self.typ.proof_len())
+ {
+ *x -= y;
+ }
+
+ if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() {
+ let mut hint = Vec::with_capacity(num_aggregators as usize - 1);
+ hint.push(leader_joint_rand_seed_part.clone());
+ hint.extend(helper_joint_rand_parts[..j].iter().cloned());
+ hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned());
+ helper.joint_rand_param.seed_hint = hint;
+ }
+ }
+
+ let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 {
+ Some(JointRandParam {
+ seed_hint: helper_joint_rand_parts.unwrap_or_default(),
+ blind: leader_blind,
+ })
+ } else {
+ None
+ };
+
+ // Prep the output messages.
+ let mut out = Vec::with_capacity(num_aggregators as usize);
+ out.push(Prio3InputShare {
+ input_share: Share::Leader(leader_input_share),
+ proof_share: Share::Leader(leader_proof_share),
+ joint_rand_param: leader_joint_rand_param,
+ });
+
+ for helper in helper_shares.into_iter() {
+ let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 {
+ Some(helper.joint_rand_param)
+ } else {
+ None
+ };
+
+ out.push(Prio3InputShare {
+ input_share: Share::Helper(helper.input_share),
+ proof_share: Share::Helper(helper.proof_share),
+ joint_rand_param: helper_joint_rand_param,
+ });
+ }
+
+ Ok(out)
+ }
+
+ /// Shard measurement with constant randomness of repeated bytes.
+ /// This method is not secure. It is used for running test vectors for Prio3.
+ #[cfg(feature = "test-util")]
+ pub fn test_vec_shard(
+ &self,
+ measurement: &T::Measurement,
+ ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
+ self.shard_with_rand_source(measurement, |buf| {
+ buf.fill(1);
+ Ok(())
+ })
+ }
+
+ fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
+ if agg_id >= self.num_aggregators as usize {
+ return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
+ }
+ Ok(u8::try_from(agg_id).unwrap())
+ }
+}
+
+impl<T, P, const L: usize> Vdaf for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ const ID: u32 = T::ID;
+ type Measurement = T::Measurement;
+ type AggregateResult = T::AggregateResult;
+ type AggregationParam = ();
+ type PublicShare = ();
+ type InputShare = Prio3InputShare<T::Field, L>;
+ type OutputShare = OutputShare<T::Field>;
+ type AggregateShare = AggregateShare<T::Field>;
+
+ fn num_aggregators(&self) -> usize {
+ self.num_aggregators as usize
+ }
+}
+
+/// Message sent by the [`Client`](crate::vdaf::Client) to each
+/// [`Aggregator`](crate::vdaf::Aggregator) during the Sharding phase.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Prio3InputShare<F, const L: usize> {
+ /// The input share.
+ input_share: Share<F, L>,
+
+ /// The proof share.
+ proof_share: Share<F, L>,
+
+ /// Parameters used by the Aggregator to compute the joint randomness. This field is optional
+ /// because not every [`Type`](`crate::flp::Type`) requires joint randomness.
+ joint_rand_param: Option<JointRandParam<L>>,
+}
+
+impl<F: FieldElement, const L: usize> Encode for Prio3InputShare<F, L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if matches!(
+ (&self.input_share, &self.proof_share),
+ (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
+ ) {
+ panic!("tried to encode input share with ambiguous encoding")
+ }
+
+ self.input_share.encode(bytes);
+ self.proof_share.encode(bytes);
+ if let Some(ref param) = self.joint_rand_param {
+ param.blind.encode(bytes);
+ for part in param.seed_hint.iter() {
+ part.encode(bytes);
+ }
+ }
+ }
+}
+
+impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
+ for Prio3InputShare<T::Field, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ fn decode_with_param(
+ (prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let agg_id = prio3
+ .role_try_from(*agg_id)
+ .map_err(|e| CodecError::Other(Box::new(e)))?;
+ let (input_decoder, proof_decoder) = if agg_id == 0 {
+ (
+ ShareDecodingParameter::Leader(prio3.typ.input_len()),
+ ShareDecodingParameter::Leader(prio3.typ.proof_len()),
+ )
+ } else {
+ (
+ ShareDecodingParameter::Helper,
+ ShareDecodingParameter::Helper,
+ )
+ };
+
+ let input_share = Share::decode_with_param(&input_decoder, bytes)?;
+ let proof_share = Share::decode_with_param(&proof_decoder, bytes)?;
+ let joint_rand_param = if prio3.typ.joint_rand_len() > 0 {
+ let num_aggregators = prio3.num_aggregators();
+ let blind = Seed::decode(bytes)?;
+ let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes))
+ .take(num_aggregators - 1)
+ .collect::<Result<Vec<_>, _>>()?;
+ Some(JointRandParam { blind, seed_hint })
+ } else {
+ None
+ };
+
+ Ok(Prio3InputShare {
+ input_share,
+ proof_share,
+ joint_rand_param,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+/// Message broadcast by each [`Aggregator`](crate::vdaf::Aggregator) in each round of the
+/// Preparation phase.
+pub struct Prio3PrepareShare<F, const L: usize> {
+ /// A share of the FLP verifier message. (See [`Type`](crate::flp::Type).)
+ verifier: Vec<F>,
+
+ /// A part of the joint randomness seed.
+ joint_rand_part: Option<Seed<L>>,
+}
+
+impl<F: FieldElement, const L: usize> Encode for Prio3PrepareShare<F, L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ for x in &self.verifier {
+ x.encode(bytes);
+ }
+ if let Some(ref seed) = self.joint_rand_part {
+ seed.encode(bytes);
+ }
+ }
+}
+
+impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
+ for Prio3PrepareShare<F, L>
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3PrepareState<F, L>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len);
+ for _ in 0..decoding_parameter.verifier_len {
+ verifier.push(F::decode(bytes)?);
+ }
+
+ let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareShare {
+ verifier,
+ joint_rand_part,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+/// Result of combining a round of [`Prio3PrepareShare`] messages.
+pub struct Prio3PrepareMessage<const L: usize> {
+ /// The joint randomness seed computed by the Aggregators.
+ joint_rand_seed: Option<Seed<L>>,
+}
+
+impl<const L: usize> Encode for Prio3PrepareMessage<L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encode(bytes);
+ }
+ }
+}
+
+impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
+ for Prio3PrepareMessage<L>
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3PrepareState<F, L>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareMessage { joint_rand_seed })
+ }
+}
+
+impl<T, P, const L: usize> Client for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ #[allow(clippy::type_complexity)]
+ fn shard(
+ &self,
+ measurement: &T::Measurement,
+ ) -> Result<((), Vec<Prio3InputShare<T::Field, L>>), VdafError> {
+ self.shard_with_rand_source(measurement, getrandom::getrandom)
+ .map(|input_shares| ((), input_shares))
+ }
+}
+
+/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Prio3PrepareState<F, const L: usize> {
+ input_share: Share<F, L>,
+ joint_rand_seed: Option<Seed<L>>,
+ agg_id: u8,
+ verifier_len: usize,
+}
+
+impl<F: FieldElement, const L: usize> Encode for Prio3PrepareState<F, L> {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.input_share.encode(bytes);
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encode(bytes);
+ }
+ }
+}
+
+impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
+ for Prio3PrepareState<T::Field, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ fn decode_with_param(
+ (prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let agg_id = prio3
+ .role_try_from(*agg_id)
+ .map_err(|e| CodecError::Other(Box::new(e)))?;
+
+ let share_decoder = if agg_id == 0 {
+ ShareDecodingParameter::Leader(prio3.typ.input_len())
+ } else {
+ ShareDecodingParameter::Helper
+ };
+ let input_share = Share::decode_with_param(&share_decoder, bytes)?;
+
+ let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Self {
+ input_share,
+ joint_rand_seed,
+ agg_id,
+ verifier_len: prio3.typ.verifier_len(),
+ })
+ }
+}
+
+impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ type PrepareState = Prio3PrepareState<T::Field, L>;
+ type PrepareShare = Prio3PrepareShare<T::Field, L>;
+ type PrepareMessage = Prio3PrepareMessage<L>;
+
+ /// Begins the Prep process with the other aggregators. The result of this process is
+ /// the aggregator's output share.
+ #[allow(clippy::type_complexity)]
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; L],
+ agg_id: usize,
+ _agg_param: &(),
+ nonce: &[u8],
+ _public_share: &(),
+ msg: &Prio3InputShare<T::Field, L>,
+ ) -> Result<
+ (
+ Prio3PrepareState<T::Field, L>,
+ Prio3PrepareShare<T::Field, L>,
+ ),
+ VdafError,
+ > {
+ let agg_id = self.role_try_from(agg_id)?;
+ let mut info = [0; DST_LEN + 1];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
+ info[DST_LEN] = agg_id;
+ let domain_separation_tag = &info[..DST_LEN];
+
+ let mut deriver = P::init(verify_key);
+ deriver.update(domain_separation_tag);
+ deriver.update(&[255]);
+ deriver.update(nonce);
+ let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream());
+
+ // Create a reference to the (expanded) input share.
+ let expanded_input_share: Option<Vec<T::Field>> = match msg.input_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
+ Some(prng.take(self.typ.input_len()).collect())
+ }
+ };
+ let input_share = match msg.input_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_input_share.as_ref().unwrap(),
+ };
+
+ // Create a reference to the (expanded) proof share.
+ let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
+ Some(prng.take(self.typ.proof_len()).collect())
+ }
+ };
+ let proof_share = match msg.proof_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_proof_share.as_ref().unwrap(),
+ };
+
+ // Compute the joint randomness.
+ let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 {
+ let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref());
+ deriver.update(&info);
+ for x in input_share {
+ deriver.update(&(*x).into());
+ }
+ let joint_rand_seed_part = deriver.into_seed();
+
+ let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint;
+ let joint_rand_seed = Self::derive_joint_randomness(
+ hints[..agg_id as usize]
+ .iter()
+ .chain(std::iter::once(&joint_rand_seed_part))
+ .chain(hints[agg_id as usize..].iter()),
+ );
+
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
+ (
+ Some(joint_rand_seed),
+ Some(joint_rand_seed_part),
+ prng.take(self.typ.joint_rand_len()).collect(),
+ )
+ } else {
+ (None, None, Vec::new())
+ };
+
+ // Compute the query randomness.
+ let query_rand: Vec<T::Field> = query_rand_prng.take(self.typ.query_rand_len()).collect();
+
+ // Run the query-generation algorithm.
+ let verifier_share = self.typ.query(
+ input_share,
+ proof_share,
+ &query_rand,
+ &joint_rand,
+ self.num_aggregators as usize,
+ )?;
+
+ Ok((
+ Prio3PrepareState {
+ input_share: msg.input_share.clone(),
+ joint_rand_seed,
+ agg_id,
+ verifier_len: verifier_share.len(),
+ },
+ Prio3PrepareShare {
+ verifier: verifier_share,
+ joint_rand_part: joint_rand_seed_part,
+ },
+ ))
+ }
+
+ fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, L>>>(
+ &self,
+ inputs: M,
+ ) -> Result<Prio3PrepareMessage<L>, VdafError> {
+ let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()];
+ let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
+ let mut count = 0;
+ for share in inputs.into_iter() {
+ count += 1;
+
+ if share.verifier.len() != verifier.len() {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected verifier share length: got {}; want {}",
+ share.verifier.len(),
+ verifier.len(),
+ )));
+ }
+
+ if self.typ.joint_rand_len() > 0 {
+ let joint_rand_seed_part = share.joint_rand_part.unwrap();
+ joint_rand_parts.push(joint_rand_seed_part);
+ }
+
+ for (x, y) in verifier.iter_mut().zip(share.verifier) {
+ *x += y;
+ }
+ }
+
+ if count != self.num_aggregators {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message count: got {}; want {}",
+ count, self.num_aggregators,
+ )));
+ }
+
+ // Check the proof verifier.
+ match self.typ.decide(&verifier) {
+ Ok(true) => (),
+ Ok(false) => {
+ return Err(VdafError::Uncategorized(
+ "proof verifier check failed".into(),
+ ))
+ }
+ Err(err) => return Err(VdafError::from(err)),
+ };
+
+ let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
+ Some(Self::derive_joint_randomness(joint_rand_parts.iter()))
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareMessage { joint_rand_seed })
+ }
+
+ fn prepare_step(
+ &self,
+ step: Prio3PrepareState<T::Field, L>,
+ msg: Prio3PrepareMessage<L>,
+ ) -> Result<PrepareTransition<Self, L>, VdafError> {
+ if self.typ.joint_rand_len() > 0 {
+ // Check that the joint randomness was correct.
+ if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() {
+ return Err(VdafError::Uncategorized(
+ "joint randomness mismatch".to_string(),
+ ));
+ }
+ }
+
+ // Compute the output share.
+ let input_share = match step.input_share {
+ Share::Leader(data) => data,
+ Share::Helper(seed) => {
+ let mut info = [0; DST_LEN + 1];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
+ info[DST_LEN] = step.agg_id;
+ let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info));
+ prng.take(self.typ.input_len()).collect()
+ }
+ };
+
+ let output_share = match self.typ.truncate(input_share) {
+ Ok(data) => OutputShare(data),
+ Err(err) => {
+ return Err(VdafError::from(err));
+ }
+ };
+
+ Ok(PrepareTransition::Finish(output_share))
+ }
+
+ /// Aggregates a sequence of output shares into an aggregate share.
+ fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
+ &self,
+ _agg_param: &(),
+ output_shares: It,
+ ) -> Result<AggregateShare<T::Field>, VdafError> {
+ let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
+ for output_share in output_shares.into_iter() {
+ agg_share.accumulate(&output_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+impl<T, P, const L: usize> Collector for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ /// Combines aggregate shares into the aggregate result.
+ fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
+ &self,
+ _agg_param: &(),
+ agg_shares: It,
+ num_measurements: usize,
+ ) -> Result<T::AggregateResult, VdafError> {
+ let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
+ for agg_share in agg_shares.into_iter() {
+ agg.merge(&agg_share)?;
+ }
+
+ Ok(self.typ.decode_result(&agg.0, num_measurements)?)
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+struct JointRandParam<const L: usize> {
+ /// The joint randomness seed parts corresponding to the other Aggregators' shares.
+ seed_hint: Vec<Seed<L>>,
+
+ /// The blinding factor, used to derive the aggregator's joint randomness seed part.
+ blind: Seed<L>,
+}
+
+#[derive(Clone)]
+struct HelperShare<const L: usize> {
+ input_share: Seed<L>,
+ proof_share: Seed<L>,
+ joint_rand_param: JointRandParam<L>,
+}
+
+impl<const L: usize> HelperShare<L> {
+ fn from_rand_source(rand_source: RandSource) -> Result<Self, VdafError> {
+ Ok(HelperShare {
+ input_share: Seed::from_rand_source(rand_source)?,
+ proof_share: Seed::from_rand_source(rand_source)?,
+ joint_rand_param: JointRandParam {
+ seed_hint: Vec::new(),
+ blind: Seed::from_rand_source(rand_source)?,
+ },
+ })
+ }
+}
+
+fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
+ if num_aggregators == 0 {
+ return Err(VdafError::Uncategorized(format!(
+ "at least one aggregator is required; got {}",
+ num_aggregators
+ )));
+ } else if num_aggregators > 254 {
+ return Err(VdafError::Uncategorized(format!(
+ "number of aggregators must not exceed 254; got {}",
+ num_aggregators
+ )));
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::vdaf::{run_vdaf, run_vdaf_prepare};
+ use assert_matches::assert_matches;
+ use rand::prelude::*;
+
+ #[test]
+ fn test_prio3_count() {
+ let prio3 = Prio3::new_aes128_count(2).unwrap();
+
+ assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3);
+
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ let nonce = b"This is a good nonce.";
+
+ let (public_share, input_shares) = prio3.shard(&0).unwrap();
+ run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();
+
+ let (public_share, input_shares) = prio3.shard(&1).unwrap();
+ run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();
+
+ test_prepare_state_serialization(&prio3, &1).unwrap();
+
+ let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap();
+ assert_eq!(
+ run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(),
+ 3,
+ );
+ }
+
+ #[test]
+ fn test_prio3_sum() {
+ let prio3 = Prio3::new_aes128_sum(3, 16).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
+ (1 << 16) + 1
+ );
+
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ let nonce = b"This is a good nonce.";
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255;
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255;
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ test_prepare_state_serialization(&prio3, &1).unwrap();
+ }
+
+ #[test]
+ fn test_prio3_countvec() {
+ let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap();
+ assert_eq!(
+ run_vdaf(
+ &prio3,
+ &(),
+ [vec![
+ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
+ ]]
+ )
+ .unwrap(),
+ vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
+ );
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_prio3_countvec_multithreaded() {
+ let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap();
+ assert_eq!(
+ run_vdaf(
+ &prio3,
+ &(),
+ [vec![
+ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
+ ]]
+ )
+ .unwrap(),
+ vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
+ );
+ }
+
+ #[test]
+ fn test_prio3_histogram() {
+ let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(),
+ vec![1, 1, 1, 1]
+ );
+ assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]);
+ test_prepare_state_serialization(&prio3, &23).unwrap();
+ }
+
+ #[test]
+ fn test_prio3_average() {
+ let prio3 = Prio3::new_aes128_average(2, 64).unwrap();
+
+ assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
+ assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
+ assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
+ assert_eq!(
+ run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
+ 207.5f64
+ );
+ }
+
+ #[test]
+ fn test_prio3_input_share() {
+ let prio3 = Prio3::new_aes128_sum(5, 16).unwrap();
+ let (_public_share, input_shares) = prio3.shard(&1).unwrap();
+
+ // Check that seed shares are distinct.
+ for (i, x) in input_shares.iter().enumerate() {
+ for (j, y) in input_shares.iter().enumerate() {
+ if i != j {
+ if let (Share::Helper(left), Share::Helper(right)) =
+ (&x.input_share, &y.input_share)
+ {
+ assert_ne!(left, right);
+ }
+
+ if let (Share::Helper(left), Share::Helper(right)) =
+ (&x.proof_share, &y.proof_share)
+ {
+ assert_ne!(left, right);
+ }
+
+ assert_ne!(x.joint_rand_param, y.joint_rand_param);
+ }
+ }
+ }
+ }
+
+ fn test_prepare_state_serialization<T, P, const L: usize>(
+ prio3: &Prio3<T, P, L>,
+ measurement: &T::Measurement,
+ ) -> Result<(), VdafError>
+ where
+ T: Type,
+ P: Prg<L>,
+ {
+ let mut verify_key = [0; L];
+ thread_rng().fill(&mut verify_key[..]);
+ let (public_share, input_shares) = prio3.shard(measurement)?;
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (want, _msg) =
+ prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?;
+ let got =
+ Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded())
+ .expect("failed to decode prepare step");
+ assert_eq!(got, want);
+ }
+ Ok(())
+ }
+}