diff options
Diffstat (limited to 'third_party/rust/prio/src/flp')
-rw-r--r-- | third_party/rust/prio/src/flp/gadgets.rs | 715 | ||||
-rw-r--r-- | third_party/rust/prio/src/flp/types.rs | 1199 |
2 files changed, 1914 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs new file mode 100644 index 0000000000..fd2be84eaa --- /dev/null +++ b/third_party/rust/prio/src/flp/gadgets.rs @@ -0,0 +1,715 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A collection of gadgets. + +use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; +use crate::field::FieldElement; +use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget}; +use crate::polynomial::{poly_deg, poly_eval, poly_mul}; + +#[cfg(feature = "multithreaded")] +use rayon::prelude::*; + +use std::any::Any; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::marker::PhantomData; + +/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for +/// polynomial multiplication. Otherwise, the gadget uses direct multiplication. +const FFT_THRESHOLD: usize = 60; + +/// An arity-2 gadget that multiples its inputs. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Mul<F: FieldElement> { + /// Size of buffer for FFT operations. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FieldElement> Mul<F> { + /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be + /// called by the validity circuit. + pub fn new(num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(2, num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + n, + n_inv, + num_calls, + } + } + + // Multiply input polynomials directly. + pub(crate) fn call_poly_direct( + &mut self, + outp: &mut [F], + inp: &[Vec<F>], + ) -> Result<(), FlpError> { + let v = poly_mul(&inp[0], &inp[1]); + outp[..v.len()].clone_from_slice(&v); + Ok(()) + } + + // Multiply input polynomials using FFT. + pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let mut buf = vec![F::zero(); n]; + + discrete_fourier_transform(&mut buf, &inp[0], n)?; + discrete_fourier_transform(outp, &inp[1], n)?; + + for i in 0..n { + buf[i] *= outp[i]; + } + + discrete_fourier_transform(outp, &buf, n)?; + discrete_fourier_transform_inv_finish(outp, n, self.n_inv); + Ok(()) + } +} + +impl<F: FieldElement> Gadget<F> for Mul<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(inp[0] * inp[1]) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 2 + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// An arity-1 gadget that evaluates its input on some polynomial. +// +// TODO Make `poly` an array of length determined by a const generic. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PolyEval<F: FieldElement> { + poly: Vec<F>, + /// Size of buffer for FFT operations. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FieldElement> PolyEval<F> { + /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times + /// this gadget is called by the validity circuit. + pub fn new(poly: Vec<F>, num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + poly, + n, + n_inv, + num_calls, + } + } +} + +impl<F: FieldElement> PolyEval<F> { + // Multiply input polynomials directly. + fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + outp[0] = self.poly[0]; + let mut x = inp[0].to_vec(); + for i in 1..self.poly.len() { + for j in 0..x.len() { + outp[j] += self.poly[i] * x[j]; + } + + if i < self.poly.len() - 1 { + x = poly_mul(&x, &inp[0]); + } + } + Ok(()) + } + + // Multiply input polynomials using FFT. + fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let inp = &inp[0]; + + let mut inp_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut inp_vals, inp, n)?; + + let mut x_vals = inp_vals.clone(); + let mut x = vec![F::zero(); n]; + x[..inp.len()].clone_from_slice(inp); + + outp[0] = self.poly[0]; + for i in 1..self.poly.len() { + for j in 0..n { + outp[j] += self.poly[i] * x[j]; + } + + if i < self.poly.len() - 1 { + for j in 0..n { + x_vals[j] *= inp_vals[j]; + } + + discrete_fourier_transform(&mut x, &x_vals, n)?; + discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv); + } + } + Ok(()) + } +} + +impl<F: FieldElement> Gadget<F> for PolyEval<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(poly_eval(&self.poly, inp[0])) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for item in outp.iter_mut() { + *item = F::zero(); + } + + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + poly_deg(&self.poly) + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// An arity-2 gadget that returns `poly(in[0]) * in[1]` for some polynomial `poly`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct BlindPolyEval<F: FieldElement> { + poly: Vec<F>, + /// Size of buffer for the outer FFT multiplication. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FieldElement> BlindPolyEval<F> { + /// Returns a `BlindPolyEval` gadget for polynomial `poly`. + pub fn new(poly: Vec<F>, num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(poly_deg(&poly) + 1, num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + poly, + n, + n_inv, + num_calls, + } + } + + fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let x = &inp[0]; + let y = &inp[1]; + + let mut z = y.to_vec(); + for i in 0..self.poly.len() { + for j in 0..z.len() { + outp[j] += self.poly[i] * z[j]; + } + + if i < self.poly.len() - 1 { + z = poly_mul(&z, x); + } + } + Ok(()) + } + + fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let x = &inp[0]; + let y = &inp[1]; + + let mut x_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut x_vals, x, n)?; + + let mut z_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut z_vals, y, n)?; + + let mut z = vec![F::zero(); n]; + let mut z_len = y.len(); + z[..y.len()].clone_from_slice(y); + + for i in 0..self.poly.len() { + for j in 0..z_len { + outp[j] += self.poly[i] * z[j]; + } + + if i < self.poly.len() - 1 { + for j in 0..n { + z_vals[j] *= x_vals[j]; + } + + discrete_fourier_transform(&mut z, &z_vals, n)?; + discrete_fourier_transform_inv_finish(&mut z, n, self.n_inv); + z_len += x.len(); + } + } + Ok(()) + } +} + +impl<F: FieldElement> Gadget<F> for BlindPolyEval<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(inp[1] * poly_eval(&self.poly, inp[0])) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for x in outp.iter_mut() { + *x = F::zero(); + } + + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + poly_deg(&self.poly) + 1 + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// Marker trait for abstracting over [`ParallelSum`]. +pub trait ParallelSumGadget<F: FieldElement, G>: Gadget<F> + Debug { + /// Wraps `inner` into a sum gadget with `chunks` chunks + fn new(inner: G, chunks: usize) -> Self; +} + +/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the +/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ParallelSum<F: FieldElement, G: Gadget<F>> { + inner: G, + chunks: usize, + phantom: PhantomData<F>, +} + +impl<F: FieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> for ParallelSum<F, G> { + fn new(inner: G, chunks: usize) -> Self { + Self { + inner, + chunks, + phantom: PhantomData, + } + } +} + +impl<F: FieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + let mut outp = F::zero(); + for chunk in inp.chunks(self.inner.arity()) { + outp += self.inner.call(chunk)?; + } + Ok(outp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for x in outp.iter_mut() { + *x = F::zero(); + } + + let mut partial_outp = vec![F::zero(); outp.len()]; + + for chunk in inp.chunks(self.inner.arity()) { + self.inner.call_poly(&mut partial_outp, chunk)?; + for i in 0..outp.len() { + outp[i] += partial_outp[i] + } + } + + Ok(()) + } + + fn arity(&self) -> usize { + self.chunks * self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the +/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum +/// evaluation is multithreaded. +#[cfg(feature = "multithreaded")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ParallelSumMultithreaded<F: FieldElement, G: Gadget<F>> { + serial_sum: ParallelSum<F, G>, +} + +#[cfg(feature = "multithreaded")] +impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G> +where + F: FieldElement + Sync + Send, + G: 'static + Gadget<F> + Clone + Sync + Send, +{ + fn new(inner: G, chunks: usize) -> Self { + Self { + serial_sum: ParallelSum::new(inner, chunks), + } + } +} + +/// Data structures passed between fold operations in [`ParallelSumMultithreaded`]. +#[cfg(feature = "multithreaded")] +struct ParallelSumFoldState<F, G> { + /// Inner gadget. + inner: G, + /// Output buffer for `call_poly()`. + partial_output: Vec<F>, + /// Sum accumulator. + partial_sum: Vec<F>, +} + +#[cfg(feature = "multithreaded")] +impl<F, G> ParallelSumFoldState<F, G> { + fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G> + where + G: Clone, + F: FieldElement, + { + ParallelSumFoldState { + inner: gadget.clone(), + partial_output: vec![F::zero(); length], + partial_sum: vec![F::zero(); length], + } + } +} + +#[cfg(feature = "multithreaded")] +impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G> +where + F: FieldElement + Sync + Send, + G: 'static + Gadget<F> + Clone + Sync + Send, +{ + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + self.serial_sum.call(inp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the + // gadget on each input polynomial, using the first temporary buffer as an output buffer. + // Then accumulate that result into the second temporary buffer, which acts as a running + // sum. Then, discard everything but the partial sums, add them, and finally copy the sum + // to the output parameter. This is equivalent to the single threaded calculation in + // ParallelSum, since we only rearrange additions, and field addition is associative. + let res = inp + .par_chunks(self.serial_sum.inner.arity()) + .fold( + || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()), + |mut state, chunk| { + state + .inner + .call_poly(&mut state.partial_output, chunk) + .unwrap(); + for (sum_elem, output_elem) in state + .partial_sum + .iter_mut() + .zip(state.partial_output.iter()) + { + *sum_elem += *output_elem; + } + state + }, + ) + .map(|state| state.partial_sum) + .reduce( + || vec![F::zero(); outp.len()], + |mut x, y| { + for (xi, yi) in x.iter_mut().zip(y.iter()) { + *xi += *yi; + } + x + }, + ); + + outp.copy_from_slice(&res[..]); + Ok(()) + } + + fn arity(&self) -> usize { + self.serial_sum.arity() + } + + fn degree(&self) -> usize { + self.serial_sum.degree() + } + + fn calls(&self) -> usize { + self.serial_sum.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +// Check that the input parameters of g.call() are well-formed. +fn gadget_call_check<F: FieldElement, G: Gadget<F>>( + gadget: &G, + in_len: usize, +) -> Result<(), FlpError> { + if in_len != gadget.arity() { + return Err(FlpError::Gadget(format!( + "unexpected number of inputs: got {}; want {}", + in_len, + gadget.arity() + ))); + } + + if in_len == 0 { + return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string())); + } + + Ok(()) +} + +// Check that the input parameters of g.call_poly() are well-formed. +fn gadget_call_poly_check<F: FieldElement, G: Gadget<F>>( + gadget: &G, + outp: &[F], + inp: &[Vec<F>], +) -> Result<(), FlpError> +where + G: Gadget<F>, +{ + gadget_call_check(gadget, inp.len())?; + + for i in 1..inp.len() { + if inp[i].len() != inp[0].len() { + return Err(FlpError::Gadget( + "gadget called on wire polynomials with different lengths".to_string(), + )); + } + } + + let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two(); + if outp.len() != expected { + return Err(FlpError::Gadget(format!( + "incorrect output length: got {}; want {}", + outp.len(), + expected + ))); + } + + Ok(()) +} + +#[inline] +fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize { + gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two() +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::field::{random_vector, Field96 as TestField}; + use crate::prng::Prng; + + #[test] + fn test_mul() { + // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the + // naive multiplication code path. + let num_calls = FFT_THRESHOLD / 2; + let mut g: Mul<TestField> = Mul::new(num_calls); + gadget_test(&mut g, num_calls); + + // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises + // FFT-based polynomial multiplication. + let num_calls = FFT_THRESHOLD; + let mut g: Mul<TestField> = Mul::new(num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_poly_eval() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + + let num_calls = FFT_THRESHOLD / 2; + let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls); + gadget_test(&mut g, num_calls); + + let num_calls = FFT_THRESHOLD; + let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_blind_poly_eval() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + + let num_calls = FFT_THRESHOLD / 2; + let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly.clone(), num_calls); + gadget_test(&mut g, num_calls); + + let num_calls = FFT_THRESHOLD; + let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly, num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_parallel_sum() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + let num_calls = 10; + let chunks = 23; + + let mut g = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks); + gadget_test(&mut g, num_calls); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_parallel_sum_multithreaded() { + use std::iter; + + for num_calls in [1, 10, 100] { + let poly: Vec<TestField> = random_vector(10).unwrap(); + let chunks = 23; + + let mut g = + ParallelSumMultithreaded::new(BlindPolyEval::new(poly.clone(), num_calls), chunks); + gadget_test(&mut g, num_calls); + + // Test that the multithreaded version has the same output as the normal version. + let mut g_serial = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks); + assert_eq!(g.arity(), g_serial.arity()); + assert_eq!(g.degree(), g_serial.degree()); + assert_eq!(g.calls(), g_serial.calls()); + + let arity = g.arity(); + let degree = g.degree(); + + // Test that both gadgets evaluate to the same value when run on scalar inputs. + let inp: Vec<TestField> = random_vector(arity).unwrap(); + let result = g.call(&inp).unwrap(); + let result_serial = g_serial.call(&inp).unwrap(); + assert_eq!(result, result_serial); + + // Test that both gadgets evaluate to the same value when run on polynomial inputs. + let mut poly_outp = + vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; + let mut poly_outp_serial = + vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; + let mut prng: Prng<TestField, _> = Prng::new().unwrap(); + let poly_inp: Vec<_> = iter::repeat_with(|| { + iter::repeat_with(|| prng.get()) + .take(1 + num_calls) + .collect::<Vec<_>>() + }) + .take(arity) + .collect(); + + g.call_poly(&mut poly_outp, &poly_inp).unwrap(); + g_serial + .call_poly(&mut poly_outp_serial, &poly_inp) + .unwrap(); + assert_eq!(poly_outp, poly_outp_serial); + } + } + + // Test that calling g.call_poly() and evaluating the output at a given point is equivalent + // to evaluating each of the inputs at the same point and applying g.call() on the results. + fn gadget_test<F: FieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) { + let wire_poly_len = (1 + num_calls).next_power_of_two(); + let mut prng = Prng::new().unwrap(); + let mut inp = vec![F::zero(); g.arity()]; + let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)]; + let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()]; + + let r = prng.get(); + for i in 0..g.arity() { + for j in 0..wire_poly_len { + wire_polys[i][j] = prng.get(); + } + inp[i] = poly_eval(&wire_polys[i], r); + } + + g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); + let got = poly_eval(&gadget_poly, r); + let want = g.call(&inp).unwrap(); + assert_eq!(got, want); + + // Repeat the call to make sure that the gadget's memory is reset properly between calls. + g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); + let got = poly_eval(&gadget_poly, r); + assert_eq!(got, want); + } +} diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs new file mode 100644 index 0000000000..83b9752f69 --- /dev/null +++ b/third_party/rust/prio/src/flp/types.rs @@ -0,0 +1,1199 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A collection of [`Type`](crate::flp::Type) implementations. + +use crate::field::{FieldElement, FieldElementExt}; +use crate::flp::gadgets::{BlindPolyEval, Mul, ParallelSumGadget, PolyEval}; +use crate::flp::{FlpError, Gadget, Type}; +use crate::polynomial::poly_range_check; +use std::convert::TryInto; +use std::marker::PhantomData; + +/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of +/// the measurements (i.e., the total number of `1s`). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Count<F> { + range_checker: Vec<F>, +} + +impl<F: FieldElement> Count<F> { + /// Return a new [`Count`] type instance. + pub fn new() -> Self { + Self { + range_checker: poly_range_check(0, 2), + } + } +} + +impl<F: FieldElement> Default for Count<F> { + fn default() -> Self { + Self::new() + } +} + +impl<F: FieldElement> Type for Count<F> { + const ID: u32 = 0x00000000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> { + let max = F::valid_integer_try_from(1)?; + if *value > max { + return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); + } + + Ok(vec![F::from(*value)]) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { + decode_result(data) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(Mul::new(1))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + Ok(g[0].call(&[input[0], input[0]])? - input[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + 1 + } + + fn proof_len(&self) -> usize { + 5 + } + + fn verifier_len(&self) -> usize { + 4 + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 0 + } + + fn prove_rand_len(&self) -> usize { + 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of +/// the measurements. +/// +/// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3]. +/// +/// [BBCG+19]: https://ia.cr/2019/188 +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Sum<F: FieldElement> { + bits: usize, + range_checker: Vec<F>, +} + +impl<F: FieldElement> Sum<F> { + /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0, + /// 2^bits)`. + pub fn new(bits: usize) -> Result<Self, FlpError> { + if !F::valid_integer_bitlength(bits) { + return Err(FlpError::Encode( + "invalid bits: number of bits exceeds maximum number of bits in this field" + .to_string(), + )); + } + Ok(Self { + bits, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FieldElement> Type for Sum<F> { + const ID: u32 = 0x00000001; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { + let v = F::encode_into_bitvector_representation(summand, self.bits)?; + Ok(v) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { + decode_result(data) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.clone(), + self.bits, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_from_bitvector_representation(&input)?; + Ok(vec![res]) + } + + fn input_len(&self) -> usize { + self.bits + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + 1 + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the +/// aggregate is the arithmetic average. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Average<F: FieldElement> { + bits: usize, + range_checker: Vec<F>, +} + +impl<F: FieldElement> Average<F> { + /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, + /// 2^bits)`. + pub fn new(bits: usize) -> Result<Self, FlpError> { + if !F::valid_integer_bitlength(bits) { + return Err(FlpError::Encode( + "invalid bits: number of bits exceeds maximum number of bits in this field" + .to_string(), + )); + } + Ok(Self { + bits, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FieldElement> Type for Average<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = f64; + type Field = F; + + fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { + let v = F::encode_into_bitvector_representation(summand, self.bits)?; + Ok(v) + } + + fn decode_result(&self, data: &[F], num_measurements: usize) -> Result<f64, FlpError> { + // Compute the average from the aggregated sum. + let data = decode_result(data)?; + let data: u64 = data.try_into().map_err(|err| { + FlpError::Decode(format!("failed to convert {:?} to u64: {}", data, err,)) + })?; + let result = (data as f64) / (num_measurements as f64); + Ok(result) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.clone(), + self.bits, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_from_bitvector_representation(&input)?; + Ok(vec![res]) + } + + fn input_len(&self) -> usize { + self.bits + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + 1 + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// The histogram type. Each measurement is a non-negative integer and the aggregate is a histogram +/// approximating the distribution of the measurements. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Histogram<F: FieldElement> { + buckets: Vec<F::Integer>, + range_checker: Vec<F>, +} + +impl<F: FieldElement> Histogram<F> { + /// Return a new [`Histogram`] type with the given buckets. + pub fn new(buckets: Vec<F::Integer>) -> Result<Self, FlpError> { + if buckets.len() >= u32::MAX as usize { + return Err(FlpError::Encode( + "invalid buckets: number of buckets exceeds maximum permitted".to_string(), + )); + } + + if !buckets.is_empty() { + for i in 0..buckets.len() - 1 { + if buckets[i + 1] <= buckets[i] { + return Err(FlpError::Encode( + "invalid buckets: out-of-order boundary".to_string(), + )); + } + } + } + + Ok(Self { + buckets, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FieldElement> Type for Histogram<F> { + const ID: u32 = 0x00000002; + type Measurement = F::Integer; + type AggregateResult = Vec<F::Integer>; + type Field = F; + + fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> { + let mut data = vec![F::zero(); self.buckets.len() + 1]; + + let bucket = match self.buckets.binary_search(measurement) { + Ok(i) => i, // on a bucket boundary + Err(i) => i, // smaller than the i-th bucket boundary + }; + + data[bucket] = F::one(); + Ok(data) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result<Vec<F::Integer>, FlpError> { + decode_result_vec(data, self.buckets.len() + 1) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.to_vec(), + self.input_len(), + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + + // Check that each element of `input` is a 0 or 1. + let range_check = call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])?; + + // Check that the elements of `input` sum to 1. + let mut sum_check = -(F::one() / F::from(F::valid_integer_try_from(num_shares)?)); + for val in input.iter() { + sum_check += *val; + } + + // Take a random linear combination of both checks. + let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check; + Ok(out) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + self.buckets.len() + 1 + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.input_len()).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 2 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// A sequence of counters. This type uses a neat trick from [[BBCG+19], Corollary 4.9] to reduce +/// the proof size to roughly the square root of the input size. +/// +/// [BBCG+19]: https://eprint.iacr.org/2019/188 +#[derive(Debug, PartialEq, Eq)] +pub struct CountVec<F, S> { + range_checker: Vec<F>, + len: usize, + chunk_len: usize, + gadget_calls: usize, + phantom: PhantomData<S>, +} + +impl<F: FieldElement, S: ParallelSumGadget<F, BlindPolyEval<F>>> CountVec<F, S> { + /// Returns a new [`CountVec`] with the given length. + pub fn new(len: usize) -> Self { + // The optimal chunk length is the square root of the input length. If the input length is + // not a perfect square, then round down. If the result is 0, then let the chunk length be + // 1 so that the underlying gadget can still be called. + let chunk_len = std::cmp::max(1, (len as f64).sqrt() as usize); + + let mut gadget_calls = len / chunk_len; + if len % chunk_len != 0 { + gadget_calls += 1; + } + + Self { + range_checker: poly_range_check(0, 2), + len, + chunk_len, + gadget_calls, + phantom: PhantomData, + } + } +} + +impl<F: FieldElement, S> Clone for CountVec<F, S> { + fn clone(&self) -> Self { + Self { + range_checker: self.range_checker.clone(), + len: self.len, + chunk_len: self.chunk_len, + gadget_calls: self.gadget_calls, + phantom: PhantomData, + } + } +} + +impl<F, S> Type for CountVec<F, S> +where + F: FieldElement, + S: ParallelSumGadget<F, BlindPolyEval<F>> + Eq + 'static, +{ + const ID: u32 = 0xFFFF0000; + type Measurement = Vec<F::Integer>; + type AggregateResult = Vec<F::Integer>; + type Field = F; + + fn encode_measurement(&self, measurement: &Vec<F::Integer>) -> Result<Vec<F>, FlpError> { + if measurement.len() != self.len { + return Err(FlpError::Encode(format!( + "unexpected measurement length: got {}; want {}", + measurement.len(), + self.len + ))); + } + + let max = F::Integer::from(F::one()); + for value in measurement { + if *value > max { + return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); + } + } + + Ok(measurement.iter().map(|value| F::from(*value)).collect()) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result<Vec<F::Integer>, FlpError> { + decode_result_vec(data, self.len) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(S::new( + BlindPolyEval::new(self.range_checker.clone(), self.gadget_calls), + self.chunk_len, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + + let s = F::from(F::valid_integer_try_from(num_shares)?).inv(); + let mut r = joint_rand[0]; + let mut outp = F::zero(); + let mut padded_chunk = vec![F::zero(); 2 * self.chunk_len]; + for chunk in input.chunks(self.chunk_len) { + let d = chunk.len(); + for i in 0..self.chunk_len { + if i < d { + padded_chunk[2 * i] = chunk[i]; + } else { + // If the chunk is smaller than the chunk length, then copy the last element of + // the chunk into the remaining slots. + padded_chunk[2 * i] = chunk[d - 1]; + } + padded_chunk[2 * i + 1] = r * s; + r *= joint_rand[0]; + } + + outp += g[0].call(&padded_chunk)?; + } + + Ok(outp) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + self.len + } + + fn proof_len(&self) -> usize { + (self.chunk_len * 2) + 3 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1 + } + + fn verifier_len(&self) -> usize { + 2 + self.chunk_len * 2 + } + + fn output_len(&self) -> usize { + self.len + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + self.chunk_len * 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// Compute a random linear combination of the result of calls of `g` on each element of `input`. +/// +/// # Arguments +/// +/// * `g` - The gadget to be applied elementwise +/// * `input` - The vector on whose elements to apply `g` +/// * `rnd` - The randomness used for the linear combination +pub(crate) fn call_gadget_on_vec_entries<F: FieldElement>( + g: &mut Box<dyn Gadget<F>>, + input: &[F], + rnd: F, +) -> Result<F, FlpError> { + let mut range_check = F::zero(); + let mut r = rnd; + for chunk in input.chunks(1) { + range_check += r * g.call(chunk)?; + r *= rnd; + } + Ok(range_check) +} + +/// Given a vector `data` of field elements which should contain exactly one entry, return the +/// integer representation of that entry. +pub(crate) fn decode_result<F: FieldElement>(data: &[F]) -> Result<F::Integer, FlpError> { + if data.len() != 1 { + return Err(FlpError::Decode("unexpected input length".into())); + } + Ok(F::Integer::from(data[0])) +} + +/// Given a vector `data` of field elements, return a vector containing the corresponding integer +/// representations, if the number of entries matches `expected_len`. +pub(crate) fn decode_result_vec<F: FieldElement>( + data: &[F], + expected_len: usize, +) -> Result<Vec<F::Integer>, FlpError> { + if data.len() != expected_len { + return Err(FlpError::Decode("unexpected input length".into())); + } + Ok(data.iter().map(|elem| F::Integer::from(*elem)).collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{random_vector, split_vector, Field64 as TestField}; + use crate::flp::gadgets::ParallelSum; + #[cfg(feature = "multithreaded")] + use crate::flp::gadgets::ParallelSumMultithreaded; + + // Number of shares to split input and proofs into in `flp_test`. + const NUM_SHARES: usize = 3; + + struct ValidityTestCase<F> { + expect_valid: bool, + expected_output: Option<Vec<F>>, + } + + #[test] + fn test_count() { + let count: Count<TestField> = Count::new(); + let zero = TestField::zero(); + let one = TestField::one(); + + // Round trip + assert_eq!( + count + .decode_result( + &count + .truncate(count.encode_measurement(&1).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + 1, + ); + + // Test FLP on valid input. + flp_validity_test( + &count, + &count.encode_measurement(&1).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one]), + }, + ) + .unwrap(); + + flp_validity_test( + &count, + &count.encode_measurement(&0).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero]), + }, + ) + .unwrap(); + + // Test FLP on invalid input. + flp_validity_test( + &count, + &[TestField::from(1337)], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + // Try running the validity circuit on an input that's too short. + count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); + count + .valid(&mut count.gadget(), &[1.into(), 2.into()], &[], 1) + .unwrap_err(); + } + + #[test] + fn test_sum() { + let sum = Sum::new(11).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + // Round trip + assert_eq!( + sum.decode_result( + &sum.truncate(sum.encode_measurement(&27).unwrap()).unwrap(), + 1 + ) + .unwrap(), + 27, + ); + + // Test FLP on valid input. + flp_validity_test( + &sum, + &sum.encode_measurement(&1337).unwrap(), + &ValidityTestCase { + expect_valid: true, + expected_output: Some(vec![TestField::from(1337)]), + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(0).unwrap(), + &[], + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero]), + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(2).unwrap(), + &[one, zero], + &ValidityTestCase { + expect_valid: true, + expected_output: Some(vec![one]), + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(9).unwrap(), + &[one, zero, one, one, zero, one, one, one, zero], + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![TestField::from(237)]), + }, + ) + .unwrap(); + + // Test FLP on invalid input. + flp_validity_test( + &Sum::new(3).unwrap(), + &[one, nine, zero], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(5).unwrap(), + &[zero, zero, zero, zero, nine], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + } + + #[test] + fn test_average() { + let average = Average::new(11).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let ten = TestField::from(10); + + // Testing that average correctly quotients the sum of the measurements + // by the number of measurements. + assert_eq!(average.decode_result(&[zero], 1).unwrap(), 0.0); + assert_eq!(average.decode_result(&[one], 1).unwrap(), 1.0); + assert_eq!(average.decode_result(&[one], 2).unwrap(), 0.5); + assert_eq!(average.decode_result(&[one], 4).unwrap(), 0.25); + assert_eq!(average.decode_result(&[ten], 8).unwrap(), 1.25); + + // round trip of 12 with `num_measurements`=1 + assert_eq!( + average + .decode_result( + &average + .truncate(average.encode_measurement(&12).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + 12.0 + ); + + // round trip of 12 with `num_measurements`=24 + assert_eq!( + average + .decode_result( + &average + .truncate(average.encode_measurement(&12).unwrap()) + .unwrap(), + 24 + ) + .unwrap(), + 0.5 + ); + } + + #[test] + fn test_histogram() { + let hist = Histogram::new(vec![10, 20]).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + assert_eq!(&hist.encode_measurement(&7).unwrap(), &[one, zero, zero]); + assert_eq!(&hist.encode_measurement(&10).unwrap(), &[one, zero, zero]); + assert_eq!(&hist.encode_measurement(&17).unwrap(), &[zero, one, zero]); + assert_eq!(&hist.encode_measurement(&20).unwrap(), &[zero, one, zero]); + assert_eq!(&hist.encode_measurement(&27).unwrap(), &[zero, zero, one]); + + // Round trip + assert_eq!( + hist.decode_result( + &hist + .truncate(hist.encode_measurement(&27).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + [0, 0, 1] + ); + + // Invalid bucket boundaries. + Histogram::<TestField>::new(vec![10, 0]).unwrap_err(); + Histogram::<TestField>::new(vec![10, 10]).unwrap_err(); + + // Test valid inputs. + flp_validity_test( + &hist, + &hist.encode_measurement(&0).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one, zero, zero]), + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &hist.encode_measurement(&17).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero, one, zero]), + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &hist.encode_measurement(&1337).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero, zero, one]), + }, + ) + .unwrap(); + + // Test invalid inputs. + flp_validity_test( + &hist, + &[zero, zero, nine], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[zero, one, one], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[one, one, one], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[zero, zero, zero], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + } + + fn test_count_vec<F, S>(f: F) + where + F: Fn(usize) -> CountVec<TestField, S>, + S: 'static + ParallelSumGadget<TestField, BlindPolyEval<TestField>> + Eq, + { + let one = TestField::one(); + let nine = TestField::from(9); + + // Test on valid inputs. + for len in 0..10 { + let count_vec = f(len); + flp_validity_test( + &count_vec, + &count_vec.encode_measurement(&vec![1; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one; len]), + }, + ) + .unwrap(); + } + + let len = 100; + let count_vec = f(len); + flp_validity_test( + &count_vec, + &count_vec.encode_measurement(&vec![1; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one; len]), + }, + ) + .unwrap(); + + // Test on invalid inputs. + for len in 1..10 { + let count_vec = f(len); + flp_validity_test( + &count_vec, + &vec![nine; len], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + } + + // Round trip + let want = vec![1; len]; + assert_eq!( + count_vec + .decode_result( + &count_vec + .truncate(count_vec.encode_measurement(&want).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + want + ); + } + + #[test] + fn test_count_vec_serial() { + test_count_vec(CountVec::<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>>::new) + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_count_vec_parallel() { + test_count_vec(CountVec::<TestField, ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>>::new) + } + + #[test] + fn count_vec_serial_long() { + let typ: CountVec<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>> = + CountVec::new(1000); + let input = typ.encode_measurement(&vec![0; 1000]).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn count_vec_parallel_long() { + let typ: CountVec< + TestField, + ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>, + > = CountVec::new(1000); + let input = typ.encode_measurement(&vec![0; 1000]).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + fn flp_validity_test<T: Type>( + typ: &T, + input: &[T::Field], + t: &ValidityTestCase<T::Field>, + ) -> Result<(), FlpError> { + let mut gadgets = typ.gadget(); + + if input.len() != typ.input_len() { + return Err(FlpError::Test(format!( + "unexpected input length: got {}; want {}", + input.len(), + typ.input_len() + ))); + } + + if typ.query_rand_len() != gadgets.len() { + return Err(FlpError::Test(format!( + "query rand length: got {}; want {}", + typ.query_rand_len(), + gadgets.len() + ))); + } + + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + + // Run the validity circuit. + let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?; + if v != T::Field::zero() && t.expect_valid { + return Err(FlpError::Test(format!( + "expected valid input: valid() returned {}", + v + ))); + } + if v == T::Field::zero() && !t.expect_valid { + return Err(FlpError::Test(format!( + "expected invalid input: valid() returned {}", + v + ))); + } + + // Generate the proof. + let proof = typ.prove(input, &prove_rand, &joint_rand)?; + if proof.len() != typ.proof_len() { + return Err(FlpError::Test(format!( + "unexpected proof length: got {}; want {}", + proof.len(), + typ.proof_len() + ))); + } + + // Query the proof. + let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?; + if verifier.len() != typ.verifier_len() { + return Err(FlpError::Test(format!( + "unexpected verifier length: got {}; want {}", + verifier.len(), + typ.verifier_len() + ))); + } + + // Decide if the input is valid. + let res = typ.decide(&verifier)?; + if res != t.expect_valid { + return Err(FlpError::Test(format!( + "decision is {}; want {}", + res, t.expect_valid, + ))); + } + + // Run distributed FLP. + let input_shares: Vec<Vec<T::Field>> = split_vector(input, NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let verifier: Vec<T::Field> = (0..NUM_SHARES) + .map(|i| { + typ.query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + NUM_SHARES, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + + let res = typ.decide(&verifier)?; + if res != t.expect_valid { + return Err(FlpError::Test(format!( + "distributed decision is {}; want {}", + res, t.expect_valid, + ))); + } + + // Try verifying various proof mutants. + for i in 0..proof.len() { + let mut mutated_proof = proof.clone(); + mutated_proof[i] += T::Field::one(); + let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?; + if typ.decide(&verifier)? { + return Err(FlpError::Test(format!( + "decision for proof mutant {} is {}; want {}", + i, true, false, + ))); + } + } + + // Try verifying a proof that is too short. + let mut mutated_proof = proof.clone(); + mutated_proof.truncate(gadgets[0].arity() - 1); + if typ + .query(input, &mutated_proof, &query_rand, &joint_rand, 1) + .is_ok() + { + return Err(FlpError::Test( + "query on short proof succeeded; want failure".to_string(), + )); + } + + // Try verifying a proof that is too long. + let mut mutated_proof = proof; + mutated_proof.extend_from_slice(&[T::Field::one(); 17]); + if typ + .query(input, &mutated_proof, &query_rand, &joint_rand, 1) + .is_ok() + { + return Err(FlpError::Test( + "query on long proof succeeded; want failure".to_string(), + )); + } + + if let Some(ref want) = t.expected_output { + let got = typ.truncate(input.to_vec())?; + + if got.len() != typ.output_len() { + return Err(FlpError::Test(format!( + "unexpected output length: got {}; want {}", + got.len(), + typ.output_len() + ))); + } + + if &got != want { + return Err(FlpError::Test(format!( + "unexpected output: got {:?}; want {:?}", + got, want + ))); + } + } + + Ok(()) + } +} |