summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/flp
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/flp')
-rw-r--r--third_party/rust/prio/src/flp/types.rs460
-rw-r--r--third_party/rust/prio/src/flp/types/fixedpoint_l2.rs131
2 files changed, 120 insertions, 471 deletions
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs
index 18c290355c..bca88a36cd 100644
--- a/third_party/rust/prio/src/flp/types.rs
+++ b/third_party/rust/prio/src/flp/types.rs
@@ -9,6 +9,7 @@ use crate::polynomial::poly_range_check;
use std::convert::TryInto;
use std::fmt::{self, Debug};
use std::marker::PhantomData;
+use subtle::Choice;
/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of the measurements (i.e., the total number of `1s`).
#[derive(Clone, PartialEq, Eq)]
pub struct Count<F> {
@@ -37,18 +38,16 @@ impl<F: FftFriendlyFieldElement> Default for Count<F> {
}
impl<F: FftFriendlyFieldElement> Type for Count<F> {
- const ID: u32 = 0x00000000;
- type Measurement = F::Integer;
+ type Measurement = bool;
type AggregateResult = F::Integer;
type Field = F;
- fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> {
- let max = F::valid_integer_try_from(1)?;
- if *value > max {
- return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
- }
-
- Ok(vec![F::from(*value)])
+ fn encode_measurement(&self, value: &bool) -> Result<Vec<F>, FlpError> {
+ Ok(vec![F::conditional_select(
+ &F::zero(),
+ &F::one(),
+ Choice::from(u8::from(*value)),
+ )])
}
fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
@@ -140,13 +139,12 @@ impl<F: FftFriendlyFieldElement> Sum<F> {
}
impl<F: FftFriendlyFieldElement> Type for Sum<F> {
- const ID: u32 = 0x00000001;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
- let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ let v = F::encode_as_bitvector(*summand, self.bits)?.collect();
Ok(v)
}
@@ -174,7 +172,7 @@ impl<F: FftFriendlyFieldElement> Type for Sum<F> {
fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
self.truncate_call_check(&input)?;
- let res = F::decode_from_bitvector_representation(&input)?;
+ let res = F::decode_bitvector(&input)?;
Ok(vec![res])
}
@@ -239,13 +237,12 @@ impl<F: FftFriendlyFieldElement> Average<F> {
}
impl<F: FftFriendlyFieldElement> Type for Average<F> {
- const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = f64;
type Field = F;
fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
- let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ let v = F::encode_as_bitvector(*summand, self.bits)?.collect();
Ok(v)
}
@@ -279,7 +276,7 @@ impl<F: FftFriendlyFieldElement> Type for Average<F> {
fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
self.truncate_call_check(&input)?;
- let res = F::decode_from_bitvector_representation(&input)?;
+ let res = F::decode_bitvector(&input)?;
Ok(vec![res])
}
@@ -380,7 +377,6 @@ where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
- const ID: u32 = 0x00000003;
type Measurement = usize;
type AggregateResult = Vec<F::Integer>;
type Field = F;
@@ -574,7 +570,6 @@ where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
- const ID: u32 = 0x00000002;
type Measurement = Vec<F::Integer>;
type AggregateResult = Vec<F::Integer>;
type Field = F;
@@ -588,18 +583,15 @@ where
)));
}
- let mut flattened = vec![F::zero(); self.flattened_len];
- for (summand, chunk) in measurement
- .iter()
- .zip(flattened.chunks_exact_mut(self.bits))
- {
+ let mut flattened = Vec::with_capacity(self.flattened_len);
+ for summand in measurement.iter() {
if summand > &self.max {
return Err(FlpError::Encode(format!(
"summand exceeds maximum of 2^{}-1",
self.bits
)));
}
- F::fill_with_bitvector_representation(summand, chunk)?;
+ flattened.extend(F::encode_as_bitvector(*summand, self.bits)?);
}
Ok(flattened)
@@ -642,7 +634,7 @@ where
self.truncate_call_check(&input)?;
let mut unflattened = Vec::with_capacity(self.len);
for chunk in input.chunks(self.bits) {
- unflattened.push(F::decode_from_bitvector_representation(chunk)?);
+ unflattened.push(F::decode_bitvector(chunk)?);
}
Ok(unflattened)
}
@@ -783,7 +775,7 @@ mod tests {
use crate::flp::gadgets::ParallelSum;
#[cfg(feature = "multithreaded")]
use crate::flp::gadgets::ParallelSumMultithreaded;
- use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase};
+ use crate::flp::test_utils::FlpTest;
use std::cmp;
#[test]
@@ -797,7 +789,7 @@ mod tests {
count
.decode_result(
&count
- .truncate(count.encode_measurement(&1).unwrap())
+ .truncate(count.encode_measurement(&true).unwrap())
.unwrap(),
1
)
@@ -806,39 +798,11 @@ mod tests {
);
// Test FLP on valid input.
- flp_validity_test(
- &count,
- &count.encode_measurement(&1).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &count,
- &count.encode_measurement(&0).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]);
+ FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]);
// Test FLP on invalid input.
- flp_validity_test(
- &count,
- &[TestField::from(1337)],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&count, &[TestField::from(1337)]);
// Try running the validity circuit on an input that's too short.
count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err();
@@ -865,72 +829,22 @@ mod tests {
);
// Test FLP on valid input.
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&sum,
&sum.encode_measurement(&1337).unwrap(),
- &ValidityTestCase {
- expect_valid: true,
- expected_output: Some(vec![TestField::from(1337)]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &Sum::new(0).unwrap(),
- &[],
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &Sum::new(2).unwrap(),
- &[one, zero],
- &ValidityTestCase {
- expect_valid: true,
- expected_output: Some(vec![one]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
+ &[TestField::from(1337)],
+ );
+ FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]);
+ FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]);
+ FlpTest::expect_valid::<3>(
&Sum::new(9).unwrap(),
&[one, zero, one, one, zero, one, one, one, zero],
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![TestField::from(237)]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[TestField::from(237)],
+ );
// Test FLP on invalid input.
- flp_validity_test(
- &Sum::new(3).unwrap(),
- &[one, nine, zero],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &Sum::new(5).unwrap(),
- &[zero, zero, zero, zero, nine],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]);
+ FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]);
}
#[test]
@@ -1000,83 +914,29 @@ mod tests {
);
// Test valid inputs.
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&hist,
&hist.encode_measurement(&0).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one, zero, zero]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[one, zero, zero],
+ );
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&hist,
&hist.encode_measurement(&1).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero, one, zero]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[zero, one, zero],
+ );
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&hist,
&hist.encode_measurement(&2).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero, zero, one]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[zero, zero, one],
+ );
// Test invalid inputs.
- flp_validity_test(
- &hist,
- &[zero, zero, nine],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &hist,
- &[zero, one, one],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &hist,
- &[one, one, one],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &hist,
- &[zero, zero, zero],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&hist, &[zero, zero, nine]);
+ FlpTest::expect_invalid::<3>(&hist, &[zero, one, one]);
+ FlpTest::expect_invalid::<3>(&hist, &[one, one, one]);
+ FlpTest::expect_invalid::<3>(&hist, &[zero, zero, zero]);
}
#[test]
@@ -1104,72 +964,38 @@ mod tests {
for len in 1..10 {
let chunk_length = cmp::max((len as f64).sqrt() as usize, 1);
let sum_vec = f(1, len, chunk_length).unwrap();
- flp_validity_test(
+ FlpTest::expect_valid_no_output::<3>(
&sum_vec,
&sum_vec.encode_measurement(&vec![1; len]).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one; len]),
- num_shares: 3,
- },
- )
- .unwrap();
+ );
}
let len = 100;
let sum_vec = f(1, len, 10).unwrap();
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&sum_vec,
&sum_vec.encode_measurement(&vec![1; len]).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one; len]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &vec![one; len],
+ );
let len = 23;
let sum_vec = f(4, len, 4).unwrap();
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&sum_vec,
&sum_vec.encode_measurement(&vec![9; len]).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![nine; len]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &vec![nine; len],
+ );
// Test on invalid inputs.
for len in 1..10 {
let chunk_length = cmp::max((len as f64).sqrt() as usize, 1);
let sum_vec = f(1, len, chunk_length).unwrap();
- flp_validity_test(
- &sum_vec,
- &vec![nine; len],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]);
}
let len = 23;
let sum_vec = f(2, len, 4).unwrap();
- flp_validity_test(
- &sum_vec,
- &vec![nine; 2 * len],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]);
// Round trip
let want = vec![1; len];
@@ -1232,184 +1058,6 @@ mod tests {
}
}
-#[cfg(test)]
-mod test_utils {
- use super::*;
- use crate::field::{random_vector, split_vector, FieldElement};
-
- pub(crate) struct ValidityTestCase<F> {
- pub(crate) expect_valid: bool,
- pub(crate) expected_output: Option<Vec<F>>,
- // Number of shares to split input and proofs into in `flp_test`.
- pub(crate) num_shares: usize,
- }
-
- pub(crate) fn flp_validity_test<T: Type>(
- typ: &T,
- input: &[T::Field],
- t: &ValidityTestCase<T::Field>,
- ) -> Result<(), FlpError> {
- let mut gadgets = typ.gadget();
-
- if input.len() != typ.input_len() {
- return Err(FlpError::Test(format!(
- "unexpected input length: got {}; want {}",
- input.len(),
- typ.input_len()
- )));
- }
-
- if typ.query_rand_len() != gadgets.len() {
- return Err(FlpError::Test(format!(
- "query rand length: got {}; want {}",
- typ.query_rand_len(),
- gadgets.len()
- )));
- }
-
- let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
- let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
- let query_rand = random_vector(typ.query_rand_len()).unwrap();
-
- // Run the validity circuit.
- let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?;
- if v != T::Field::zero() && t.expect_valid {
- return Err(FlpError::Test(format!(
- "expected valid input: valid() returned {v}"
- )));
- }
- if v == T::Field::zero() && !t.expect_valid {
- return Err(FlpError::Test(format!(
- "expected invalid input: valid() returned {v}"
- )));
- }
-
- // Generate the proof.
- let proof = typ.prove(input, &prove_rand, &joint_rand)?;
- if proof.len() != typ.proof_len() {
- return Err(FlpError::Test(format!(
- "unexpected proof length: got {}; want {}",
- proof.len(),
- typ.proof_len()
- )));
- }
-
- // Query the proof.
- let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?;
- if verifier.len() != typ.verifier_len() {
- return Err(FlpError::Test(format!(
- "unexpected verifier length: got {}; want {}",
- verifier.len(),
- typ.verifier_len()
- )));
- }
-
- // Decide if the input is valid.
- let res = typ.decide(&verifier)?;
- if res != t.expect_valid {
- return Err(FlpError::Test(format!(
- "decision is {}; want {}",
- res, t.expect_valid,
- )));
- }
-
- // Run distributed FLP.
- let input_shares: Vec<Vec<T::Field>> = split_vector(input, t.num_shares)
- .unwrap()
- .into_iter()
- .collect();
-
- let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, t.num_shares)
- .unwrap()
- .into_iter()
- .collect();
-
- let verifier: Vec<T::Field> = (0..t.num_shares)
- .map(|i| {
- typ.query(
- &input_shares[i],
- &proof_shares[i],
- &query_rand,
- &joint_rand,
- t.num_shares,
- )
- .unwrap()
- })
- .reduce(|mut left, right| {
- for (x, y) in left.iter_mut().zip(right.iter()) {
- *x += *y;
- }
- left
- })
- .unwrap();
-
- let res = typ.decide(&verifier)?;
- if res != t.expect_valid {
- return Err(FlpError::Test(format!(
- "distributed decision is {}; want {}",
- res, t.expect_valid,
- )));
- }
-
- // Try verifying various proof mutants.
- for i in 0..proof.len() {
- let mut mutated_proof = proof.clone();
- mutated_proof[i] += T::Field::one();
- let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?;
- if typ.decide(&verifier)? {
- return Err(FlpError::Test(format!(
- "decision for proof mutant {} is {}; want {}",
- i, true, false,
- )));
- }
- }
-
- // Try verifying a proof that is too short.
- let mut mutated_proof = proof.clone();
- mutated_proof.truncate(gadgets[0].arity() - 1);
- if typ
- .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
- .is_ok()
- {
- return Err(FlpError::Test(
- "query on short proof succeeded; want failure".to_string(),
- ));
- }
-
- // Try verifying a proof that is too long.
- let mut mutated_proof = proof;
- mutated_proof.extend_from_slice(&[T::Field::one(); 17]);
- if typ
- .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
- .is_ok()
- {
- return Err(FlpError::Test(
- "query on long proof succeeded; want failure".to_string(),
- ));
- }
-
- if let Some(ref want) = t.expected_output {
- let got = typ.truncate(input.to_vec())?;
-
- if got.len() != typ.output_len() {
- return Err(FlpError::Test(format!(
- "unexpected output length: got {}; want {}",
- got.len(),
- typ.output_len()
- )));
- }
-
- if &got != want {
- return Err(FlpError::Test(format!(
- "unexpected output: got {got:?}; want {want:?}"
- )));
- }
- }
-
- Ok(())
- }
-}
-
#[cfg(feature = "experimental")]
#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
pub mod fixedpoint_l2;
diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
index b5aa2fd116..8766c035b8 100644
--- a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
+++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
@@ -172,15 +172,17 @@
pub mod compatible_float;
use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError};
-use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt};
+use crate::field::{
+ Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt, Integer,
+};
use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat;
use crate::flp::types::parallel_sum_range_checks;
use crate::flp::{FlpError, Gadget, Type, TypeWithNoise};
-use crate::vdaf::xof::SeedStreamSha3;
+use crate::vdaf::xof::SeedStreamTurboShake128;
use fixed::traits::Fixed;
use num_bigint::{BigInt, BigUint, TryFromBigIntError};
-use num_integer::Integer;
+use num_integer::Integer as _;
use num_rational::Ratio;
use rand::{distributions::Distribution, Rng};
use rand_core::SeedableRng;
@@ -250,7 +252,7 @@ where
/// fixed point vector with `entries` entries.
pub fn new(entries: usize) -> Result<Self, FlpError> {
// (0) initialize constants
- let fi_one = u128::from(Field128::one());
+ let fi_one = <Field128 as FieldElementWithInteger>::Integer::one();
// (I) Check that the fixed type is compatible.
//
@@ -400,7 +402,6 @@ where
SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static,
SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static,
{
- const ID: u32 = 0xFFFF0000;
type Measurement = Vec<T>;
type AggregateResult = Vec<f64>;
type Field = Field128;
@@ -419,12 +420,9 @@ where
// Encode the integer entries bitwise, and write them into the `encoded`
// vector.
let mut encoded: Vec<Field128> =
- vec![Field128::zero(); self.bits_per_entry * self.entries + self.bits_for_norm];
- for (l, entry) in integer_entries.clone().enumerate() {
- Field128::fill_with_bitvector_representation(
- &entry,
- &mut encoded[l * self.bits_per_entry..(l + 1) * self.bits_per_entry],
- )?;
+ Vec::with_capacity(self.bits_per_entry * self.entries + self.bits_for_norm);
+ for entry in integer_entries.clone() {
+ encoded.extend(Field128::encode_as_bitvector(entry, self.bits_per_entry)?);
}
// (II) Vector norm.
@@ -434,10 +432,7 @@ where
let norm_int = u128::from(norm);
// Write the norm into the `entries` vector.
- Field128::fill_with_bitvector_representation(
- &norm_int,
- &mut encoded[self.range_norm_begin..self.range_norm_end],
- )?;
+ encoded.extend(Field128::encode_as_bitvector(norm_int, self.bits_for_norm)?);
Ok(encoded)
}
@@ -535,7 +530,7 @@ where
// decode the bit-encoded entries into elements in the range [0,2^n):
let decoded_entries: Result<Vec<_>, _> = input[0..self.entries * self.bits_per_entry]
.chunks(self.bits_per_entry)
- .map(Field128::decode_from_bitvector_representation)
+ .map(Field128::decode_bitvector)
.collect();
// run parallel sum gadget on the decoded entries
@@ -544,7 +539,7 @@ where
// Chunks which are too short need to be extended with a share of the
// encoded zero value, that is: 1/num_shares * (2^(n-1))
- let fi_one = u128::from(Field128::one());
+ let fi_one = <Field128 as FieldElementWithInteger>::Integer::one();
let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1));
let zero_enc_share = zero_enc * num_shares_inverse;
@@ -567,7 +562,7 @@ where
// The submitted norm is also decoded from its bit-encoding, and
// compared with the computed norm.
let submitted_norm_enc = &input[self.range_norm_begin..self.range_norm_end];
- let submitted_norm = Field128::decode_from_bitvector_representation(submitted_norm_enc)?;
+ let submitted_norm = Field128::decode_bitvector(submitted_norm_enc)?;
let norm_check = computed_norm - submitted_norm;
@@ -586,7 +581,7 @@ where
let start = i_entry * self.bits_per_entry;
let end = (i_entry + 1) * self.bits_per_entry;
- let decoded = Field128::decode_from_bitvector_representation(&input[start..end])?;
+ let decoded = Field128::decode_bitvector(&input[start..end])?;
decoded_vector.push(decoded);
}
Ok(decoded_vector)
@@ -644,7 +639,11 @@ where
agg_result: &mut [Self::Field],
_num_measurements: usize,
) -> Result<(), FlpError> {
- self.add_noise(dp_strategy, agg_result, &mut SeedStreamSha3::from_entropy())
+ self.add_noise(
+ dp_strategy,
+ agg_result,
+ &mut SeedStreamTurboShake128::from_entropy(),
+ )
}
}
@@ -686,8 +685,8 @@ mod tests {
use crate::dp::{Rational, ZCdpBudget};
use crate::field::{random_vector, Field128, FieldElement};
use crate::flp::gadgets::ParallelSum;
- use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase};
- use crate::vdaf::xof::SeedStreamSha3;
+ use crate::flp::test_utils::FlpTest;
+ use crate::vdaf::xof::SeedStreamTurboShake128;
use fixed::types::extra::{U127, U14, U63};
use fixed::{FixedI128, FixedI16, FixedI64};
use fixed_macro::fixed;
@@ -768,15 +767,23 @@ mod tests {
let strategy = ZCdpDiscreteGaussian::from_budget(ZCdpBudget::new(
Rational::from_unsigned(100u8, 3u8).unwrap(),
));
- vsum.add_noise(&strategy, &mut v, &mut SeedStreamSha3::from_seed([0u8; 16]))
- .unwrap();
+ vsum.add_noise(
+ &strategy,
+ &mut v,
+ &mut SeedStreamTurboShake128::from_seed([0u8; 16]),
+ )
+ .unwrap();
assert_eq!(
vsum.decode_result(&v, 1).unwrap(),
match n {
// sensitivity depends on encoding so the noise differs
- 16 => vec![0.150604248046875, 0.139373779296875, -0.03759765625],
- 32 => vec![0.3051439793780446, 0.1226568529382348, 0.08595499861985445],
- 64 => vec![0.2896077990915178, 0.16115188007715098, 0.0788390114728425],
+ 16 => vec![0.288970947265625, 0.168853759765625, 0.085662841796875],
+ 32 => vec![0.257810294162482, 0.10634658299386501, 0.10149003705009818],
+ 64 => vec![
+ 0.37697368351762867,
+ -0.02388947667663828,
+ 0.19813152630930916
+ ],
_ => panic!("unsupported bitsize"),
}
);
@@ -785,52 +792,46 @@ mod tests {
let mut input: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap();
assert_eq!(input[0], Field128::zero());
input[0] = one; // it was zero
- flp_validity_test(
- &vsum,
- &input,
- &ValidityTestCase::<Field128> {
- expect_valid: false,
- expected_output: Some(vec![
- Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0
- Field128::from(enc_vec[1]),
- Field128::from(enc_vec[2]),
- ]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest {
+ name: None,
+ flp: &vsum,
+ input: &input,
+ expected_output: Some(&[
+ Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0
+ Field128::from(enc_vec[1]),
+ Field128::from(enc_vec[2]),
+ ]),
+ expect_valid: false,
+ }
+ .run::<3>();
// encoding contains entries that are not zero or one
let mut input2: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap();
input2[0] = one + one;
- flp_validity_test(
- &vsum,
- &input2,
- &ValidityTestCase::<Field128> {
- expect_valid: false,
- expected_output: Some(vec![
- Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0
- Field128::from(enc_vec[1]),
- Field128::from(enc_vec[2]),
- ]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest {
+ name: None,
+ flp: &vsum,
+ input: &input2,
+ expected_output: Some(&[
+ Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0
+ Field128::from(enc_vec[1]),
+ Field128::from(enc_vec[2]),
+ ]),
+ expect_valid: false,
+ }
+ .run::<3>();
// norm is too big
// 2^n - 1, the field element encoded by the all-1 vector
let one_enc = Field128::from(((2_u128) << (n - 1)) - 1);
- flp_validity_test(
- &vsum,
- &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors
- &ValidityTestCase::<Field128> {
- expect_valid: false,
- expected_output: Some(vec![one_enc; 3]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest {
+ name: None,
+ flp: &vsum,
+ input: &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors
+ expected_output: Some(&[one_enc; 3]),
+ expect_valid: false,
+ }
+ .run::<3>();
// invalid submission length, should be 3n + (2*n - 2) for a
// 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm.