summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio2.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/prio/src/vdaf/prio2.rs
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio2.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prio2.rs425
1 files changed, 425 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs
new file mode 100644
index 0000000000..47fc076790
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio2.rs
@@ -0,0 +1,425 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Port of the ENPA Prio system to a VDAF. It is backwards compatible with
+//! [`Client`](crate::client::Client) and [`Server`](crate::server::Server).
+
+use crate::{
+ client as v2_client,
+ codec::{CodecError, Decode, Encode, ParameterizedDecode},
+ field::{FieldElement, FieldPrio2},
+ prng::Prng,
+ server as v2_server,
+ util::proof_length,
+ vdaf::{
+ prg::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare,
+ PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError,
+ },
+};
+use ring::hmac;
+use std::{
+ convert::{TryFrom, TryInto},
+ io::Cursor,
+};
+
+/// The Prio2 VDAF. It supports the same measurement type as
+/// [`Prio3Aes128CountVec`](crate::vdaf::prio3::Prio3Aes128CountVec) but uses the proof system
+/// and finite field deployed in ENPA.
+#[derive(Clone, Debug)]
+pub struct Prio2 {
+ input_len: usize,
+}
+
+impl Prio2 {
+ /// Returns an instance of the VDAF for the given input length.
+ pub fn new(input_len: usize) -> Result<Self, VdafError> {
+ let n = (input_len + 1).next_power_of_two();
+ if let Ok(size) = u32::try_from(2 * n) {
+ if size > FieldPrio2::generator_order() {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+ } else {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds memory capacity".into(),
+ ));
+ }
+
+ Ok(Prio2 { input_len })
+ }
+
+ /// Prepare an input share for aggregation using the given field element `query_rand` to
+ /// compute the verifier share.
+ ///
+ /// In the [`Aggregator`](crate::vdaf::Aggregator) trait implementation for [`Prio2`], the
+ /// query randomness is computed jointly by the Aggregators. This method is designed to be used
+ /// in applications, like ENPA, in which the query randomness is instead chosen by a
+ /// third-party.
+ pub fn prepare_init_with_query_rand(
+ &self,
+ query_rand: FieldPrio2,
+ input_share: &Share<FieldPrio2, 32>,
+ is_leader: bool,
+ ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
+ let expanded_data: Option<Vec<FieldPrio2>> = match input_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_prio2_seed(seed.as_ref());
+ Some(prng.take(proof_length(self.input_len)).collect())
+ }
+ };
+ let data = match input_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_data.as_ref().unwrap(),
+ };
+
+ let mut mem = v2_server::ValidationMemory::new(self.input_len);
+ let verifier_share = v2_server::generate_verification_message(
+ self.input_len,
+ query_rand,
+ data, // Combined input and proof shares
+ is_leader,
+ &mut mem,
+ )
+ .map_err(|e| VdafError::Uncategorized(e.to_string()))?;
+
+ Ok((
+ Prio2PrepareState(input_share.truncated(self.input_len)),
+ Prio2PrepareShare(verifier_share),
+ ))
+ }
+}
+
+impl Vdaf for Prio2 {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = Vec<u32>;
+ type AggregateResult = Vec<u32>;
+ type AggregationParam = ();
+ type PublicShare = ();
+ type InputShare = Share<FieldPrio2, 32>;
+ type OutputShare = OutputShare<FieldPrio2>;
+ type AggregateShare = AggregateShare<FieldPrio2>;
+
+ fn num_aggregators(&self) -> usize {
+ // Prio2 can easily be extended to support more than two Aggregators.
+ 2
+ }
+}
+
+impl Client for Prio2 {
+ fn shard(&self, measurement: &Vec<u32>) -> Result<((), Vec<Share<FieldPrio2, 32>>), VdafError> {
+ if measurement.len() != self.input_len {
+ return Err(VdafError::Uncategorized("incorrect input length".into()));
+ }
+ let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len());
+ for int in measurement {
+ input.push((*int).into());
+ }
+
+ let mut mem = v2_client::ClientMemory::new(self.input_len)?;
+ let copy_data = |share_data: &mut [FieldPrio2]| {
+ share_data[..].clone_from_slice(&input);
+ };
+ let mut leader_data = mem.prove_with(self.input_len, copy_data);
+
+ let helper_seed = Seed::generate()?;
+ let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref());
+ for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) {
+ *s1 -= d;
+ }
+
+ Ok((
+ (),
+ vec![Share::Leader(leader_data), Share::Helper(helper_seed)],
+ ))
+ }
+}
+
+/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Prio2PrepareState(Share<FieldPrio2, 32>);
+
+impl Encode for Prio2PrepareState {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes);
+ }
+}
+
+impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
+ fn decode_with_param(
+ (prio2, agg_id): &(&'a Prio2, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let share_decoder = if *agg_id == 0 {
+ ShareDecodingParameter::Leader(prio2.input_len)
+ } else {
+ ShareDecodingParameter::Helper
+ };
+ let out_share = Share::decode_with_param(&share_decoder, bytes)?;
+ Ok(Self(out_share))
+ }
+}
+
+/// Message emitted by each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
+#[derive(Clone, Debug)]
+pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);
+
+impl Encode for Prio2PrepareShare {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.f_r.encode(bytes);
+ self.0.g_r.encode(bytes);
+ self.0.h_r.encode(bytes);
+ }
+}
+
+impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
+ fn decode_with_param(
+ _state: &Prio2PrepareState,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Ok(Self(v2_server::VerificationMessage {
+ f_r: FieldPrio2::decode(bytes)?,
+ g_r: FieldPrio2::decode(bytes)?,
+ h_r: FieldPrio2::decode(bytes)?,
+ }))
+ }
+}
+
+impl Aggregator<32> for Prio2 {
+ type PrepareState = Prio2PrepareState;
+ type PrepareShare = Prio2PrepareShare;
+ type PrepareMessage = ();
+
+ fn prepare_init(
+ &self,
+ agg_key: &[u8; 32],
+ agg_id: usize,
+ _agg_param: &(),
+ nonce: &[u8],
+ _public_share: &Self::PublicShare,
+ input_share: &Share<FieldPrio2, 32>,
+ ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
+ let is_leader = role_try_from(agg_id)?;
+
+ // In the ENPA Prio system, the query randomness is generated by a third party and
+ // distributed to the Aggregators after they receive their input shares. In a VDAF, shared
+ // randomness is derived from a nonce selected by the client. For Prio2 we compute the
+ // query using HMAC-SHA256 evaluated over the nonce.
+ let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, agg_key);
+ let hmac_tag = hmac::sign(&hmac_key, nonce);
+ let query_rand = Prng::from_prio2_seed(hmac_tag.as_ref().try_into().unwrap())
+ .next()
+ .unwrap();
+
+ self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
+ }
+
+ fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>(
+ &self,
+ inputs: M,
+ ) -> Result<(), VdafError> {
+ let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
+ inputs.into_iter().map(|msg| msg.0).collect();
+ if verifier_shares.len() != 2 {
+ return Err(VdafError::Uncategorized(
+ "wrong number of verifier shares".into(),
+ ));
+ }
+
+ if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) {
+ return Err(VdafError::Uncategorized(
+ "proof verifier check failed".into(),
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn prepare_step(
+ &self,
+ state: Prio2PrepareState,
+ _input: (),
+ ) -> Result<PrepareTransition<Self, 32>, VdafError> {
+ let data = match state.0 {
+ Share::Leader(data) => data,
+ Share::Helper(seed) => {
+ let prng = Prng::from_prio2_seed(seed.as_ref());
+ prng.take(self.input_len).collect()
+ }
+ };
+ Ok(PrepareTransition::Finish(OutputShare::from(data)))
+ }
+
+ fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>(
+ &self,
+ _agg_param: &(),
+ out_shares: M,
+ ) -> Result<AggregateShare<FieldPrio2>, VdafError> {
+ let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
+ for out_share in out_shares.into_iter() {
+ agg_share.accumulate(&out_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+impl Collector for Prio2 {
+ fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>(
+ &self,
+ _agg_param: &(),
+ agg_shares: M,
+ _num_measurements: usize,
+ ) -> Result<Vec<u32>, VdafError> {
+ let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
+ for agg_share in agg_shares.into_iter() {
+ agg.merge(&agg_share)?;
+ }
+
+ Ok(agg.0.into_iter().map(u32::from).collect())
+ }
+}
+
+impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> {
+ fn decode_with_param(
+ (prio2, agg_id): &(&'a Prio2, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
+ let decoder = if is_leader {
+ ShareDecodingParameter::Leader(proof_length(prio2.input_len))
+ } else {
+ ShareDecodingParameter::Helper
+ };
+
+ Share::decode_with_param(&decoder, bytes)
+ }
+}
+
+fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
+ match agg_id {
+ 0 => Ok(true),
+ 1 => Ok(false),
+ _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ client::encode_simple,
+ encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey},
+ field::random_vector,
+ server::Server,
+ vdaf::{run_vdaf, run_vdaf_prepare},
+ };
+ use rand::prelude::*;
+
+ const PRIV_KEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
+ const PRIV_KEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";
+
+ #[test]
+ fn run_prio2() {
+ let prio2 = Prio2::new(6).unwrap();
+
+ assert_eq!(
+ run_vdaf(
+ &prio2,
+ &(),
+ [
+ vec![0, 0, 0, 0, 1, 0],
+ vec![0, 1, 0, 0, 0, 0],
+ vec![0, 1, 1, 0, 0, 0],
+ vec![1, 1, 1, 0, 0, 0],
+ vec![0, 0, 0, 0, 1, 1],
+ ]
+ )
+ .unwrap(),
+ vec![1, 3, 2, 0, 2, 1],
+ );
+ }
+
+ #[test]
+ fn enpa_client_interop() {
+ let mut rng = thread_rng();
+ let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
+ let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
+ let pub_key1 = PublicKey::from(&priv_key1);
+ let pub_key2 = PublicKey::from(&priv_key2);
+
+ let data: Vec<FieldPrio2> = [0, 0, 1, 1, 0]
+ .iter()
+ .map(|x| FieldPrio2::from(*x))
+ .collect();
+ let (encrypted_input_share1, encrypted_input_share2) =
+ encode_simple(&data, pub_key1, pub_key2).unwrap();
+
+ let input_share1 = decrypt_share(&encrypted_input_share1, &priv_key1).unwrap();
+ let input_share2 = decrypt_share(&encrypted_input_share2, &priv_key2).unwrap();
+
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let input_shares = vec![
+ Share::get_decoded_with_param(&(&prio2, 0), &input_share1).unwrap(),
+ Share::get_decoded_with_param(&(&prio2, 1), &input_share2).unwrap(),
+ ];
+
+ let verify_key = rng.gen();
+ let mut nonce = [0; 16];
+ rng.fill(&mut nonce);
+ run_vdaf_prepare(&prio2, &verify_key, &(), &nonce, (), input_shares).unwrap();
+ }
+
+ #[test]
+ fn enpa_server_interop() {
+ let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
+ let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
+ let pub_key1 = PublicKey::from(&priv_key1);
+ let pub_key2 = PublicKey::from(&priv_key2);
+
+ let data = vec![0, 0, 1, 1, 0];
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let (_public_share, input_shares) = prio2.shard(&data).unwrap();
+
+ let encrypted_input_share1 =
+ encrypt_share(&input_shares[0].get_encoded(), &pub_key1).unwrap();
+ let encrypted_input_share2 =
+ encrypt_share(&input_shares[1].get_encoded(), &pub_key2).unwrap();
+
+ let mut server1 = Server::new(data.len(), true, priv_key1).unwrap();
+ let mut server2 = Server::new(data.len(), false, priv_key2).unwrap();
+
+ let eval_at: FieldPrio2 = random_vector(1).unwrap()[0];
+ let verifier1 = server1
+ .generate_verification_message(eval_at, &encrypted_input_share1)
+ .unwrap();
+ let verifier2 = server2
+ .generate_verification_message(eval_at, &encrypted_input_share2)
+ .unwrap();
+
+ server1
+ .aggregate(&encrypted_input_share1, &verifier1, &verifier2)
+ .unwrap();
+ server2
+ .aggregate(&encrypted_input_share2, &verifier1, &verifier2)
+ .unwrap();
+ }
+
+ #[test]
+ fn prepare_state_serialization() {
+ let mut verify_key = [0; 32];
+ thread_rng().fill(&mut verify_key[..]);
+ let data = vec![0, 0, 1, 1, 0];
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let (public_share, input_shares) = prio2.shard(&data).unwrap();
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (want, _msg) = prio2
+ .prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)
+ .unwrap();
+ let got =
+ Prio2PrepareState::get_decoded_with_param(&(&prio2, agg_id), &want.get_encoded())
+ .expect("failed to decode prepare step");
+ assert_eq!(got, want);
+ }
+ }
+}