summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio3_test.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio3_test.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prio3_test.rs162
1 files changed, 102 insertions, 60 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs
index 372a2c8560..9a3dfd85f4 100644
--- a/third_party/rust/prio/src/vdaf/prio3_test.rs
+++ b/third_party/rust/prio/src/vdaf/prio3_test.rs
@@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0
+//! Tools for evaluating Prio3 test vectors.
+
use crate::{
codec::{Encode, ParameterizedDecode},
flp::Type,
@@ -58,19 +60,21 @@ macro_rules! err {
// TODO Generalize this method to work with any VDAF. To do so we would need to add
// `shard_with_random()` to traits. (There may be a less invasive alternative.)
-fn check_prep_test_vec<M, T, P, const SEED_SIZE: usize>(
+fn check_prep_test_vec<MS, MP, T, P, const SEED_SIZE: usize>(
prio3: &Prio3<T, P, SEED_SIZE>,
verify_key: &[u8; SEED_SIZE],
test_num: usize,
- t: &TPrio3Prep<M>,
+ t: &TPrio3Prep<MS>,
) -> Vec<OutputShare<T::Field>>
where
- T: Type<Measurement = M>,
+ MS: Clone,
+ MP: From<MS>,
+ T: Type<Measurement = MP>,
P: Xof<SEED_SIZE>,
{
let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap();
let (public_share, input_shares) = prio3
- .shard_with_random(&t.measurement, &nonce, &t.rand)
+ .shard_with_random(&t.measurement.clone().into(), &nonce, &t.rand)
.expect("failed to generate input shares");
assert_eq!(
@@ -86,7 +90,7 @@ where
"#{test_num}"
);
assert_eq!(
- input_shares[agg_id].get_encoded(),
+ input_shares[agg_id].get_encoded().unwrap(),
want.as_ref(),
"#{test_num}"
)
@@ -110,14 +114,18 @@ where
.unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
"#{test_num}"
);
- assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{test_num}");
+ assert_eq!(
+ prep_shares[i].get_encoded().unwrap(),
+ want.as_ref(),
+ "#{test_num}"
+ );
}
let inbound = prio3
.prepare_shares_to_prepare_message(&(), prep_shares)
.unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
assert_eq!(t.prep_messages.len(), 1);
- assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());
+ assert_eq!(inbound.get_encoded().unwrap(), t.prep_messages[0].as_ref());
let mut out_shares = Vec::new();
for state in states.iter_mut() {
@@ -130,7 +138,11 @@ where
}
for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
- let got: Vec<Vec<u8>> = got.as_ref().iter().map(|x| x.get_encoded()).collect();
+ let got: Vec<Vec<u8>> = got
+ .as_ref()
+ .iter()
+ .map(|x| x.get_encoded().unwrap())
+ .collect();
assert_eq!(got.len(), want.len());
for (got_elem, want_elem) in got.iter().zip(want.iter()) {
assert_eq!(got_elem.as_slice(), want_elem.as_ref());
@@ -141,12 +153,14 @@ where
}
#[must_use]
-fn check_aggregate_test_vec<M, T, P, const SEED_SIZE: usize>(
+fn check_aggregate_test_vec<MS, MP, T, P, const SEED_SIZE: usize>(
prio3: &Prio3<T, P, SEED_SIZE>,
- t: &TPrio3<M>,
+ t: &TPrio3<MS>,
) -> T::AggregateResult
where
- T: Type<Measurement = M>,
+ MS: Clone,
+ MP: From<MS>,
+ T: Type<Measurement = MP>,
P: Xof<SEED_SIZE>,
{
let verify_key = t.verify_key.as_ref().try_into().unwrap();
@@ -167,85 +181,113 @@ where
.collect::<Vec<_>>();
for (got, want) in aggregate_shares.iter().zip(t.agg_shares.iter()) {
- let got = got.get_encoded();
+ let got = got.get_encoded().unwrap();
assert_eq!(got.as_slice(), want.as_ref());
}
prio3.unshard(&(), aggregate_shares, 1).unwrap()
}
+/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf` callback,
+/// which takes in the VDAF parameters encoded by the test vectors and the number of shares.
+///
+/// This version allows customizing the deserialization of measurements, via an additional type
+/// parameter.
+#[cfg(feature = "test-util")]
+#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
+pub fn check_test_vec_custom_de<MS, MP, A, T, P, const SEED_SIZE: usize>(
+ test_vec_json_str: &str,
+ new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>,
+) where
+ MS: for<'de> Deserialize<'de> + Clone,
+ MP: From<MS>,
+ A: for<'de> Deserialize<'de> + Debug + Eq,
+ T: Type<Measurement = MP, AggregateResult = A>,
+ P: Xof<SEED_SIZE>,
+{
+ let t: TPrio3<MS> = serde_json::from_str(test_vec_json_str).unwrap();
+ let vdaf = new_vdaf(&t.other_params, t.shares);
+ let agg_result = check_aggregate_test_vec(&vdaf, &t);
+ assert_eq!(agg_result, serde_json::from_value(t.agg_result).unwrap());
+}
+
+/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf` callback,
+/// which takes in the VDAF parameters encoded by the test vectors and the number of shares.
+#[cfg(feature = "test-util")]
+#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
+pub fn check_test_vec<M, A, T, P, const SEED_SIZE: usize>(
+ test_vec_json_str: &str,
+ new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>,
+) where
+ M: for<'de> Deserialize<'de> + Clone,
+ A: for<'de> Deserialize<'de> + Debug + Eq,
+ T: Type<Measurement = M, AggregateResult = A>,
+ P: Xof<SEED_SIZE>,
+{
+ check_test_vec_custom_de::<M, M, _, _, _, SEED_SIZE>(test_vec_json_str, new_vdaf)
+}
+
+#[derive(Debug, Clone, Deserialize)]
+#[serde(transparent)]
+struct Prio3CountMeasurement(u8);
+
+impl From<Prio3CountMeasurement> for bool {
+ fn from(value: Prio3CountMeasurement) -> Self {
+ value.0 != 0
+ }
+}
+
#[test]
fn test_vec_prio3_count() {
for test_vector_str in [
- include_str!("test_vec/07/Prio3Count_0.json"),
- include_str!("test_vec/07/Prio3Count_1.json"),
+ include_str!("test_vec/08/Prio3Count_0.json"),
+ include_str!("test_vec/08/Prio3Count_1.json"),
] {
- let t: TPrio3<u64> = serde_json::from_str(test_vector_str).unwrap();
- let prio3 = Prio3::new_count(t.shares).unwrap();
-
- let aggregate_result = check_aggregate_test_vec(&prio3, &t);
- assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap());
+ check_test_vec_custom_de::<Prio3CountMeasurement, _, _, _, _, 16>(
+ test_vector_str,
+ |_json_params, num_shares| Prio3::new_count(num_shares).unwrap(),
+ );
}
}
#[test]
fn test_vec_prio3_sum() {
for test_vector_str in [
- include_str!("test_vec/07/Prio3Sum_0.json"),
- include_str!("test_vec/07/Prio3Sum_1.json"),
+ include_str!("test_vec/08/Prio3Sum_0.json"),
+ include_str!("test_vec/08/Prio3Sum_1.json"),
] {
- let t: TPrio3<u128> = serde_json::from_str(test_vector_str).unwrap();
- let bits = t.other_params["bits"].as_u64().unwrap() as usize;
- let prio3 = Prio3::new_sum(t.shares, bits).unwrap();
-
- let aggregate_result = check_aggregate_test_vec(&prio3, &t);
- assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap() as u128);
+ check_test_vec(test_vector_str, |json_params, num_shares| {
+ let bits = json_params["bits"].as_u64().unwrap() as usize;
+ Prio3::new_sum(num_shares, bits).unwrap()
+ });
}
}
#[test]
fn test_vec_prio3_sum_vec() {
for test_vector_str in [
- include_str!("test_vec/07/Prio3SumVec_0.json"),
- include_str!("test_vec/07/Prio3SumVec_1.json"),
+ include_str!("test_vec/08/Prio3SumVec_0.json"),
+ include_str!("test_vec/08/Prio3SumVec_1.json"),
] {
- let t: TPrio3<Vec<u128>> = serde_json::from_str(test_vector_str).unwrap();
- let bits = t.other_params["bits"].as_u64().unwrap() as usize;
- let length = t.other_params["length"].as_u64().unwrap() as usize;
- let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize;
- let prio3 = Prio3::new_sum_vec(t.shares, bits, length, chunk_length).unwrap();
-
- let aggregate_result = check_aggregate_test_vec(&prio3, &t);
- let expected_aggregate_result = t
- .agg_result
- .as_array()
- .unwrap()
- .iter()
- .map(|val| val.as_u64().unwrap() as u128)
- .collect::<Vec<u128>>();
- assert_eq!(aggregate_result, expected_aggregate_result);
+ check_test_vec(test_vector_str, |json_params, num_shares| {
+ let bits = json_params["bits"].as_u64().unwrap() as usize;
+ let length = json_params["length"].as_u64().unwrap() as usize;
+ let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize;
+ Prio3::new_sum_vec(num_shares, bits, length, chunk_length).unwrap()
+ });
}
}
#[test]
fn test_vec_prio3_histogram() {
for test_vector_str in [
- include_str!("test_vec/07/Prio3Histogram_0.json"),
- include_str!("test_vec/07/Prio3Histogram_1.json"),
+ include_str!("test_vec/08/Prio3Histogram_0.json"),
+ include_str!("test_vec/08/Prio3Histogram_1.json"),
] {
- let t: TPrio3<usize> = serde_json::from_str(test_vector_str).unwrap();
- let length = t.other_params["length"].as_u64().unwrap() as usize;
- let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize;
- let prio3 = Prio3::new_histogram(t.shares, length, chunk_length).unwrap();
-
- let aggregate_result = check_aggregate_test_vec(&prio3, &t);
- let expected_aggregate_result = t
- .agg_result
- .as_array()
- .unwrap()
- .iter()
- .map(|val| val.as_u64().unwrap() as u128)
- .collect::<Vec<u128>>();
- assert_eq!(aggregate_result, expected_aggregate_result);
+ check_test_vec(test_vector_str, |json_params, num_shares| {
+ let length = json_params["length"].as_u64().unwrap() as usize;
+ let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize;
+ Prio3::new_histogram(num_shares, length, chunk_length).unwrap()
+ });
}
}