diff options
Diffstat (limited to 'third_party/rust/prio/src/flp')
-rw-r--r-- | third_party/rust/prio/src/flp/types.rs | 460 | ||||
-rw-r--r-- | third_party/rust/prio/src/flp/types/fixedpoint_l2.rs | 131 |
2 files changed, 120 insertions, 471 deletions
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs index 18c290355c..bca88a36cd 100644 --- a/third_party/rust/prio/src/flp/types.rs +++ b/third_party/rust/prio/src/flp/types.rs @@ -9,6 +9,7 @@ use crate::polynomial::poly_range_check; use std::convert::TryInto; use std::fmt::{self, Debug}; use std::marker::PhantomData; +use subtle::Choice; /// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of the measurements (i.e., the total number of `1s`). #[derive(Clone, PartialEq, Eq)] pub struct Count<F> { @@ -37,18 +38,16 @@ impl<F: FftFriendlyFieldElement> Default for Count<F> { } impl<F: FftFriendlyFieldElement> Type for Count<F> { - const ID: u32 = 0x00000000; - type Measurement = F::Integer; + type Measurement = bool; type AggregateResult = F::Integer; type Field = F; - fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> { - let max = F::valid_integer_try_from(1)?; - if *value > max { - return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); - } - - Ok(vec![F::from(*value)]) + fn encode_measurement(&self, value: &bool) -> Result<Vec<F>, FlpError> { + Ok(vec![F::conditional_select( + &F::zero(), + &F::one(), + Choice::from(u8::from(*value)), + )]) } fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { @@ -140,13 +139,12 @@ impl<F: FftFriendlyFieldElement> Sum<F> { } impl<F: FftFriendlyFieldElement> Type for Sum<F> { - const ID: u32 = 0x00000001; type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { - let v = F::encode_into_bitvector_representation(summand, self.bits)?; + let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); Ok(v) } @@ -174,7 +172,7 @@ impl<F: FftFriendlyFieldElement> Type for Sum<F> { fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { self.truncate_call_check(&input)?; - let res = F::decode_from_bitvector_representation(&input)?; + let res = F::decode_bitvector(&input)?; Ok(vec![res]) } @@ -239,13 +237,12 @@ impl<F: FftFriendlyFieldElement> Average<F> { } impl<F: FftFriendlyFieldElement> Type for Average<F> { - const ID: u32 = 0xFFFF0000; type Measurement = F::Integer; type AggregateResult = f64; type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { - let v = F::encode_into_bitvector_representation(summand, self.bits)?; + let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); Ok(v) } @@ -279,7 +276,7 @@ impl<F: FftFriendlyFieldElement> Type for Average<F> { fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { self.truncate_call_check(&input)?; - let res = F::decode_from_bitvector_representation(&input)?; + let res = F::decode_bitvector(&input)?; Ok(vec![res]) } @@ -380,7 +377,6 @@ where F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>> + Eq + 'static, { - const ID: u32 = 0x00000003; type Measurement = usize; type AggregateResult = Vec<F::Integer>; type Field = F; @@ -574,7 +570,6 @@ where F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>> + Eq + 'static, { - const ID: u32 = 0x00000002; type Measurement = Vec<F::Integer>; type AggregateResult = Vec<F::Integer>; type Field = F; @@ -588,18 +583,15 @@ where ))); } - let mut flattened = vec![F::zero(); self.flattened_len]; - for (summand, chunk) in measurement - .iter() - .zip(flattened.chunks_exact_mut(self.bits)) - { + let mut flattened = Vec::with_capacity(self.flattened_len); + for summand in measurement.iter() { if summand > &self.max { return Err(FlpError::Encode(format!( "summand exceeds maximum of 2^{}-1", self.bits ))); } - F::fill_with_bitvector_representation(summand, chunk)?; + flattened.extend(F::encode_as_bitvector(*summand, self.bits)?); } Ok(flattened) @@ -642,7 +634,7 @@ where self.truncate_call_check(&input)?; let mut unflattened = Vec::with_capacity(self.len); for chunk in input.chunks(self.bits) { - unflattened.push(F::decode_from_bitvector_representation(chunk)?); + unflattened.push(F::decode_bitvector(chunk)?); } Ok(unflattened) } @@ -783,7 +775,7 @@ mod tests { use crate::flp::gadgets::ParallelSum; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; - use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; + use crate::flp::test_utils::FlpTest; use std::cmp; #[test] @@ -797,7 +789,7 @@ mod tests { count .decode_result( &count - .truncate(count.encode_measurement(&1).unwrap()) + .truncate(count.encode_measurement(&true).unwrap()) .unwrap(), 1 ) @@ -806,39 +798,11 @@ mod tests { ); // Test FLP on valid input. - flp_validity_test( - &count, - &count.encode_measurement(&1).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &count, - &count.encode_measurement(&0).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]); + FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]); // Test FLP on invalid input. - flp_validity_test( - &count, - &[TestField::from(1337)], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&count, &[TestField::from(1337)]); // Try running the validity circuit on an input that's too short. count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); @@ -865,72 +829,22 @@ mod tests { ); // Test FLP on valid input. - flp_validity_test( + FlpTest::expect_valid::<3>( &sum, &sum.encode_measurement(&1337).unwrap(), - &ValidityTestCase { - expect_valid: true, - expected_output: Some(vec![TestField::from(1337)]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(0).unwrap(), - &[], - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(2).unwrap(), - &[one, zero], - &ValidityTestCase { - expect_valid: true, - expected_output: Some(vec![one]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( + &[TestField::from(1337)], + ); + FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]); + FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]); + FlpTest::expect_valid::<3>( &Sum::new(9).unwrap(), &[one, zero, one, one, zero, one, one, one, zero], - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![TestField::from(237)]), - num_shares: 3, - }, - ) - .unwrap(); + &[TestField::from(237)], + ); // Test FLP on invalid input. - flp_validity_test( - &Sum::new(3).unwrap(), - &[one, nine, zero], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(5).unwrap(), - &[zero, zero, zero, zero, nine], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]); + FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]); } #[test] @@ -1000,83 +914,29 @@ mod tests { ); // Test valid inputs. - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&0).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one, zero, zero]), - num_shares: 3, - }, - ) - .unwrap(); + &[one, zero, zero], + ); - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&1).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero, one, zero]), - num_shares: 3, - }, - ) - .unwrap(); + &[zero, one, zero], + ); - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&2).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero, zero, one]), - num_shares: 3, - }, - ) - .unwrap(); + &[zero, zero, one], + ); // Test invalid inputs. - flp_validity_test( - &hist, - &[zero, zero, nine], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[zero, one, one], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[one, one, one], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[zero, zero, zero], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&hist, &[zero, zero, nine]); + FlpTest::expect_invalid::<3>(&hist, &[zero, one, one]); + FlpTest::expect_invalid::<3>(&hist, &[one, one, one]); + FlpTest::expect_invalid::<3>(&hist, &[zero, zero, zero]); } #[test] @@ -1104,72 +964,38 @@ mod tests { for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - flp_validity_test( + FlpTest::expect_valid_no_output::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one; len]), - num_shares: 3, - }, - ) - .unwrap(); + ); } let len = 100; let sum_vec = f(1, len, 10).unwrap(); - flp_validity_test( + FlpTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one; len]), - num_shares: 3, - }, - ) - .unwrap(); + &vec![one; len], + ); let len = 23; let sum_vec = f(4, len, 4).unwrap(); - flp_validity_test( + FlpTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![9; len]).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![nine; len]), - num_shares: 3, - }, - ) - .unwrap(); + &vec![nine; len], + ); // Test on invalid inputs. for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - flp_validity_test( - &sum_vec, - &vec![nine; len], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]); } let len = 23; let sum_vec = f(2, len, 4).unwrap(); - flp_validity_test( - &sum_vec, - &vec![nine; 2 * len], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]); // Round trip let want = vec![1; len]; @@ -1232,184 +1058,6 @@ mod tests { } } -#[cfg(test)] -mod test_utils { - use super::*; - use crate::field::{random_vector, split_vector, FieldElement}; - - pub(crate) struct ValidityTestCase<F> { - pub(crate) expect_valid: bool, - pub(crate) expected_output: Option<Vec<F>>, - // Number of shares to split input and proofs into in `flp_test`. - pub(crate) num_shares: usize, - } - - pub(crate) fn flp_validity_test<T: Type>( - typ: &T, - input: &[T::Field], - t: &ValidityTestCase<T::Field>, - ) -> Result<(), FlpError> { - let mut gadgets = typ.gadget(); - - if input.len() != typ.input_len() { - return Err(FlpError::Test(format!( - "unexpected input length: got {}; want {}", - input.len(), - typ.input_len() - ))); - } - - if typ.query_rand_len() != gadgets.len() { - return Err(FlpError::Test(format!( - "query rand length: got {}; want {}", - typ.query_rand_len(), - gadgets.len() - ))); - } - - let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); - let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); - let query_rand = random_vector(typ.query_rand_len()).unwrap(); - - // Run the validity circuit. - let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?; - if v != T::Field::zero() && t.expect_valid { - return Err(FlpError::Test(format!( - "expected valid input: valid() returned {v}" - ))); - } - if v == T::Field::zero() && !t.expect_valid { - return Err(FlpError::Test(format!( - "expected invalid input: valid() returned {v}" - ))); - } - - // Generate the proof. - let proof = typ.prove(input, &prove_rand, &joint_rand)?; - if proof.len() != typ.proof_len() { - return Err(FlpError::Test(format!( - "unexpected proof length: got {}; want {}", - proof.len(), - typ.proof_len() - ))); - } - - // Query the proof. - let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?; - if verifier.len() != typ.verifier_len() { - return Err(FlpError::Test(format!( - "unexpected verifier length: got {}; want {}", - verifier.len(), - typ.verifier_len() - ))); - } - - // Decide if the input is valid. - let res = typ.decide(&verifier)?; - if res != t.expect_valid { - return Err(FlpError::Test(format!( - "decision is {}; want {}", - res, t.expect_valid, - ))); - } - - // Run distributed FLP. - let input_shares: Vec<Vec<T::Field>> = split_vector(input, t.num_shares) - .unwrap() - .into_iter() - .collect(); - - let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, t.num_shares) - .unwrap() - .into_iter() - .collect(); - - let verifier: Vec<T::Field> = (0..t.num_shares) - .map(|i| { - typ.query( - &input_shares[i], - &proof_shares[i], - &query_rand, - &joint_rand, - t.num_shares, - ) - .unwrap() - }) - .reduce(|mut left, right| { - for (x, y) in left.iter_mut().zip(right.iter()) { - *x += *y; - } - left - }) - .unwrap(); - - let res = typ.decide(&verifier)?; - if res != t.expect_valid { - return Err(FlpError::Test(format!( - "distributed decision is {}; want {}", - res, t.expect_valid, - ))); - } - - // Try verifying various proof mutants. - for i in 0..proof.len() { - let mut mutated_proof = proof.clone(); - mutated_proof[i] += T::Field::one(); - let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?; - if typ.decide(&verifier)? { - return Err(FlpError::Test(format!( - "decision for proof mutant {} is {}; want {}", - i, true, false, - ))); - } - } - - // Try verifying a proof that is too short. - let mut mutated_proof = proof.clone(); - mutated_proof.truncate(gadgets[0].arity() - 1); - if typ - .query(input, &mutated_proof, &query_rand, &joint_rand, 1) - .is_ok() - { - return Err(FlpError::Test( - "query on short proof succeeded; want failure".to_string(), - )); - } - - // Try verifying a proof that is too long. - let mut mutated_proof = proof; - mutated_proof.extend_from_slice(&[T::Field::one(); 17]); - if typ - .query(input, &mutated_proof, &query_rand, &joint_rand, 1) - .is_ok() - { - return Err(FlpError::Test( - "query on long proof succeeded; want failure".to_string(), - )); - } - - if let Some(ref want) = t.expected_output { - let got = typ.truncate(input.to_vec())?; - - if got.len() != typ.output_len() { - return Err(FlpError::Test(format!( - "unexpected output length: got {}; want {}", - got.len(), - typ.output_len() - ))); - } - - if &got != want { - return Err(FlpError::Test(format!( - "unexpected output: got {got:?}; want {want:?}" - ))); - } - } - - Ok(()) - } -} - #[cfg(feature = "experimental")] #[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] pub mod fixedpoint_l2; diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs index b5aa2fd116..8766c035b8 100644 --- a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs +++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs @@ -172,15 +172,17 @@ pub mod compatible_float; use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError}; -use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt}; +use crate::field::{ + Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt, Integer, +}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat; use crate::flp::types::parallel_sum_range_checks; use crate::flp::{FlpError, Gadget, Type, TypeWithNoise}; -use crate::vdaf::xof::SeedStreamSha3; +use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::traits::Fixed; use num_bigint::{BigInt, BigUint, TryFromBigIntError}; -use num_integer::Integer; +use num_integer::Integer as _; use num_rational::Ratio; use rand::{distributions::Distribution, Rng}; use rand_core::SeedableRng; @@ -250,7 +252,7 @@ where /// fixed point vector with `entries` entries. pub fn new(entries: usize) -> Result<Self, FlpError> { // (0) initialize constants - let fi_one = u128::from(Field128::one()); + let fi_one = <Field128 as FieldElementWithInteger>::Integer::one(); // (I) Check that the fixed type is compatible. // @@ -400,7 +402,6 @@ where SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static, SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static, { - const ID: u32 = 0xFFFF0000; type Measurement = Vec<T>; type AggregateResult = Vec<f64>; type Field = Field128; @@ -419,12 +420,9 @@ where // Encode the integer entries bitwise, and write them into the `encoded` // vector. let mut encoded: Vec<Field128> = - vec![Field128::zero(); self.bits_per_entry * self.entries + self.bits_for_norm]; - for (l, entry) in integer_entries.clone().enumerate() { - Field128::fill_with_bitvector_representation( - &entry, - &mut encoded[l * self.bits_per_entry..(l + 1) * self.bits_per_entry], - )?; + Vec::with_capacity(self.bits_per_entry * self.entries + self.bits_for_norm); + for entry in integer_entries.clone() { + encoded.extend(Field128::encode_as_bitvector(entry, self.bits_per_entry)?); } // (II) Vector norm. @@ -434,10 +432,7 @@ where let norm_int = u128::from(norm); // Write the norm into the `entries` vector. - Field128::fill_with_bitvector_representation( - &norm_int, - &mut encoded[self.range_norm_begin..self.range_norm_end], - )?; + encoded.extend(Field128::encode_as_bitvector(norm_int, self.bits_for_norm)?); Ok(encoded) } @@ -535,7 +530,7 @@ where // decode the bit-encoded entries into elements in the range [0,2^n): let decoded_entries: Result<Vec<_>, _> = input[0..self.entries * self.bits_per_entry] .chunks(self.bits_per_entry) - .map(Field128::decode_from_bitvector_representation) + .map(Field128::decode_bitvector) .collect(); // run parallel sum gadget on the decoded entries @@ -544,7 +539,7 @@ where // Chunks which are too short need to be extended with a share of the // encoded zero value, that is: 1/num_shares * (2^(n-1)) - let fi_one = u128::from(Field128::one()); + let fi_one = <Field128 as FieldElementWithInteger>::Integer::one(); let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1)); let zero_enc_share = zero_enc * num_shares_inverse; @@ -567,7 +562,7 @@ where // The submitted norm is also decoded from its bit-encoding, and // compared with the computed norm. let submitted_norm_enc = &input[self.range_norm_begin..self.range_norm_end]; - let submitted_norm = Field128::decode_from_bitvector_representation(submitted_norm_enc)?; + let submitted_norm = Field128::decode_bitvector(submitted_norm_enc)?; let norm_check = computed_norm - submitted_norm; @@ -586,7 +581,7 @@ where let start = i_entry * self.bits_per_entry; let end = (i_entry + 1) * self.bits_per_entry; - let decoded = Field128::decode_from_bitvector_representation(&input[start..end])?; + let decoded = Field128::decode_bitvector(&input[start..end])?; decoded_vector.push(decoded); } Ok(decoded_vector) @@ -644,7 +639,11 @@ where agg_result: &mut [Self::Field], _num_measurements: usize, ) -> Result<(), FlpError> { - self.add_noise(dp_strategy, agg_result, &mut SeedStreamSha3::from_entropy()) + self.add_noise( + dp_strategy, + agg_result, + &mut SeedStreamTurboShake128::from_entropy(), + ) } } @@ -686,8 +685,8 @@ mod tests { use crate::dp::{Rational, ZCdpBudget}; use crate::field::{random_vector, Field128, FieldElement}; use crate::flp::gadgets::ParallelSum; - use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; - use crate::vdaf::xof::SeedStreamSha3; + use crate::flp::test_utils::FlpTest; + use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::types::extra::{U127, U14, U63}; use fixed::{FixedI128, FixedI16, FixedI64}; use fixed_macro::fixed; @@ -768,15 +767,23 @@ mod tests { let strategy = ZCdpDiscreteGaussian::from_budget(ZCdpBudget::new( Rational::from_unsigned(100u8, 3u8).unwrap(), )); - vsum.add_noise(&strategy, &mut v, &mut SeedStreamSha3::from_seed([0u8; 16])) - .unwrap(); + vsum.add_noise( + &strategy, + &mut v, + &mut SeedStreamTurboShake128::from_seed([0u8; 16]), + ) + .unwrap(); assert_eq!( vsum.decode_result(&v, 1).unwrap(), match n { // sensitivity depends on encoding so the noise differs - 16 => vec![0.150604248046875, 0.139373779296875, -0.03759765625], - 32 => vec![0.3051439793780446, 0.1226568529382348, 0.08595499861985445], - 64 => vec![0.2896077990915178, 0.16115188007715098, 0.0788390114728425], + 16 => vec![0.288970947265625, 0.168853759765625, 0.085662841796875], + 32 => vec![0.257810294162482, 0.10634658299386501, 0.10149003705009818], + 64 => vec![ + 0.37697368351762867, + -0.02388947667663828, + 0.19813152630930916 + ], _ => panic!("unsupported bitsize"), } ); @@ -785,52 +792,46 @@ mod tests { let mut input: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap(); assert_eq!(input[0], Field128::zero()); input[0] = one; // it was zero - flp_validity_test( - &vsum, - &input, - &ValidityTestCase::<Field128> { - expect_valid: false, - expected_output: Some(vec![ - Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 - Field128::from(enc_vec[1]), - Field128::from(enc_vec[2]), - ]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &input, + expected_output: Some(&[ + Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + expect_valid: false, + } + .run::<3>(); // encoding contains entries that are not zero or one let mut input2: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap(); input2[0] = one + one; - flp_validity_test( - &vsum, - &input2, - &ValidityTestCase::<Field128> { - expect_valid: false, - expected_output: Some(vec![ - Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 - Field128::from(enc_vec[1]), - Field128::from(enc_vec[2]), - ]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &input2, + expected_output: Some(&[ + Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + expect_valid: false, + } + .run::<3>(); // norm is too big // 2^n - 1, the field element encoded by the all-1 vector let one_enc = Field128::from(((2_u128) << (n - 1)) - 1); - flp_validity_test( - &vsum, - &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors - &ValidityTestCase::<Field128> { - expect_valid: false, - expected_output: Some(vec![one_enc; 3]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors + expected_output: Some(&[one_enc; 3]), + expect_valid: false, + } + .run::<3>(); // invalid submission length, should be 3n + (2*n - 2) for a // 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm. |