summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/flp.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/flp.rs')
-rw-r--r--third_party/rust/prio/src/flp.rs241
1 files changed, 227 insertions, 14 deletions
diff --git a/third_party/rust/prio/src/flp.rs b/third_party/rust/prio/src/flp.rs
index 1912ebab14..5fd956155a 100644
--- a/third_party/rust/prio/src/flp.rs
+++ b/third_party/rust/prio/src/flp.rs
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: MPL-2.0
//! Implementation of the generic Fully Linear Proof (FLP) system specified in
-//! [[draft-irtf-cfrg-vdaf-07]]. This is the main building block of [`Prio3`](crate::vdaf::prio3).
+//! [[draft-irtf-cfrg-vdaf-08]]. This is the main building block of [`Prio3`](crate::vdaf::prio3).
//!
//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation
//! specifies a validity circuit that defines the set of valid measurements, as well as the finite
@@ -24,7 +24,7 @@
//!
//! // The prover chooses a measurement.
//! let count = Count::new();
-//! let input: Vec<Field64> = count.encode_measurement(&0).unwrap();
+//! let input: Vec<Field64> = count.encode_measurement(&false).unwrap();
//!
//! // The prover and verifier agree on "joint randomness" used to generate and
//! // check the proof. The application needs to ensure that the prover
@@ -44,7 +44,7 @@
//! assert!(count.decide(&verifier).unwrap());
//! ```
//!
-//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/
#[cfg(feature = "experimental")]
use crate::dp::DifferentialPrivacyStrategy;
@@ -61,6 +61,7 @@ pub mod types;
/// Errors propagated by methods in this module.
#[derive(Debug, thiserror::Error)]
+#[non_exhaustive]
pub enum FlpError {
/// Calling [`Type::prove`] returned an error.
#[error("prove error: {0}")]
@@ -110,20 +111,12 @@ pub enum FlpError {
/// An error happened during noising.
#[error("differential privacy error: {0}")]
DifferentialPrivacy(#[from] crate::dp::DpError),
-
- /// Unit test error.
- #[cfg(test)]
- #[error("test failed: {0}")]
- Test(String),
}
/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
/// as a vector of field elements and how validity of the encoded measurement is determined.
/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
pub trait Type: Sized + Eq + Clone + Debug {
- /// The Prio3 VDAF identifier corresponding to this type.
- const ID: u32;
-
/// The type of raw measurement to be encoded.
type Measurement: Clone + Debug;
@@ -178,7 +171,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
/// use prio::field::{random_vector, FieldElement, Field64};
///
/// let count = Count::new();
- /// let input: Vec<Field64> = count.encode_measurement(&1).unwrap();
+ /// let input: Vec<Field64> = count.encode_measurement(&true).unwrap();
/// let joint_rand = random_vector(count.joint_rand_len()).unwrap();
/// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap();
/// assert_eq!(v, Field64::zero());
@@ -552,6 +545,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
/// A type which supports adding noise to aggregate shares for Server Differential Privacy.
#[cfg(feature = "experimental")]
+#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
pub trait TypeWithNoise<S>: Type
where
S: DifferentialPrivacyStrategy,
@@ -754,6 +748,227 @@ pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usi
gadget_degree * (wire_poly_len - 1) + 1
}
+/// Utilities for testing FLPs.
+#[cfg(feature = "test-util")]
+#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
+pub mod test_utils {
+ use super::*;
+ use crate::field::{random_vector, FieldElement, FieldElementWithInteger};
+
+ /// Various tests for an FLP.
+ #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
+ pub struct FlpTest<'a, T: Type> {
+ /// The FLP.
+ pub flp: &'a T,
+
+ /// Optional test name.
+ pub name: Option<&'a str>,
+
+ /// The input to use for the tests.
+ pub input: &'a [T::Field],
+
+ /// If set, the expected result of truncating the input.
+ pub expected_output: Option<&'a [T::Field]>,
+
+ /// Whether the input is expected to be valid.
+ pub expect_valid: bool,
+ }
+
+ impl<T: Type> FlpTest<'_, T> {
+ /// Construct a test and run it. Expect the input to be valid and compare the truncated
+ /// output to the provided value.
+ pub fn expect_valid<const SHARES: usize>(
+ flp: &T,
+ input: &[T::Field],
+ expected_output: &[T::Field],
+ ) {
+ FlpTest {
+ flp,
+ name: None,
+ input,
+ expected_output: Some(expected_output),
+ expect_valid: true,
+ }
+ .run::<SHARES>()
+ }
+
+ /// Construct a test and run it. Expect the input to be invalid.
+ pub fn expect_invalid<const SHARES: usize>(flp: &T, input: &[T::Field]) {
+ FlpTest {
+ flp,
+ name: None,
+ input,
+ expect_valid: false,
+ expected_output: None,
+ }
+ .run::<SHARES>()
+ }
+
+ /// Construct a test and run it. Expect the input to be valid.
+ pub fn expect_valid_no_output<const SHARES: usize>(flp: &T, input: &[T::Field]) {
+ FlpTest {
+ flp,
+ name: None,
+ input,
+ expect_valid: true,
+ expected_output: None,
+ }
+ .run::<SHARES>()
+ }
+
+ /// Run the tests.
+ pub fn run<const SHARES: usize>(&self) {
+ let name = self.name.unwrap_or("unnamed test");
+
+ assert_eq!(
+ self.input.len(),
+ self.flp.input_len(),
+ "{name}: unexpected input length"
+ );
+
+ let mut gadgets = self.flp.gadget();
+ let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap();
+ let query_rand = random_vector(self.flp.query_rand_len()).unwrap();
+ assert_eq!(
+ self.flp.query_rand_len(),
+ gadgets.len(),
+ "{name}: unexpected number of gadgets"
+ );
+ assert_eq!(
+ self.flp.joint_rand_len(),
+ joint_rand.len(),
+ "{name}: unexpected joint rand length"
+ );
+ assert_eq!(
+ self.flp.prove_rand_len(),
+ prove_rand.len(),
+ "{name}: unexpected prove rand length",
+ );
+ assert_eq!(
+ self.flp.query_rand_len(),
+ query_rand.len(),
+ "{name}: unexpected query rand length",
+ );
+
+ // Run the validity circuit.
+ let v = self
+ .flp
+ .valid(&mut gadgets, self.input, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(
+ v == T::Field::zero(),
+ self.expect_valid,
+ "{name}: unexpected output of valid() returned {v}",
+ );
+
+ // Generate the proof.
+ let proof = self
+ .flp
+ .prove(self.input, &prove_rand, &joint_rand)
+ .unwrap();
+ assert_eq!(
+ proof.len(),
+ self.flp.proof_len(),
+ "{name}: unexpected proof length"
+ );
+
+ // Query the proof.
+ let verifier = self
+ .flp
+ .query(self.input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(
+ verifier.len(),
+ self.flp.verifier_len(),
+ "{name}: unexpected verifier length"
+ );
+
+ // Decide if the input is valid.
+ let res = self.flp.decide(&verifier).unwrap();
+ assert_eq!(res, self.expect_valid, "{name}: unexpected decision");
+
+ // Run distributed FLP.
+ let input_shares = split_vector::<_, SHARES>(self.input);
+ let proof_shares = split_vector::<_, SHARES>(&proof);
+ let verifier: Vec<T::Field> = (0..SHARES)
+ .map(|i| {
+ self.flp
+ .query(
+ &input_shares[i],
+ &proof_shares[i],
+ &query_rand,
+ &joint_rand,
+ SHARES,
+ )
+ .unwrap()
+ })
+ .reduce(|mut left, right| {
+ for (x, y) in left.iter_mut().zip(right.iter()) {
+ *x += *y;
+ }
+ left
+ })
+ .unwrap();
+
+ let res = self.flp.decide(&verifier).unwrap();
+ assert_eq!(
+ res, self.expect_valid,
+ "{name}: unexpected distributed decision"
+ );
+
+ // Try verifying various proof mutants.
+ for i in 0..std::cmp::min(proof.len(), 10) {
+ let mut mutated_proof = proof.clone();
+ mutated_proof[i] *= T::Field::from(
+ <T::Field as FieldElementWithInteger>::Integer::try_from(23).unwrap(),
+ );
+ let verifier = self
+ .flp
+ .query(self.input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert!(
+ !self.flp.decide(&verifier).unwrap(),
+ "{name}: proof mutant {} deemed valid",
+ i
+ );
+ }
+
+ // Try truncating the input.
+ if let Some(ref expected_output) = self.expected_output {
+ let output = self.flp.truncate(self.input.to_vec()).unwrap();
+
+ assert_eq!(
+ output.len(),
+ self.flp.output_len(),
+ "{name}: unexpected output length of truncate()"
+ );
+
+ assert_eq!(
+ &output, expected_output,
+ "{name}: unexpected output of truncate()"
+ );
+ }
+ }
+ }
+
+ fn split_vector<F: FieldElement, const SHARES: usize>(inp: &[F]) -> [Vec<F>; SHARES] {
+ let mut outp = Vec::with_capacity(SHARES);
+ outp.push(inp.to_vec());
+
+ for _ in 1..SHARES {
+ let share: Vec<F> =
+ random_vector(inp.len()).expect("failed to generate a random vector");
+ for (x, y) in outp[0].iter_mut().zip(&share) {
+ *x -= *y;
+ }
+ outp.push(share);
+ }
+
+ outp.try_into().unwrap()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -825,7 +1040,6 @@ mod tests {
}
impl<F: FftFriendlyFieldElement> Type for TestType<F> {
- const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
@@ -960,7 +1174,6 @@ mod tests {
}
impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
- const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;