diff options
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio3_test.rs')
-rw-r--r-- | third_party/rust/prio/src/vdaf/prio3_test.rs | 162 |
1 files changed, 102 insertions, 60 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs index 372a2c8560..9a3dfd85f4 100644 --- a/third_party/rust/prio/src/vdaf/prio3_test.rs +++ b/third_party/rust/prio/src/vdaf/prio3_test.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +//! Tools for evaluating Prio3 test vectors. + use crate::{ codec::{Encode, ParameterizedDecode}, flp::Type, @@ -58,19 +60,21 @@ macro_rules! err { // TODO Generalize this method to work with any VDAF. To do so we would need to add // `shard_with_random()` to traits. (There may be a less invasive alternative.) -fn check_prep_test_vec<M, T, P, const SEED_SIZE: usize>( +fn check_prep_test_vec<MS, MP, T, P, const SEED_SIZE: usize>( prio3: &Prio3<T, P, SEED_SIZE>, verify_key: &[u8; SEED_SIZE], test_num: usize, - t: &TPrio3Prep<M>, + t: &TPrio3Prep<MS>, ) -> Vec<OutputShare<T::Field>> where - T: Type<Measurement = M>, + MS: Clone, + MP: From<MS>, + T: Type<Measurement = MP>, P: Xof<SEED_SIZE>, { let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap(); let (public_share, input_shares) = prio3 - .shard_with_random(&t.measurement, &nonce, &t.rand) + .shard_with_random(&t.measurement.clone().into(), &nonce, &t.rand) .expect("failed to generate input shares"); assert_eq!( @@ -86,7 +90,7 @@ where "#{test_num}" ); assert_eq!( - input_shares[agg_id].get_encoded(), + input_shares[agg_id].get_encoded().unwrap(), want.as_ref(), "#{test_num}" ) @@ -110,14 +114,18 @@ where .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")), "#{test_num}" ); - assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{test_num}"); + assert_eq!( + prep_shares[i].get_encoded().unwrap(), + want.as_ref(), + "#{test_num}" + ); } let inbound = prio3 .prepare_shares_to_prepare_message(&(), prep_shares) .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); assert_eq!(t.prep_messages.len(), 1); - assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref()); + assert_eq!(inbound.get_encoded().unwrap(), t.prep_messages[0].as_ref()); let mut out_shares = Vec::new(); for state in states.iter_mut() { @@ -130,7 +138,11 @@ where } for (got, want) in out_shares.iter().zip(t.out_shares.iter()) { - let got: Vec<Vec<u8>> = got.as_ref().iter().map(|x| x.get_encoded()).collect(); + let got: Vec<Vec<u8>> = got + .as_ref() + .iter() + .map(|x| x.get_encoded().unwrap()) + .collect(); assert_eq!(got.len(), want.len()); for (got_elem, want_elem) in got.iter().zip(want.iter()) { assert_eq!(got_elem.as_slice(), want_elem.as_ref()); @@ -141,12 +153,14 @@ where } #[must_use] -fn check_aggregate_test_vec<M, T, P, const SEED_SIZE: usize>( +fn check_aggregate_test_vec<MS, MP, T, P, const SEED_SIZE: usize>( prio3: &Prio3<T, P, SEED_SIZE>, - t: &TPrio3<M>, + t: &TPrio3<MS>, ) -> T::AggregateResult where - T: Type<Measurement = M>, + MS: Clone, + MP: From<MS>, + T: Type<Measurement = MP>, P: Xof<SEED_SIZE>, { let verify_key = t.verify_key.as_ref().try_into().unwrap(); @@ -167,85 +181,113 @@ where .collect::<Vec<_>>(); for (got, want) in aggregate_shares.iter().zip(t.agg_shares.iter()) { - let got = got.get_encoded(); + let got = got.get_encoded().unwrap(); assert_eq!(got.as_slice(), want.as_ref()); } prio3.unshard(&(), aggregate_shares, 1).unwrap() } +/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf` callback, +/// which takes in the VDAF parameters encoded by the test vectors and the number of shares. +/// +/// This version allows customizing the deserialization of measurements, via an additional type +/// parameter. +#[cfg(feature = "test-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] +pub fn check_test_vec_custom_de<MS, MP, A, T, P, const SEED_SIZE: usize>( + test_vec_json_str: &str, + new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>, +) where + MS: for<'de> Deserialize<'de> + Clone, + MP: From<MS>, + A: for<'de> Deserialize<'de> + Debug + Eq, + T: Type<Measurement = MP, AggregateResult = A>, + P: Xof<SEED_SIZE>, +{ + let t: TPrio3<MS> = serde_json::from_str(test_vec_json_str).unwrap(); + let vdaf = new_vdaf(&t.other_params, t.shares); + let agg_result = check_aggregate_test_vec(&vdaf, &t); + assert_eq!(agg_result, serde_json::from_value(t.agg_result).unwrap()); +} + +/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf` callback, +/// which takes in the VDAF parameters encoded by the test vectors and the number of shares. +#[cfg(feature = "test-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] +pub fn check_test_vec<M, A, T, P, const SEED_SIZE: usize>( + test_vec_json_str: &str, + new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>, +) where + M: for<'de> Deserialize<'de> + Clone, + A: for<'de> Deserialize<'de> + Debug + Eq, + T: Type<Measurement = M, AggregateResult = A>, + P: Xof<SEED_SIZE>, +{ + check_test_vec_custom_de::<M, M, _, _, _, SEED_SIZE>(test_vec_json_str, new_vdaf) +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(transparent)] +struct Prio3CountMeasurement(u8); + +impl From<Prio3CountMeasurement> for bool { + fn from(value: Prio3CountMeasurement) -> Self { + value.0 != 0 + } +} + #[test] fn test_vec_prio3_count() { for test_vector_str in [ - include_str!("test_vec/07/Prio3Count_0.json"), - include_str!("test_vec/07/Prio3Count_1.json"), + include_str!("test_vec/08/Prio3Count_0.json"), + include_str!("test_vec/08/Prio3Count_1.json"), ] { - let t: TPrio3<u64> = serde_json::from_str(test_vector_str).unwrap(); - let prio3 = Prio3::new_count(t.shares).unwrap(); - - let aggregate_result = check_aggregate_test_vec(&prio3, &t); - assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap()); + check_test_vec_custom_de::<Prio3CountMeasurement, _, _, _, _, 16>( + test_vector_str, + |_json_params, num_shares| Prio3::new_count(num_shares).unwrap(), + ); } } #[test] fn test_vec_prio3_sum() { for test_vector_str in [ - include_str!("test_vec/07/Prio3Sum_0.json"), - include_str!("test_vec/07/Prio3Sum_1.json"), + include_str!("test_vec/08/Prio3Sum_0.json"), + include_str!("test_vec/08/Prio3Sum_1.json"), ] { - let t: TPrio3<u128> = serde_json::from_str(test_vector_str).unwrap(); - let bits = t.other_params["bits"].as_u64().unwrap() as usize; - let prio3 = Prio3::new_sum(t.shares, bits).unwrap(); - - let aggregate_result = check_aggregate_test_vec(&prio3, &t); - assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap() as u128); + check_test_vec(test_vector_str, |json_params, num_shares| { + let bits = json_params["bits"].as_u64().unwrap() as usize; + Prio3::new_sum(num_shares, bits).unwrap() + }); } } #[test] fn test_vec_prio3_sum_vec() { for test_vector_str in [ - include_str!("test_vec/07/Prio3SumVec_0.json"), - include_str!("test_vec/07/Prio3SumVec_1.json"), + include_str!("test_vec/08/Prio3SumVec_0.json"), + include_str!("test_vec/08/Prio3SumVec_1.json"), ] { - let t: TPrio3<Vec<u128>> = serde_json::from_str(test_vector_str).unwrap(); - let bits = t.other_params["bits"].as_u64().unwrap() as usize; - let length = t.other_params["length"].as_u64().unwrap() as usize; - let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize; - let prio3 = Prio3::new_sum_vec(t.shares, bits, length, chunk_length).unwrap(); - - let aggregate_result = check_aggregate_test_vec(&prio3, &t); - let expected_aggregate_result = t - .agg_result - .as_array() - .unwrap() - .iter() - .map(|val| val.as_u64().unwrap() as u128) - .collect::<Vec<u128>>(); - assert_eq!(aggregate_result, expected_aggregate_result); + check_test_vec(test_vector_str, |json_params, num_shares| { + let bits = json_params["bits"].as_u64().unwrap() as usize; + let length = json_params["length"].as_u64().unwrap() as usize; + let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize; + Prio3::new_sum_vec(num_shares, bits, length, chunk_length).unwrap() + }); } } #[test] fn test_vec_prio3_histogram() { for test_vector_str in [ - include_str!("test_vec/07/Prio3Histogram_0.json"), - include_str!("test_vec/07/Prio3Histogram_1.json"), + include_str!("test_vec/08/Prio3Histogram_0.json"), + include_str!("test_vec/08/Prio3Histogram_1.json"), ] { - let t: TPrio3<usize> = serde_json::from_str(test_vector_str).unwrap(); - let length = t.other_params["length"].as_u64().unwrap() as usize; - let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize; - let prio3 = Prio3::new_histogram(t.shares, length, chunk_length).unwrap(); - - let aggregate_result = check_aggregate_test_vec(&prio3, &t); - let expected_aggregate_result = t - .agg_result - .as_array() - .unwrap() - .iter() - .map(|val| val.as_u64().unwrap() as u128) - .collect::<Vec<u128>>(); - assert_eq!(aggregate_result, expected_aggregate_result); + check_test_vec(test_vector_str, |json_params, num_shares| { + let length = json_params["length"].as_u64().unwrap() as usize; + let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize; + Prio3::new_histogram(num_shares, length, chunk_length).unwrap() + }); } } |