diff options
Diffstat (limited to 'third_party/rust/prio/src/flp/types.rs')
-rw-r--r-- | third_party/rust/prio/src/flp/types.rs | 460 |
1 files changed, 54 insertions, 406 deletions
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs index 18c290355c..bca88a36cd 100644 --- a/third_party/rust/prio/src/flp/types.rs +++ b/third_party/rust/prio/src/flp/types.rs @@ -9,6 +9,7 @@ use crate::polynomial::poly_range_check; use std::convert::TryInto; use std::fmt::{self, Debug}; use std::marker::PhantomData; +use subtle::Choice; /// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of the measurements (i.e., the total number of `1s`). #[derive(Clone, PartialEq, Eq)] pub struct Count<F> { @@ -37,18 +38,16 @@ impl<F: FftFriendlyFieldElement> Default for Count<F> { } impl<F: FftFriendlyFieldElement> Type for Count<F> { - const ID: u32 = 0x00000000; - type Measurement = F::Integer; + type Measurement = bool; type AggregateResult = F::Integer; type Field = F; - fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> { - let max = F::valid_integer_try_from(1)?; - if *value > max { - return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); - } - - Ok(vec![F::from(*value)]) + fn encode_measurement(&self, value: &bool) -> Result<Vec<F>, FlpError> { + Ok(vec![F::conditional_select( + &F::zero(), + &F::one(), + Choice::from(u8::from(*value)), + )]) } fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { @@ -140,13 +139,12 @@ impl<F: FftFriendlyFieldElement> Sum<F> { } impl<F: FftFriendlyFieldElement> Type for Sum<F> { - const ID: u32 = 0x00000001; type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { - let v = F::encode_into_bitvector_representation(summand, self.bits)?; + let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); Ok(v) } @@ -174,7 +172,7 @@ impl<F: FftFriendlyFieldElement> Type for Sum<F> { fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { self.truncate_call_check(&input)?; - let res = F::decode_from_bitvector_representation(&input)?; + let res = F::decode_bitvector(&input)?; Ok(vec![res]) } @@ -239,13 +237,12 @@ impl<F: FftFriendlyFieldElement> Average<F> { } impl<F: FftFriendlyFieldElement> Type for Average<F> { - const ID: u32 = 0xFFFF0000; type Measurement = F::Integer; type AggregateResult = f64; type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { - let v = F::encode_into_bitvector_representation(summand, self.bits)?; + let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); Ok(v) } @@ -279,7 +276,7 @@ impl<F: FftFriendlyFieldElement> Type for Average<F> { fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { self.truncate_call_check(&input)?; - let res = F::decode_from_bitvector_representation(&input)?; + let res = F::decode_bitvector(&input)?; Ok(vec![res]) } @@ -380,7 +377,6 @@ where F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>> + Eq + 'static, { - const ID: u32 = 0x00000003; type Measurement = usize; type AggregateResult = Vec<F::Integer>; type Field = F; @@ -574,7 +570,6 @@ where F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>> + Eq + 'static, { - const ID: u32 = 0x00000002; type Measurement = Vec<F::Integer>; type AggregateResult = Vec<F::Integer>; type Field = F; @@ -588,18 +583,15 @@ where ))); } - let mut flattened = vec![F::zero(); self.flattened_len]; - for (summand, chunk) in measurement - .iter() - .zip(flattened.chunks_exact_mut(self.bits)) - { + let mut flattened = Vec::with_capacity(self.flattened_len); + for summand in measurement.iter() { if summand > &self.max { return Err(FlpError::Encode(format!( "summand exceeds maximum of 2^{}-1", self.bits ))); } - F::fill_with_bitvector_representation(summand, chunk)?; + flattened.extend(F::encode_as_bitvector(*summand, self.bits)?); } Ok(flattened) @@ -642,7 +634,7 @@ where self.truncate_call_check(&input)?; let mut unflattened = Vec::with_capacity(self.len); for chunk in input.chunks(self.bits) { - unflattened.push(F::decode_from_bitvector_representation(chunk)?); + unflattened.push(F::decode_bitvector(chunk)?); } Ok(unflattened) } @@ -783,7 +775,7 @@ mod tests { use crate::flp::gadgets::ParallelSum; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; - use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; + use crate::flp::test_utils::FlpTest; use std::cmp; #[test] @@ -797,7 +789,7 @@ mod tests { count .decode_result( &count - .truncate(count.encode_measurement(&1).unwrap()) + .truncate(count.encode_measurement(&true).unwrap()) .unwrap(), 1 ) @@ -806,39 +798,11 @@ mod tests { ); // Test FLP on valid input. - flp_validity_test( - &count, - &count.encode_measurement(&1).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &count, - &count.encode_measurement(&0).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]); + FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]); // Test FLP on invalid input. - flp_validity_test( - &count, - &[TestField::from(1337)], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&count, &[TestField::from(1337)]); // Try running the validity circuit on an input that's too short. count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); @@ -865,72 +829,22 @@ mod tests { ); // Test FLP on valid input. - flp_validity_test( + FlpTest::expect_valid::<3>( &sum, &sum.encode_measurement(&1337).unwrap(), - &ValidityTestCase { - expect_valid: true, - expected_output: Some(vec![TestField::from(1337)]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(0).unwrap(), - &[], - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(2).unwrap(), - &[one, zero], - &ValidityTestCase { - expect_valid: true, - expected_output: Some(vec![one]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( + &[TestField::from(1337)], + ); + FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]); + FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]); + FlpTest::expect_valid::<3>( &Sum::new(9).unwrap(), &[one, zero, one, one, zero, one, one, one, zero], - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![TestField::from(237)]), - num_shares: 3, - }, - ) - .unwrap(); + &[TestField::from(237)], + ); // Test FLP on invalid input. - flp_validity_test( - &Sum::new(3).unwrap(), - &[one, nine, zero], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(5).unwrap(), - &[zero, zero, zero, zero, nine], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]); + FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]); } #[test] @@ -1000,83 +914,29 @@ mod tests { ); // Test valid inputs. - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&0).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one, zero, zero]), - num_shares: 3, - }, - ) - .unwrap(); + &[one, zero, zero], + ); - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&1).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero, one, zero]), - num_shares: 3, - }, - ) - .unwrap(); + &[zero, one, zero], + ); - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&2).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![zero, zero, one]), - num_shares: 3, - }, - ) - .unwrap(); + &[zero, zero, one], + ); // Test invalid inputs. - flp_validity_test( - &hist, - &[zero, zero, nine], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[zero, one, one], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[one, one, one], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[zero, zero, zero], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&hist, &[zero, zero, nine]); + FlpTest::expect_invalid::<3>(&hist, &[zero, one, one]); + FlpTest::expect_invalid::<3>(&hist, &[one, one, one]); + FlpTest::expect_invalid::<3>(&hist, &[zero, zero, zero]); } #[test] @@ -1104,72 +964,38 @@ mod tests { for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - flp_validity_test( + FlpTest::expect_valid_no_output::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one; len]), - num_shares: 3, - }, - ) - .unwrap(); + ); } let len = 100; let sum_vec = f(1, len, 10).unwrap(); - flp_validity_test( + FlpTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![one; len]), - num_shares: 3, - }, - ) - .unwrap(); + &vec![one; len], + ); let len = 23; let sum_vec = f(4, len, 4).unwrap(); - flp_validity_test( + FlpTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![9; len]).unwrap(), - &ValidityTestCase::<TestField> { - expect_valid: true, - expected_output: Some(vec![nine; len]), - num_shares: 3, - }, - ) - .unwrap(); + &vec![nine; len], + ); // Test on invalid inputs. for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - flp_validity_test( - &sum_vec, - &vec![nine; len], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]); } let len = 23; let sum_vec = f(2, len, 4).unwrap(); - flp_validity_test( - &sum_vec, - &vec![nine; 2 * len], - &ValidityTestCase::<TestField> { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]); // Round trip let want = vec![1; len]; @@ -1232,184 +1058,6 @@ mod tests { } } -#[cfg(test)] -mod test_utils { - use super::*; - use crate::field::{random_vector, split_vector, FieldElement}; - - pub(crate) struct ValidityTestCase<F> { - pub(crate) expect_valid: bool, - pub(crate) expected_output: Option<Vec<F>>, - // Number of shares to split input and proofs into in `flp_test`. - pub(crate) num_shares: usize, - } - - pub(crate) fn flp_validity_test<T: Type>( - typ: &T, - input: &[T::Field], - t: &ValidityTestCase<T::Field>, - ) -> Result<(), FlpError> { - let mut gadgets = typ.gadget(); - - if input.len() != typ.input_len() { - return Err(FlpError::Test(format!( - "unexpected input length: got {}; want {}", - input.len(), - typ.input_len() - ))); - } - - if typ.query_rand_len() != gadgets.len() { - return Err(FlpError::Test(format!( - "query rand length: got {}; want {}", - typ.query_rand_len(), - gadgets.len() - ))); - } - - let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); - let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); - let query_rand = random_vector(typ.query_rand_len()).unwrap(); - - // Run the validity circuit. - let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?; - if v != T::Field::zero() && t.expect_valid { - return Err(FlpError::Test(format!( - "expected valid input: valid() returned {v}" - ))); - } - if v == T::Field::zero() && !t.expect_valid { - return Err(FlpError::Test(format!( - "expected invalid input: valid() returned {v}" - ))); - } - - // Generate the proof. - let proof = typ.prove(input, &prove_rand, &joint_rand)?; - if proof.len() != typ.proof_len() { - return Err(FlpError::Test(format!( - "unexpected proof length: got {}; want {}", - proof.len(), - typ.proof_len() - ))); - } - - // Query the proof. - let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?; - if verifier.len() != typ.verifier_len() { - return Err(FlpError::Test(format!( - "unexpected verifier length: got {}; want {}", - verifier.len(), - typ.verifier_len() - ))); - } - - // Decide if the input is valid. - let res = typ.decide(&verifier)?; - if res != t.expect_valid { - return Err(FlpError::Test(format!( - "decision is {}; want {}", - res, t.expect_valid, - ))); - } - - // Run distributed FLP. - let input_shares: Vec<Vec<T::Field>> = split_vector(input, t.num_shares) - .unwrap() - .into_iter() - .collect(); - - let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, t.num_shares) - .unwrap() - .into_iter() - .collect(); - - let verifier: Vec<T::Field> = (0..t.num_shares) - .map(|i| { - typ.query( - &input_shares[i], - &proof_shares[i], - &query_rand, - &joint_rand, - t.num_shares, - ) - .unwrap() - }) - .reduce(|mut left, right| { - for (x, y) in left.iter_mut().zip(right.iter()) { - *x += *y; - } - left - }) - .unwrap(); - - let res = typ.decide(&verifier)?; - if res != t.expect_valid { - return Err(FlpError::Test(format!( - "distributed decision is {}; want {}", - res, t.expect_valid, - ))); - } - - // Try verifying various proof mutants. - for i in 0..proof.len() { - let mut mutated_proof = proof.clone(); - mutated_proof[i] += T::Field::one(); - let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?; - if typ.decide(&verifier)? { - return Err(FlpError::Test(format!( - "decision for proof mutant {} is {}; want {}", - i, true, false, - ))); - } - } - - // Try verifying a proof that is too short. - let mut mutated_proof = proof.clone(); - mutated_proof.truncate(gadgets[0].arity() - 1); - if typ - .query(input, &mutated_proof, &query_rand, &joint_rand, 1) - .is_ok() - { - return Err(FlpError::Test( - "query on short proof succeeded; want failure".to_string(), - )); - } - - // Try verifying a proof that is too long. - let mut mutated_proof = proof; - mutated_proof.extend_from_slice(&[T::Field::one(); 17]); - if typ - .query(input, &mutated_proof, &query_rand, &joint_rand, 1) - .is_ok() - { - return Err(FlpError::Test( - "query on long proof succeeded; want failure".to_string(), - )); - } - - if let Some(ref want) = t.expected_output { - let got = typ.truncate(input.to_vec())?; - - if got.len() != typ.output_len() { - return Err(FlpError::Test(format!( - "unexpected output length: got {}; want {}", - got.len(), - typ.output_len() - ))); - } - - if &got != want { - return Err(FlpError::Test(format!( - "unexpected output: got {got:?}; want {want:?}" - ))); - } - } - - Ok(()) - } -} - #[cfg(feature = "experimental")] #[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] pub mod fixedpoint_l2; |