diff options
Diffstat (limited to 'third_party/rust/prio/src/flp.rs')
-rw-r--r-- | third_party/rust/prio/src/flp.rs | 241 |
1 files changed, 227 insertions, 14 deletions
diff --git a/third_party/rust/prio/src/flp.rs b/third_party/rust/prio/src/flp.rs index 1912ebab14..5fd956155a 100644 --- a/third_party/rust/prio/src/flp.rs +++ b/third_party/rust/prio/src/flp.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 //! Implementation of the generic Fully Linear Proof (FLP) system specified in -//! [[draft-irtf-cfrg-vdaf-07]]. This is the main building block of [`Prio3`](crate::vdaf::prio3). +//! [[draft-irtf-cfrg-vdaf-08]]. This is the main building block of [`Prio3`](crate::vdaf::prio3). //! //! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation //! specifies a validity circuit that defines the set of valid measurements, as well as the finite @@ -24,7 +24,7 @@ //! //! // The prover chooses a measurement. //! let count = Count::new(); -//! let input: Vec<Field64> = count.encode_measurement(&0).unwrap(); +//! let input: Vec<Field64> = count.encode_measurement(&false).unwrap(); //! //! // The prover and verifier agree on "joint randomness" used to generate and //! // check the proof. The application needs to ensure that the prover @@ -44,7 +44,7 @@ //! assert!(count.decide(&verifier).unwrap()); //! ``` //! -//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/ #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; @@ -61,6 +61,7 @@ pub mod types; /// Errors propagated by methods in this module. #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum FlpError { /// Calling [`Type::prove`] returned an error. #[error("prove error: {0}")] @@ -110,20 +111,12 @@ pub enum FlpError { /// An error happened during noising. #[error("differential privacy error: {0}")] DifferentialPrivacy(#[from] crate::dp::DpError), - - /// Unit test error. - #[cfg(test)] - #[error("test failed: {0}")] - Test(String), } /// A type. Implementations of this trait specify how a particular kind of measurement is encoded /// as a vector of field elements and how validity of the encoded measurement is determined. /// Validity is determined via an arithmetic circuit evaluated over the encoded measurement. pub trait Type: Sized + Eq + Clone + Debug { - /// The Prio3 VDAF identifier corresponding to this type. - const ID: u32; - /// The type of raw measurement to be encoded. type Measurement: Clone + Debug; @@ -178,7 +171,7 @@ pub trait Type: Sized + Eq + Clone + Debug { /// use prio::field::{random_vector, FieldElement, Field64}; /// /// let count = Count::new(); - /// let input: Vec<Field64> = count.encode_measurement(&1).unwrap(); + /// let input: Vec<Field64> = count.encode_measurement(&true).unwrap(); /// let joint_rand = random_vector(count.joint_rand_len()).unwrap(); /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap(); /// assert_eq!(v, Field64::zero()); @@ -552,6 +545,7 @@ pub trait Type: Sized + Eq + Clone + Debug { /// A type which supports adding noise to aggregate shares for Server Differential Privacy. #[cfg(feature = "experimental")] +#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] pub trait TypeWithNoise<S>: Type where S: DifferentialPrivacyStrategy, @@ -754,6 +748,227 @@ pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usi gadget_degree * (wire_poly_len - 1) + 1 } +/// Utilities for testing FLPs. +#[cfg(feature = "test-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] +pub mod test_utils { + use super::*; + use crate::field::{random_vector, FieldElement, FieldElementWithInteger}; + + /// Various tests for an FLP. + #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] + pub struct FlpTest<'a, T: Type> { + /// The FLP. + pub flp: &'a T, + + /// Optional test name. + pub name: Option<&'a str>, + + /// The input to use for the tests. + pub input: &'a [T::Field], + + /// If set, the expected result of truncating the input. + pub expected_output: Option<&'a [T::Field]>, + + /// Whether the input is expected to be valid. + pub expect_valid: bool, + } + + impl<T: Type> FlpTest<'_, T> { + /// Construct a test and run it. Expect the input to be valid and compare the truncated + /// output to the provided value. + pub fn expect_valid<const SHARES: usize>( + flp: &T, + input: &[T::Field], + expected_output: &[T::Field], + ) { + FlpTest { + flp, + name: None, + input, + expected_output: Some(expected_output), + expect_valid: true, + } + .run::<SHARES>() + } + + /// Construct a test and run it. Expect the input to be invalid. + pub fn expect_invalid<const SHARES: usize>(flp: &T, input: &[T::Field]) { + FlpTest { + flp, + name: None, + input, + expect_valid: false, + expected_output: None, + } + .run::<SHARES>() + } + + /// Construct a test and run it. Expect the input to be valid. + pub fn expect_valid_no_output<const SHARES: usize>(flp: &T, input: &[T::Field]) { + FlpTest { + flp, + name: None, + input, + expect_valid: true, + expected_output: None, + } + .run::<SHARES>() + } + + /// Run the tests. + pub fn run<const SHARES: usize>(&self) { + let name = self.name.unwrap_or("unnamed test"); + + assert_eq!( + self.input.len(), + self.flp.input_len(), + "{name}: unexpected input length" + ); + + let mut gadgets = self.flp.gadget(); + let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap(); + let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap(); + let query_rand = random_vector(self.flp.query_rand_len()).unwrap(); + assert_eq!( + self.flp.query_rand_len(), + gadgets.len(), + "{name}: unexpected number of gadgets" + ); + assert_eq!( + self.flp.joint_rand_len(), + joint_rand.len(), + "{name}: unexpected joint rand length" + ); + assert_eq!( + self.flp.prove_rand_len(), + prove_rand.len(), + "{name}: unexpected prove rand length", + ); + assert_eq!( + self.flp.query_rand_len(), + query_rand.len(), + "{name}: unexpected query rand length", + ); + + // Run the validity circuit. + let v = self + .flp + .valid(&mut gadgets, self.input, &joint_rand, 1) + .unwrap(); + assert_eq!( + v == T::Field::zero(), + self.expect_valid, + "{name}: unexpected output of valid() returned {v}", + ); + + // Generate the proof. + let proof = self + .flp + .prove(self.input, &prove_rand, &joint_rand) + .unwrap(); + assert_eq!( + proof.len(), + self.flp.proof_len(), + "{name}: unexpected proof length" + ); + + // Query the proof. + let verifier = self + .flp + .query(self.input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!( + verifier.len(), + self.flp.verifier_len(), + "{name}: unexpected verifier length" + ); + + // Decide if the input is valid. + let res = self.flp.decide(&verifier).unwrap(); + assert_eq!(res, self.expect_valid, "{name}: unexpected decision"); + + // Run distributed FLP. + let input_shares = split_vector::<_, SHARES>(self.input); + let proof_shares = split_vector::<_, SHARES>(&proof); + let verifier: Vec<T::Field> = (0..SHARES) + .map(|i| { + self.flp + .query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + SHARES, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + + let res = self.flp.decide(&verifier).unwrap(); + assert_eq!( + res, self.expect_valid, + "{name}: unexpected distributed decision" + ); + + // Try verifying various proof mutants. + for i in 0..std::cmp::min(proof.len(), 10) { + let mut mutated_proof = proof.clone(); + mutated_proof[i] *= T::Field::from( + <T::Field as FieldElementWithInteger>::Integer::try_from(23).unwrap(), + ); + let verifier = self + .flp + .query(self.input, &mutated_proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert!( + !self.flp.decide(&verifier).unwrap(), + "{name}: proof mutant {} deemed valid", + i + ); + } + + // Try truncating the input. + if let Some(ref expected_output) = self.expected_output { + let output = self.flp.truncate(self.input.to_vec()).unwrap(); + + assert_eq!( + output.len(), + self.flp.output_len(), + "{name}: unexpected output length of truncate()" + ); + + assert_eq!( + &output, expected_output, + "{name}: unexpected output of truncate()" + ); + } + } + } + + fn split_vector<F: FieldElement, const SHARES: usize>(inp: &[F]) -> [Vec<F>; SHARES] { + let mut outp = Vec::with_capacity(SHARES); + outp.push(inp.to_vec()); + + for _ in 1..SHARES { + let share: Vec<F> = + random_vector(inp.len()).expect("failed to generate a random vector"); + for (x, y) in outp[0].iter_mut().zip(&share) { + *x -= *y; + } + outp.push(share); + } + + outp.try_into().unwrap() + } +} + #[cfg(test)] mod tests { use super::*; @@ -825,7 +1040,6 @@ mod tests { } impl<F: FftFriendlyFieldElement> Type for TestType<F> { - const ID: u32 = 0xFFFF0000; type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; @@ -960,7 +1174,6 @@ mod tests { } impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> { - const ID: u32 = 0xFFFF0000; type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; |