summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/flp
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/flp')
-rw-r--r--third_party/rust/prio/src/flp/gadgets.rs715
-rw-r--r--third_party/rust/prio/src/flp/types.rs1199
2 files changed, 1914 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs
new file mode 100644
index 0000000000..fd2be84eaa
--- /dev/null
+++ b/third_party/rust/prio/src/flp/gadgets.rs
@@ -0,0 +1,715 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of gadgets.
+
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
+use crate::field::FieldElement;
+use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
+use crate::polynomial::{poly_deg, poly_eval, poly_mul};
+
+#[cfg(feature = "multithreaded")]
+use rayon::prelude::*;
+
+use std::any::Any;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
+/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
+const FFT_THRESHOLD: usize = 60;
+
+/// An arity-2 gadget that multiples its inputs.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Mul<F: FieldElement> {
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> Mul<F> {
+ /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
+ /// called by the validity circuit.
+ pub fn new(num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(2, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ // Multiply input polynomials directly.
+ pub(crate) fn call_poly_direct(
+ &mut self,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+ ) -> Result<(), FlpError> {
+ let v = poly_mul(&inp[0], &inp[1]);
+ outp[..v.len()].clone_from_slice(&v);
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let mut buf = vec![F::zero(); n];
+
+ discrete_fourier_transform(&mut buf, &inp[0], n)?;
+ discrete_fourier_transform(outp, &inp[1], n)?;
+
+ for i in 0..n {
+ buf[i] *= outp[i];
+ }
+
+ discrete_fourier_transform(outp, &buf, n)?;
+ discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for Mul<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[0] * inp[1])
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ 2
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-1 gadget that evaluates its input on some polynomial.
+//
+// TODO Make `poly` an array of length determined by a const generic.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct PolyEval<F: FieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> PolyEval<F> {
+ /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
+ /// this gadget is called by the validity circuit.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+}
+
+impl<F: FieldElement> PolyEval<F> {
+ // Multiply input polynomials directly.
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ outp[0] = self.poly[0];
+ let mut x = inp[0].to_vec();
+ for i in 1..self.poly.len() {
+ for j in 0..x.len() {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ x = poly_mul(&x, &inp[0]);
+ }
+ }
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let inp = &inp[0];
+
+ let mut inp_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut inp_vals, inp, n)?;
+
+ let mut x_vals = inp_vals.clone();
+ let mut x = vec![F::zero(); n];
+ x[..inp.len()].clone_from_slice(inp);
+
+ outp[0] = self.poly[0];
+ for i in 1..self.poly.len() {
+ for j in 0..n {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ x_vals[j] *= inp_vals[j];
+ }
+
+ discrete_fourier_transform(&mut x, &x_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for PolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for item in outp.iter_mut() {
+ *item = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 1
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly)
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-2 gadget that returns `poly(in[0]) * in[1]` for some polynomial `poly`.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct BlindPolyEval<F: FieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for the outer FFT multiplication.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> BlindPolyEval<F> {
+ /// Returns a `BlindPolyEval` gadget for polynomial `poly`.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly) + 1, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let x = &inp[0];
+ let y = &inp[1];
+
+ let mut z = y.to_vec();
+ for i in 0..self.poly.len() {
+ for j in 0..z.len() {
+ outp[j] += self.poly[i] * z[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ z = poly_mul(&z, x);
+ }
+ }
+ Ok(())
+ }
+
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let x = &inp[0];
+ let y = &inp[1];
+
+ let mut x_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut x_vals, x, n)?;
+
+ let mut z_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut z_vals, y, n)?;
+
+ let mut z = vec![F::zero(); n];
+ let mut z_len = y.len();
+ z[..y.len()].clone_from_slice(y);
+
+ for i in 0..self.poly.len() {
+ for j in 0..z_len {
+ outp[j] += self.poly[i] * z[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ z_vals[j] *= x_vals[j];
+ }
+
+ discrete_fourier_transform(&mut z, &z_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut z, n, self.n_inv);
+ z_len += x.len();
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for BlindPolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[1] * poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly) + 1
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// Marker trait for abstracting over [`ParallelSum`].
+pub trait ParallelSumGadget<F: FieldElement, G>: Gadget<F> + Debug {
+ /// Wraps `inner` into a sum gadget with `chunks` chunks
+ fn new(inner: G, chunks: usize) -> Self;
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSum<F: FieldElement, G: Gadget<F>> {
+ inner: G,
+ chunks: usize,
+ phantom: PhantomData<F>,
+}
+
+impl<F: FieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> for ParallelSum<F, G> {
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ inner,
+ chunks,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F: FieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ let mut outp = F::zero();
+ for chunk in inp.chunks(self.inner.arity()) {
+ outp += self.inner.call(chunk)?;
+ }
+ Ok(outp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ let mut partial_outp = vec![F::zero(); outp.len()];
+
+ for chunk in inp.chunks(self.inner.arity()) {
+ self.inner.call_poly(&mut partial_outp, chunk)?;
+ for i in 0..outp.len() {
+ outp[i] += partial_outp[i]
+ }
+ }
+
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.chunks * self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
+/// evaluation is multithreaded.
+#[cfg(feature = "multithreaded")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSumMultithreaded<F: FieldElement, G: Gadget<F>> {
+ serial_sum: ParallelSum<F, G>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
+where
+ F: FieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ serial_sum: ParallelSum::new(inner, chunks),
+ }
+ }
+}
+
+/// Data structures passed between fold operations in [`ParallelSumMultithreaded`].
+#[cfg(feature = "multithreaded")]
+struct ParallelSumFoldState<F, G> {
+ /// Inner gadget.
+ inner: G,
+ /// Output buffer for `call_poly()`.
+ partial_output: Vec<F>,
+ /// Sum accumulator.
+ partial_sum: Vec<F>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumFoldState<F, G> {
+ fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G>
+ where
+ G: Clone,
+ F: FieldElement,
+ {
+ ParallelSumFoldState {
+ inner: gadget.clone(),
+ partial_output: vec![F::zero(); length],
+ partial_sum: vec![F::zero(); length],
+ }
+ }
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
+where
+ F: FieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ self.serial_sum.call(inp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
+ // gadget on each input polynomial, using the first temporary buffer as an output buffer.
+ // Then accumulate that result into the second temporary buffer, which acts as a running
+ // sum. Then, discard everything but the partial sums, add them, and finally copy the sum
+ // to the output parameter. This is equivalent to the single threaded calculation in
+ // ParallelSum, since we only rearrange additions, and field addition is associative.
+ let res = inp
+ .par_chunks(self.serial_sum.inner.arity())
+ .fold(
+ || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()),
+ |mut state, chunk| {
+ state
+ .inner
+ .call_poly(&mut state.partial_output, chunk)
+ .unwrap();
+ for (sum_elem, output_elem) in state
+ .partial_sum
+ .iter_mut()
+ .zip(state.partial_output.iter())
+ {
+ *sum_elem += *output_elem;
+ }
+ state
+ },
+ )
+ .map(|state| state.partial_sum)
+ .reduce(
+ || vec![F::zero(); outp.len()],
+ |mut x, y| {
+ for (xi, yi) in x.iter_mut().zip(y.iter()) {
+ *xi += *yi;
+ }
+ x
+ },
+ );
+
+ outp.copy_from_slice(&res[..]);
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.serial_sum.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.serial_sum.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.serial_sum.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+// Check that the input parameters of g.call() are well-formed.
+fn gadget_call_check<F: FieldElement, G: Gadget<F>>(
+ gadget: &G,
+ in_len: usize,
+) -> Result<(), FlpError> {
+ if in_len != gadget.arity() {
+ return Err(FlpError::Gadget(format!(
+ "unexpected number of inputs: got {}; want {}",
+ in_len,
+ gadget.arity()
+ )));
+ }
+
+ if in_len == 0 {
+ return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
+ }
+
+ Ok(())
+}
+
+// Check that the input parameters of g.call_poly() are well-formed.
+fn gadget_call_poly_check<F: FieldElement, G: Gadget<F>>(
+ gadget: &G,
+ outp: &[F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError>
+where
+ G: Gadget<F>,
+{
+ gadget_call_check(gadget, inp.len())?;
+
+ for i in 1..inp.len() {
+ if inp[i].len() != inp[0].len() {
+ return Err(FlpError::Gadget(
+ "gadget called on wire polynomials with different lengths".to_string(),
+ ));
+ }
+ }
+
+ let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
+ if outp.len() != expected {
+ return Err(FlpError::Gadget(format!(
+ "incorrect output length: got {}; want {}",
+ outp.len(),
+ expected
+ )));
+ }
+
+ Ok(())
+}
+
+#[inline]
+fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
+ gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::field::{random_vector, Field96 as TestField};
+ use crate::prng::Prng;
+
+ #[test]
+ fn test_mul() {
+ // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
+ // naive multiplication code path.
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+
+ // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
+ // FFT-based polynomial multiplication.
+ let num_calls = FFT_THRESHOLD;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_blind_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_parallel_sum() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+ let num_calls = 10;
+ let chunks = 23;
+
+ let mut g = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_parallel_sum_multithreaded() {
+ use std::iter;
+
+ for num_calls in [1, 10, 100] {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+ let chunks = 23;
+
+ let mut g =
+ ParallelSumMultithreaded::new(BlindPolyEval::new(poly.clone(), num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+
+ // Test that the multithreaded version has the same output as the normal version.
+ let mut g_serial = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
+ assert_eq!(g.arity(), g_serial.arity());
+ assert_eq!(g.degree(), g_serial.degree());
+ assert_eq!(g.calls(), g_serial.calls());
+
+ let arity = g.arity();
+ let degree = g.degree();
+
+ // Test that both gadgets evaluate to the same value when run on scalar inputs.
+ let inp: Vec<TestField> = random_vector(arity).unwrap();
+ let result = g.call(&inp).unwrap();
+ let result_serial = g_serial.call(&inp).unwrap();
+ assert_eq!(result, result_serial);
+
+ // Test that both gadgets evaluate to the same value when run on polynomial inputs.
+ let mut poly_outp =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut poly_outp_serial =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut prng: Prng<TestField, _> = Prng::new().unwrap();
+ let poly_inp: Vec<_> = iter::repeat_with(|| {
+ iter::repeat_with(|| prng.get())
+ .take(1 + num_calls)
+ .collect::<Vec<_>>()
+ })
+ .take(arity)
+ .collect();
+
+ g.call_poly(&mut poly_outp, &poly_inp).unwrap();
+ g_serial
+ .call_poly(&mut poly_outp_serial, &poly_inp)
+ .unwrap();
+ assert_eq!(poly_outp, poly_outp_serial);
+ }
+ }
+
+ // Test that calling g.call_poly() and evaluating the output at a given point is equivalent
+ // to evaluating each of the inputs at the same point and applying g.call() on the results.
+ fn gadget_test<F: FieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
+ let wire_poly_len = (1 + num_calls).next_power_of_two();
+ let mut prng = Prng::new().unwrap();
+ let mut inp = vec![F::zero(); g.arity()];
+ let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
+ let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];
+
+ let r = prng.get();
+ for i in 0..g.arity() {
+ for j in 0..wire_poly_len {
+ wire_polys[i][j] = prng.get();
+ }
+ inp[i] = poly_eval(&wire_polys[i], r);
+ }
+
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ let want = g.call(&inp).unwrap();
+ assert_eq!(got, want);
+
+ // Repeat the call to make sure that the gadget's memory is reset properly between calls.
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ assert_eq!(got, want);
+ }
+}
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs
new file mode 100644
index 0000000000..83b9752f69
--- /dev/null
+++ b/third_party/rust/prio/src/flp/types.rs
@@ -0,0 +1,1199 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of [`Type`](crate::flp::Type) implementations.
+
+use crate::field::{FieldElement, FieldElementExt};
+use crate::flp::gadgets::{BlindPolyEval, Mul, ParallelSumGadget, PolyEval};
+use crate::flp::{FlpError, Gadget, Type};
+use crate::polynomial::poly_range_check;
+use std::convert::TryInto;
+use std::marker::PhantomData;
+
+/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of
+/// the measurements (i.e., the total number of `1s`).
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Count<F> {
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Count<F> {
+ /// Return a new [`Count`] type instance.
+ pub fn new() -> Self {
+ Self {
+ range_checker: poly_range_check(0, 2),
+ }
+ }
+}
+
+impl<F: FieldElement> Default for Count<F> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<F: FieldElement> Type for Count<F> {
+ const ID: u32 = 0x00000000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let max = F::valid_integer_try_from(1)?;
+ if *value > max {
+ return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
+ }
+
+ Ok(vec![F::from(*value)])
+ }
+
+ fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
+ decode_result(data)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(Mul::new(1))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ Ok(g[0].call(&[input[0], input[0]])? - input[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ 1
+ }
+
+ fn proof_len(&self) -> usize {
+ 5
+ }
+
+ fn verifier_len(&self) -> usize {
+ 4
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 0
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of
+/// the measurements.
+///
+/// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3].
+///
+/// [BBCG+19]: https://ia.cr/2019/188
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Sum<F: FieldElement> {
+ bits: usize,
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Sum<F> {
+ /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0,
+ /// 2^bits)`.
+ pub fn new(bits: usize) -> Result<Self, FlpError> {
+ if !F::valid_integer_bitlength(bits) {
+ return Err(FlpError::Encode(
+ "invalid bits: number of bits exceeds maximum number of bits in this field"
+ .to_string(),
+ ));
+ }
+ Ok(Self {
+ bits,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FieldElement> Type for Sum<F> {
+ const ID: u32 = 0x00000001;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ Ok(v)
+ }
+
+ fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
+ decode_result(data)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.clone(),
+ self.bits,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let res = F::decode_from_bitvector_representation(&input)?;
+ Ok(vec![res])
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.bits).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ 1
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
+/// aggregate is the arithmetic average.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Average<F: FieldElement> {
+ bits: usize,
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Average<F> {
+ /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0,
+ /// 2^bits)`.
+ pub fn new(bits: usize) -> Result<Self, FlpError> {
+ if !F::valid_integer_bitlength(bits) {
+ return Err(FlpError::Encode(
+ "invalid bits: number of bits exceeds maximum number of bits in this field"
+ .to_string(),
+ ));
+ }
+ Ok(Self {
+ bits,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FieldElement> Type for Average<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = f64;
+ type Field = F;
+
+ fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ Ok(v)
+ }
+
+ fn decode_result(&self, data: &[F], num_measurements: usize) -> Result<f64, FlpError> {
+ // Compute the average from the aggregated sum.
+ let data = decode_result(data)?;
+ let data: u64 = data.try_into().map_err(|err| {
+ FlpError::Decode(format!("failed to convert {:?} to u64: {}", data, err,))
+ })?;
+ let result = (data as f64) / (num_measurements as f64);
+ Ok(result)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.clone(),
+ self.bits,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let res = F::decode_from_bitvector_representation(&input)?;
+ Ok(vec![res])
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.bits).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ 1
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// The histogram type. Each measurement is a non-negative integer and the aggregate is a histogram
+/// approximating the distribution of the measurements.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Histogram<F: FieldElement> {
+ buckets: Vec<F::Integer>,
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Histogram<F> {
+ /// Return a new [`Histogram`] type with the given buckets.
+ pub fn new(buckets: Vec<F::Integer>) -> Result<Self, FlpError> {
+ if buckets.len() >= u32::MAX as usize {
+ return Err(FlpError::Encode(
+ "invalid buckets: number of buckets exceeds maximum permitted".to_string(),
+ ));
+ }
+
+ if !buckets.is_empty() {
+ for i in 0..buckets.len() - 1 {
+ if buckets[i + 1] <= buckets[i] {
+ return Err(FlpError::Encode(
+ "invalid buckets: out-of-order boundary".to_string(),
+ ));
+ }
+ }
+ }
+
+ Ok(Self {
+ buckets,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FieldElement> Type for Histogram<F> {
+ const ID: u32 = 0x00000002;
+ type Measurement = F::Integer;
+ type AggregateResult = Vec<F::Integer>;
+ type Field = F;
+
+ fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let mut data = vec![F::zero(); self.buckets.len() + 1];
+
+ let bucket = match self.buckets.binary_search(measurement) {
+ Ok(i) => i, // on a bucket boundary
+ Err(i) => i, // smaller than the i-th bucket boundary
+ };
+
+ data[bucket] = F::one();
+ Ok(data)
+ }
+
+ fn decode_result(
+ &self,
+ data: &[F],
+ _num_measurements: usize,
+ ) -> Result<Vec<F::Integer>, FlpError> {
+ decode_result_vec(data, self.buckets.len() + 1)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.to_vec(),
+ self.input_len(),
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ // Check that each element of `input` is a 0 or 1.
+ let range_check = call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])?;
+
+ // Check that the elements of `input` sum to 1.
+ let mut sum_check = -(F::one() / F::from(F::valid_integer_try_from(num_shares)?));
+ for val in input.iter() {
+ sum_check += *val;
+ }
+
+ // Take a random linear combination of both checks.
+ let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check;
+ Ok(out)
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ self.buckets.len() + 1
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.input_len()).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// A sequence of counters. This type uses a neat trick from [[BBCG+19], Corollary 4.9] to reduce
+/// the proof size to roughly the square root of the input size.
+///
+/// [BBCG+19]: https://eprint.iacr.org/2019/188
+#[derive(Debug, PartialEq, Eq)]
+pub struct CountVec<F, S> {
+ range_checker: Vec<F>,
+ len: usize,
+ chunk_len: usize,
+ gadget_calls: usize,
+ phantom: PhantomData<S>,
+}
+
+impl<F: FieldElement, S: ParallelSumGadget<F, BlindPolyEval<F>>> CountVec<F, S> {
+ /// Returns a new [`CountVec`] with the given length.
+ pub fn new(len: usize) -> Self {
+ // The optimal chunk length is the square root of the input length. If the input length is
+ // not a perfect square, then round down. If the result is 0, then let the chunk length be
+ // 1 so that the underlying gadget can still be called.
+ let chunk_len = std::cmp::max(1, (len as f64).sqrt() as usize);
+
+ let mut gadget_calls = len / chunk_len;
+ if len % chunk_len != 0 {
+ gadget_calls += 1;
+ }
+
+ Self {
+ range_checker: poly_range_check(0, 2),
+ len,
+ chunk_len,
+ gadget_calls,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F: FieldElement, S> Clone for CountVec<F, S> {
+ fn clone(&self) -> Self {
+ Self {
+ range_checker: self.range_checker.clone(),
+ len: self.len,
+ chunk_len: self.chunk_len,
+ gadget_calls: self.gadget_calls,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F, S> Type for CountVec<F, S>
+where
+ F: FieldElement,
+ S: ParallelSumGadget<F, BlindPolyEval<F>> + Eq + 'static,
+{
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = Vec<F::Integer>;
+ type AggregateResult = Vec<F::Integer>;
+ type Field = F;
+
+ fn encode_measurement(&self, measurement: &Vec<F::Integer>) -> Result<Vec<F>, FlpError> {
+ if measurement.len() != self.len {
+ return Err(FlpError::Encode(format!(
+ "unexpected measurement length: got {}; want {}",
+ measurement.len(),
+ self.len
+ )));
+ }
+
+ let max = F::Integer::from(F::one());
+ for value in measurement {
+ if *value > max {
+ return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
+ }
+ }
+
+ Ok(measurement.iter().map(|value| F::from(*value)).collect())
+ }
+
+ fn decode_result(
+ &self,
+ data: &[F],
+ _num_measurements: usize,
+ ) -> Result<Vec<F::Integer>, FlpError> {
+ decode_result_vec(data, self.len)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(S::new(
+ BlindPolyEval::new(self.range_checker.clone(), self.gadget_calls),
+ self.chunk_len,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ let s = F::from(F::valid_integer_try_from(num_shares)?).inv();
+ let mut r = joint_rand[0];
+ let mut outp = F::zero();
+ let mut padded_chunk = vec![F::zero(); 2 * self.chunk_len];
+ for chunk in input.chunks(self.chunk_len) {
+ let d = chunk.len();
+ for i in 0..self.chunk_len {
+ if i < d {
+ padded_chunk[2 * i] = chunk[i];
+ } else {
+ // If the chunk is smaller than the chunk length, then copy the last element of
+ // the chunk into the remaining slots.
+ padded_chunk[2 * i] = chunk[d - 1];
+ }
+ padded_chunk[2 * i + 1] = r * s;
+ r *= joint_rand[0];
+ }
+
+ outp += g[0].call(&padded_chunk)?;
+ }
+
+ Ok(outp)
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ self.len
+ }
+
+ fn proof_len(&self) -> usize {
+ (self.chunk_len * 2) + 3 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1
+ }
+
+ fn verifier_len(&self) -> usize {
+ 2 + self.chunk_len * 2
+ }
+
+ fn output_len(&self) -> usize {
+ self.len
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ self.chunk_len * 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// Compute a random linear combination of the result of calls of `g` on each element of `input`.
+///
+/// # Arguments
+///
+/// * `g` - The gadget to be applied elementwise
+/// * `input` - The vector on whose elements to apply `g`
+/// * `rnd` - The randomness used for the linear combination
+pub(crate) fn call_gadget_on_vec_entries<F: FieldElement>(
+ g: &mut Box<dyn Gadget<F>>,
+ input: &[F],
+ rnd: F,
+) -> Result<F, FlpError> {
+ let mut range_check = F::zero();
+ let mut r = rnd;
+ for chunk in input.chunks(1) {
+ range_check += r * g.call(chunk)?;
+ r *= rnd;
+ }
+ Ok(range_check)
+}
+
+/// Given a vector `data` of field elements which should contain exactly one entry, return the
+/// integer representation of that entry.
+pub(crate) fn decode_result<F: FieldElement>(data: &[F]) -> Result<F::Integer, FlpError> {
+ if data.len() != 1 {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ Ok(F::Integer::from(data[0]))
+}
+
+/// Given a vector `data` of field elements, return a vector containing the corresponding integer
+/// representations, if the number of entries matches `expected_len`.
+pub(crate) fn decode_result_vec<F: FieldElement>(
+ data: &[F],
+ expected_len: usize,
+) -> Result<Vec<F::Integer>, FlpError> {
+ if data.len() != expected_len {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ Ok(data.iter().map(|elem| F::Integer::from(*elem)).collect())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{random_vector, split_vector, Field64 as TestField};
+ use crate::flp::gadgets::ParallelSum;
+ #[cfg(feature = "multithreaded")]
+ use crate::flp::gadgets::ParallelSumMultithreaded;
+
+ // Number of shares to split input and proofs into in `flp_test`.
+ const NUM_SHARES: usize = 3;
+
+ struct ValidityTestCase<F> {
+ expect_valid: bool,
+ expected_output: Option<Vec<F>>,
+ }
+
+ #[test]
+ fn test_count() {
+ let count: Count<TestField> = Count::new();
+ let zero = TestField::zero();
+ let one = TestField::one();
+
+ // Round trip
+ assert_eq!(
+ count
+ .decode_result(
+ &count
+ .truncate(count.encode_measurement(&1).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ 1,
+ );
+
+ // Test FLP on valid input.
+ flp_validity_test(
+ &count,
+ &count.encode_measurement(&1).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &count,
+ &count.encode_measurement(&0).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero]),
+ },
+ )
+ .unwrap();
+
+ // Test FLP on invalid input.
+ flp_validity_test(
+ &count,
+ &[TestField::from(1337)],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ // Try running the validity circuit on an input that's too short.
+ count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err();
+ count
+ .valid(&mut count.gadget(), &[1.into(), 2.into()], &[], 1)
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_sum() {
+ let sum = Sum::new(11).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ // Round trip
+ assert_eq!(
+ sum.decode_result(
+ &sum.truncate(sum.encode_measurement(&27).unwrap()).unwrap(),
+ 1
+ )
+ .unwrap(),
+ 27,
+ );
+
+ // Test FLP on valid input.
+ flp_validity_test(
+ &sum,
+ &sum.encode_measurement(&1337).unwrap(),
+ &ValidityTestCase {
+ expect_valid: true,
+ expected_output: Some(vec![TestField::from(1337)]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(0).unwrap(),
+ &[],
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(2).unwrap(),
+ &[one, zero],
+ &ValidityTestCase {
+ expect_valid: true,
+ expected_output: Some(vec![one]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(9).unwrap(),
+ &[one, zero, one, one, zero, one, one, one, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![TestField::from(237)]),
+ },
+ )
+ .unwrap();
+
+ // Test FLP on invalid input.
+ flp_validity_test(
+ &Sum::new(3).unwrap(),
+ &[one, nine, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(5).unwrap(),
+ &[zero, zero, zero, zero, nine],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_average() {
+ let average = Average::new(11).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let ten = TestField::from(10);
+
+ // Testing that average correctly quotients the sum of the measurements
+ // by the number of measurements.
+ assert_eq!(average.decode_result(&[zero], 1).unwrap(), 0.0);
+ assert_eq!(average.decode_result(&[one], 1).unwrap(), 1.0);
+ assert_eq!(average.decode_result(&[one], 2).unwrap(), 0.5);
+ assert_eq!(average.decode_result(&[one], 4).unwrap(), 0.25);
+ assert_eq!(average.decode_result(&[ten], 8).unwrap(), 1.25);
+
+ // round trip of 12 with `num_measurements`=1
+ assert_eq!(
+ average
+ .decode_result(
+ &average
+ .truncate(average.encode_measurement(&12).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ 12.0
+ );
+
+ // round trip of 12 with `num_measurements`=24
+ assert_eq!(
+ average
+ .decode_result(
+ &average
+ .truncate(average.encode_measurement(&12).unwrap())
+ .unwrap(),
+ 24
+ )
+ .unwrap(),
+ 0.5
+ );
+ }
+
+ #[test]
+ fn test_histogram() {
+ let hist = Histogram::new(vec![10, 20]).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ assert_eq!(&hist.encode_measurement(&7).unwrap(), &[one, zero, zero]);
+ assert_eq!(&hist.encode_measurement(&10).unwrap(), &[one, zero, zero]);
+ assert_eq!(&hist.encode_measurement(&17).unwrap(), &[zero, one, zero]);
+ assert_eq!(&hist.encode_measurement(&20).unwrap(), &[zero, one, zero]);
+ assert_eq!(&hist.encode_measurement(&27).unwrap(), &[zero, zero, one]);
+
+ // Round trip
+ assert_eq!(
+ hist.decode_result(
+ &hist
+ .truncate(hist.encode_measurement(&27).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ [0, 0, 1]
+ );
+
+ // Invalid bucket boundaries.
+ Histogram::<TestField>::new(vec![10, 0]).unwrap_err();
+ Histogram::<TestField>::new(vec![10, 10]).unwrap_err();
+
+ // Test valid inputs.
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&0).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one, zero, zero]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&17).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero, one, zero]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&1337).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero, zero, one]),
+ },
+ )
+ .unwrap();
+
+ // Test invalid inputs.
+ flp_validity_test(
+ &hist,
+ &[zero, zero, nine],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[zero, one, one],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[one, one, one],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[zero, zero, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+ }
+
+ fn test_count_vec<F, S>(f: F)
+ where
+ F: Fn(usize) -> CountVec<TestField, S>,
+ S: 'static + ParallelSumGadget<TestField, BlindPolyEval<TestField>> + Eq,
+ {
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ // Test on valid inputs.
+ for len in 0..10 {
+ let count_vec = f(len);
+ flp_validity_test(
+ &count_vec,
+ &count_vec.encode_measurement(&vec![1; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one; len]),
+ },
+ )
+ .unwrap();
+ }
+
+ let len = 100;
+ let count_vec = f(len);
+ flp_validity_test(
+ &count_vec,
+ &count_vec.encode_measurement(&vec![1; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one; len]),
+ },
+ )
+ .unwrap();
+
+ // Test on invalid inputs.
+ for len in 1..10 {
+ let count_vec = f(len);
+ flp_validity_test(
+ &count_vec,
+ &vec![nine; len],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+ }
+
+ // Round trip
+ let want = vec![1; len];
+ assert_eq!(
+ count_vec
+ .decode_result(
+ &count_vec
+ .truncate(count_vec.encode_measurement(&want).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ want
+ );
+ }
+
+ #[test]
+ fn test_count_vec_serial() {
+ test_count_vec(CountVec::<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>>::new)
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_count_vec_parallel() {
+ test_count_vec(CountVec::<TestField, ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>>::new)
+ }
+
+ #[test]
+ fn count_vec_serial_long() {
+ let typ: CountVec<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>> =
+ CountVec::new(1000);
+ let input = typ.encode_measurement(&vec![0; 1000]).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn count_vec_parallel_long() {
+ let typ: CountVec<
+ TestField,
+ ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>,
+ > = CountVec::new(1000);
+ let input = typ.encode_measurement(&vec![0; 1000]).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ fn flp_validity_test<T: Type>(
+ typ: &T,
+ input: &[T::Field],
+ t: &ValidityTestCase<T::Field>,
+ ) -> Result<(), FlpError> {
+ let mut gadgets = typ.gadget();
+
+ if input.len() != typ.input_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ typ.input_len()
+ )));
+ }
+
+ if typ.query_rand_len() != gadgets.len() {
+ return Err(FlpError::Test(format!(
+ "query rand length: got {}; want {}",
+ typ.query_rand_len(),
+ gadgets.len()
+ )));
+ }
+
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+
+ // Run the validity circuit.
+ let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?;
+ if v != T::Field::zero() && t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "expected valid input: valid() returned {}",
+ v
+ )));
+ }
+ if v == T::Field::zero() && !t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "expected invalid input: valid() returned {}",
+ v
+ )));
+ }
+
+ // Generate the proof.
+ let proof = typ.prove(input, &prove_rand, &joint_rand)?;
+ if proof.len() != typ.proof_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected proof length: got {}; want {}",
+ proof.len(),
+ typ.proof_len()
+ )));
+ }
+
+ // Query the proof.
+ let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?;
+ if verifier.len() != typ.verifier_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected verifier length: got {}; want {}",
+ verifier.len(),
+ typ.verifier_len()
+ )));
+ }
+
+ // Decide if the input is valid.
+ let res = typ.decide(&verifier)?;
+ if res != t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "decision is {}; want {}",
+ res, t.expect_valid,
+ )));
+ }
+
+ // Run distributed FLP.
+ let input_shares: Vec<Vec<T::Field>> = split_vector(input, NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let verifier: Vec<T::Field> = (0..NUM_SHARES)
+ .map(|i| {
+ typ.query(
+ &input_shares[i],
+ &proof_shares[i],
+ &query_rand,
+ &joint_rand,
+ NUM_SHARES,
+ )
+ .unwrap()
+ })
+ .reduce(|mut left, right| {
+ for (x, y) in left.iter_mut().zip(right.iter()) {
+ *x += *y;
+ }
+ left
+ })
+ .unwrap();
+
+ let res = typ.decide(&verifier)?;
+ if res != t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "distributed decision is {}; want {}",
+ res, t.expect_valid,
+ )));
+ }
+
+ // Try verifying various proof mutants.
+ for i in 0..proof.len() {
+ let mut mutated_proof = proof.clone();
+ mutated_proof[i] += T::Field::one();
+ let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?;
+ if typ.decide(&verifier)? {
+ return Err(FlpError::Test(format!(
+ "decision for proof mutant {} is {}; want {}",
+ i, true, false,
+ )));
+ }
+ }
+
+ // Try verifying a proof that is too short.
+ let mut mutated_proof = proof.clone();
+ mutated_proof.truncate(gadgets[0].arity() - 1);
+ if typ
+ .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .is_ok()
+ {
+ return Err(FlpError::Test(
+ "query on short proof succeeded; want failure".to_string(),
+ ));
+ }
+
+ // Try verifying a proof that is too long.
+ let mut mutated_proof = proof;
+ mutated_proof.extend_from_slice(&[T::Field::one(); 17]);
+ if typ
+ .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .is_ok()
+ {
+ return Err(FlpError::Test(
+ "query on long proof succeeded; want failure".to_string(),
+ ));
+ }
+
+ if let Some(ref want) = t.expected_output {
+ let got = typ.truncate(input.to_vec())?;
+
+ if got.len() != typ.output_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected output length: got {}; want {}",
+ got.len(),
+ typ.output_len()
+ )));
+ }
+
+ if &got != want {
+ return Err(FlpError::Test(format!(
+ "unexpected output: got {:?}; want {:?}",
+ got, want
+ )));
+ }
+ }
+
+ Ok(())
+ }
+}