summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio3_test.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio3_test.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prio3_test.rs162
1 files changed, 162 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs
new file mode 100644
index 0000000000..d4c9151ce0
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio3_test.rs
@@ -0,0 +1,162 @@
+// SPDX-License-Identifier: MPL-2.0
+
+use crate::{
+ codec::{Encode, ParameterizedDecode},
+ flp::Type,
+ vdaf::{
+ prg::Prg,
+ prio3::{Prio3, Prio3InputShare, Prio3PrepareShare},
+ Aggregator, PrepareTransition,
+ },
+};
+use serde::{Deserialize, Serialize};
+use std::{convert::TryInto, fmt::Debug};
+
+#[derive(Debug, Deserialize, Serialize)]
+struct TEncoded(#[serde(with = "hex")] Vec<u8>);
+
+impl AsRef<[u8]> for TEncoded {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[derive(Deserialize, Serialize)]
+struct TPrio3Prep<M> {
+ measurement: M,
+ #[serde(with = "hex")]
+ nonce: Vec<u8>,
+ input_shares: Vec<TEncoded>,
+ prep_shares: Vec<Vec<TEncoded>>,
+ prep_messages: Vec<TEncoded>,
+ out_shares: Vec<Vec<M>>,
+}
+
+#[derive(Deserialize, Serialize)]
+struct TPrio3<M> {
+ verify_key: TEncoded,
+ prep: Vec<TPrio3Prep<M>>,
+}
+
+macro_rules! err {
+ (
+ $test_num:ident,
+ $error:expr,
+ $msg:expr
+ ) => {
+ panic!("test #{} failed: {} err: {}", $test_num, $msg, $error)
+ };
+}
+
+// TODO Generalize this method to work with any VDAF. To do so we would need to add
+// `test_vec_setup()` and `test_vec_shard()` to traits. (There may be a less invasive alternative.)
+fn check_prep_test_vec<M, T, P, const L: usize>(
+ prio3: &Prio3<T, P, L>,
+ verify_key: &[u8; L],
+ test_num: usize,
+ t: &TPrio3Prep<M>,
+) where
+ T: Type<Measurement = M>,
+ P: Prg<L>,
+ M: From<<T as Type>::Field> + Debug + PartialEq,
+{
+ let input_shares = prio3
+ .test_vec_shard(&t.measurement)
+ .expect("failed to generate input shares");
+
+ assert_eq!(2, t.input_shares.len(), "#{}", test_num);
+ for (agg_id, want) in t.input_shares.iter().enumerate() {
+ assert_eq!(
+ input_shares[agg_id],
+ Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")),
+ "#{}",
+ test_num
+ );
+ assert_eq!(
+ input_shares[agg_id].get_encoded(),
+ want.as_ref(),
+ "#{}",
+ test_num
+ )
+ }
+
+ let mut states = Vec::new();
+ let mut prep_shares = Vec::new();
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (state, prep_share) = prio3
+ .prepare_init(verify_key, agg_id, &(), &t.nonce, &(), input_share)
+ .unwrap_or_else(|e| err!(test_num, e, "prep state init"));
+ states.push(state);
+ prep_shares.push(prep_share);
+ }
+
+ assert_eq!(1, t.prep_shares.len(), "#{}", test_num);
+ for (i, want) in t.prep_shares[0].iter().enumerate() {
+ assert_eq!(
+ prep_shares[i],
+ Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
+ "#{}",
+ test_num
+ );
+ assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{}", test_num);
+ }
+
+ let inbound = prio3
+ .prepare_preprocess(prep_shares)
+ .unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
+ assert_eq!(t.prep_messages.len(), 1);
+ assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());
+
+ let mut out_shares = Vec::new();
+ for state in states.iter_mut() {
+ match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() {
+ PrepareTransition::Finish(out_share) => {
+ out_shares.push(out_share);
+ }
+ _ => panic!("unexpected transition"),
+ }
+ }
+
+ for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
+ let got: Vec<M> = got.as_ref().iter().map(|x| M::from(*x)).collect();
+ assert_eq!(&got, want);
+ }
+}
+
+#[test]
+fn test_vec_prio3_count() {
+ let t: TPrio3<u64> =
+ serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Count_0.json")).unwrap();
+ let prio3 = Prio3::new_aes128_count(2).unwrap();
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ for (test_num, p) in t.prep.iter().enumerate() {
+ check_prep_test_vec(&prio3, &verify_key, test_num, p);
+ }
+}
+
+#[test]
+fn test_vec_prio3_sum() {
+ let t: TPrio3<u128> =
+ serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Sum_0.json")).unwrap();
+ let prio3 = Prio3::new_aes128_sum(2, 8).unwrap();
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ for (test_num, p) in t.prep.iter().enumerate() {
+ check_prep_test_vec(&prio3, &verify_key, test_num, p);
+ }
+}
+
+#[test]
+fn test_vec_prio3_histogram() {
+ let t: TPrio3<u128> =
+ serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Histogram_0.json")).unwrap();
+ let prio3 = Prio3::new_aes128_histogram(2, &[1, 10, 100]).unwrap();
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ for (test_num, p) in t.prep.iter().enumerate() {
+ check_prep_test_vec(&prio3, &verify_key, test_num, p);
+ }
+}