summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/flp/gadgets.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/prio/src/flp/gadgets.rs
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/prio/src/flp/gadgets.rs')
-rw-r--r--third_party/rust/prio/src/flp/gadgets.rs715
1 files changed, 715 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs
new file mode 100644
index 0000000000..fd2be84eaa
--- /dev/null
+++ b/third_party/rust/prio/src/flp/gadgets.rs
@@ -0,0 +1,715 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of gadgets.
+
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
+use crate::field::FieldElement;
+use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
+use crate::polynomial::{poly_deg, poly_eval, poly_mul};
+
+#[cfg(feature = "multithreaded")]
+use rayon::prelude::*;
+
+use std::any::Any;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
+/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
+const FFT_THRESHOLD: usize = 60;
+
+/// An arity-2 gadget that multiples its inputs.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Mul<F: FieldElement> {
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> Mul<F> {
+ /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
+ /// called by the validity circuit.
+ pub fn new(num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(2, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ // Multiply input polynomials directly.
+ pub(crate) fn call_poly_direct(
+ &mut self,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+ ) -> Result<(), FlpError> {
+ let v = poly_mul(&inp[0], &inp[1]);
+ outp[..v.len()].clone_from_slice(&v);
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let mut buf = vec![F::zero(); n];
+
+ discrete_fourier_transform(&mut buf, &inp[0], n)?;
+ discrete_fourier_transform(outp, &inp[1], n)?;
+
+ for i in 0..n {
+ buf[i] *= outp[i];
+ }
+
+ discrete_fourier_transform(outp, &buf, n)?;
+ discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for Mul<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[0] * inp[1])
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ 2
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-1 gadget that evaluates its input on some polynomial.
+//
+// TODO Make `poly` an array of length determined by a const generic.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct PolyEval<F: FieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> PolyEval<F> {
+ /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
+ /// this gadget is called by the validity circuit.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+}
+
+impl<F: FieldElement> PolyEval<F> {
+ // Multiply input polynomials directly.
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ outp[0] = self.poly[0];
+ let mut x = inp[0].to_vec();
+ for i in 1..self.poly.len() {
+ for j in 0..x.len() {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ x = poly_mul(&x, &inp[0]);
+ }
+ }
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let inp = &inp[0];
+
+ let mut inp_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut inp_vals, inp, n)?;
+
+ let mut x_vals = inp_vals.clone();
+ let mut x = vec![F::zero(); n];
+ x[..inp.len()].clone_from_slice(inp);
+
+ outp[0] = self.poly[0];
+ for i in 1..self.poly.len() {
+ for j in 0..n {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ x_vals[j] *= inp_vals[j];
+ }
+
+ discrete_fourier_transform(&mut x, &x_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for PolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for item in outp.iter_mut() {
+ *item = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 1
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly)
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-2 gadget that returns `poly(in[0]) * in[1]` for some polynomial `poly`.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct BlindPolyEval<F: FieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for the outer FFT multiplication.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> BlindPolyEval<F> {
+ /// Returns a `BlindPolyEval` gadget for polynomial `poly`.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly) + 1, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let x = &inp[0];
+ let y = &inp[1];
+
+ let mut z = y.to_vec();
+ for i in 0..self.poly.len() {
+ for j in 0..z.len() {
+ outp[j] += self.poly[i] * z[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ z = poly_mul(&z, x);
+ }
+ }
+ Ok(())
+ }
+
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let x = &inp[0];
+ let y = &inp[1];
+
+ let mut x_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut x_vals, x, n)?;
+
+ let mut z_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut z_vals, y, n)?;
+
+ let mut z = vec![F::zero(); n];
+ let mut z_len = y.len();
+ z[..y.len()].clone_from_slice(y);
+
+ for i in 0..self.poly.len() {
+ for j in 0..z_len {
+ outp[j] += self.poly[i] * z[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ z_vals[j] *= x_vals[j];
+ }
+
+ discrete_fourier_transform(&mut z, &z_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut z, n, self.n_inv);
+ z_len += x.len();
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for BlindPolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[1] * poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly) + 1
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// Marker trait for abstracting over [`ParallelSum`].
+pub trait ParallelSumGadget<F: FieldElement, G>: Gadget<F> + Debug {
+ /// Wraps `inner` into a sum gadget with `chunks` chunks
+ fn new(inner: G, chunks: usize) -> Self;
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSum<F: FieldElement, G: Gadget<F>> {
+ inner: G,
+ chunks: usize,
+ phantom: PhantomData<F>,
+}
+
+impl<F: FieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> for ParallelSum<F, G> {
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ inner,
+ chunks,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F: FieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ let mut outp = F::zero();
+ for chunk in inp.chunks(self.inner.arity()) {
+ outp += self.inner.call(chunk)?;
+ }
+ Ok(outp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ let mut partial_outp = vec![F::zero(); outp.len()];
+
+ for chunk in inp.chunks(self.inner.arity()) {
+ self.inner.call_poly(&mut partial_outp, chunk)?;
+ for i in 0..outp.len() {
+ outp[i] += partial_outp[i]
+ }
+ }
+
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.chunks * self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
+/// evaluation is multithreaded.
+#[cfg(feature = "multithreaded")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSumMultithreaded<F: FieldElement, G: Gadget<F>> {
+ serial_sum: ParallelSum<F, G>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
+where
+ F: FieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ serial_sum: ParallelSum::new(inner, chunks),
+ }
+ }
+}
+
+/// Data structures passed between fold operations in [`ParallelSumMultithreaded`].
+#[cfg(feature = "multithreaded")]
+struct ParallelSumFoldState<F, G> {
+ /// Inner gadget.
+ inner: G,
+ /// Output buffer for `call_poly()`.
+ partial_output: Vec<F>,
+ /// Sum accumulator.
+ partial_sum: Vec<F>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumFoldState<F, G> {
+ fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G>
+ where
+ G: Clone,
+ F: FieldElement,
+ {
+ ParallelSumFoldState {
+ inner: gadget.clone(),
+ partial_output: vec![F::zero(); length],
+ partial_sum: vec![F::zero(); length],
+ }
+ }
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
+where
+ F: FieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ self.serial_sum.call(inp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
+ // gadget on each input polynomial, using the first temporary buffer as an output buffer.
+ // Then accumulate that result into the second temporary buffer, which acts as a running
+ // sum. Then, discard everything but the partial sums, add them, and finally copy the sum
+ // to the output parameter. This is equivalent to the single threaded calculation in
+ // ParallelSum, since we only rearrange additions, and field addition is associative.
+ let res = inp
+ .par_chunks(self.serial_sum.inner.arity())
+ .fold(
+ || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()),
+ |mut state, chunk| {
+ state
+ .inner
+ .call_poly(&mut state.partial_output, chunk)
+ .unwrap();
+ for (sum_elem, output_elem) in state
+ .partial_sum
+ .iter_mut()
+ .zip(state.partial_output.iter())
+ {
+ *sum_elem += *output_elem;
+ }
+ state
+ },
+ )
+ .map(|state| state.partial_sum)
+ .reduce(
+ || vec![F::zero(); outp.len()],
+ |mut x, y| {
+ for (xi, yi) in x.iter_mut().zip(y.iter()) {
+ *xi += *yi;
+ }
+ x
+ },
+ );
+
+ outp.copy_from_slice(&res[..]);
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.serial_sum.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.serial_sum.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.serial_sum.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+// Check that the input parameters of g.call() are well-formed.
+fn gadget_call_check<F: FieldElement, G: Gadget<F>>(
+ gadget: &G,
+ in_len: usize,
+) -> Result<(), FlpError> {
+ if in_len != gadget.arity() {
+ return Err(FlpError::Gadget(format!(
+ "unexpected number of inputs: got {}; want {}",
+ in_len,
+ gadget.arity()
+ )));
+ }
+
+ if in_len == 0 {
+ return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
+ }
+
+ Ok(())
+}
+
+// Check that the input parameters of g.call_poly() are well-formed.
+fn gadget_call_poly_check<F: FieldElement, G: Gadget<F>>(
+ gadget: &G,
+ outp: &[F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError>
+where
+ G: Gadget<F>,
+{
+ gadget_call_check(gadget, inp.len())?;
+
+ for i in 1..inp.len() {
+ if inp[i].len() != inp[0].len() {
+ return Err(FlpError::Gadget(
+ "gadget called on wire polynomials with different lengths".to_string(),
+ ));
+ }
+ }
+
+ let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
+ if outp.len() != expected {
+ return Err(FlpError::Gadget(format!(
+ "incorrect output length: got {}; want {}",
+ outp.len(),
+ expected
+ )));
+ }
+
+ Ok(())
+}
+
+#[inline]
+fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
+ gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::field::{random_vector, Field96 as TestField};
+ use crate::prng::Prng;
+
+ #[test]
+ fn test_mul() {
+ // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
+ // naive multiplication code path.
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+
+ // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
+ // FFT-based polynomial multiplication.
+ let num_calls = FFT_THRESHOLD;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_blind_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_parallel_sum() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+ let num_calls = 10;
+ let chunks = 23;
+
+ let mut g = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_parallel_sum_multithreaded() {
+ use std::iter;
+
+ for num_calls in [1, 10, 100] {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+ let chunks = 23;
+
+ let mut g =
+ ParallelSumMultithreaded::new(BlindPolyEval::new(poly.clone(), num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+
+ // Test that the multithreaded version has the same output as the normal version.
+ let mut g_serial = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
+ assert_eq!(g.arity(), g_serial.arity());
+ assert_eq!(g.degree(), g_serial.degree());
+ assert_eq!(g.calls(), g_serial.calls());
+
+ let arity = g.arity();
+ let degree = g.degree();
+
+ // Test that both gadgets evaluate to the same value when run on scalar inputs.
+ let inp: Vec<TestField> = random_vector(arity).unwrap();
+ let result = g.call(&inp).unwrap();
+ let result_serial = g_serial.call(&inp).unwrap();
+ assert_eq!(result, result_serial);
+
+ // Test that both gadgets evaluate to the same value when run on polynomial inputs.
+ let mut poly_outp =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut poly_outp_serial =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut prng: Prng<TestField, _> = Prng::new().unwrap();
+ let poly_inp: Vec<_> = iter::repeat_with(|| {
+ iter::repeat_with(|| prng.get())
+ .take(1 + num_calls)
+ .collect::<Vec<_>>()
+ })
+ .take(arity)
+ .collect();
+
+ g.call_poly(&mut poly_outp, &poly_inp).unwrap();
+ g_serial
+ .call_poly(&mut poly_outp_serial, &poly_inp)
+ .unwrap();
+ assert_eq!(poly_outp, poly_outp_serial);
+ }
+ }
+
+ // Test that calling g.call_poly() and evaluating the output at a given point is equivalent
+ // to evaluating each of the inputs at the same point and applying g.call() on the results.
+ fn gadget_test<F: FieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
+ let wire_poly_len = (1 + num_calls).next_power_of_two();
+ let mut prng = Prng::new().unwrap();
+ let mut inp = vec![F::zero(); g.arity()];
+ let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
+ let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];
+
+ let r = prng.get();
+ for i in 0..g.arity() {
+ for j in 0..wire_poly_len {
+ wire_polys[i][j] = prng.get();
+ }
+ inp[i] = poly_eval(&wire_polys[i], r);
+ }
+
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ let want = g.call(&inp).unwrap();
+ assert_eq!(got, want);
+
+ // Repeat the call to make sure that the gadget's memory is reset properly between calls.
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ assert_eq!(got, want);
+ }
+}