// SPDX-License-Identifier: MPL-2.0 //! A collection of gadgets. use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; use crate::field::FftFriendlyFieldElement; use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget}; use crate::polynomial::{poly_deg, poly_eval, poly_mul}; #[cfg(feature = "multithreaded")] use rayon::prelude::*; use std::any::Any; use std::convert::TryFrom; use std::fmt::Debug; use std::marker::PhantomData; /// For input polynomials larger than or equal to this threshold, gadgets will use FFT for /// polynomial multiplication. Otherwise, the gadget uses direct multiplication. const FFT_THRESHOLD: usize = 60; /// An arity-2 gadget that multiples its inputs. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Mul { /// Size of buffer for FFT operations. n: usize, /// Inverse of `n` in `F`. n_inv: F, /// The number of times this gadget will be called. num_calls: usize, } impl Mul { /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be /// called by the validity circuit. pub fn new(num_calls: usize) -> Self { let n = gadget_poly_fft_mem_len(2, num_calls); let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); Self { n, n_inv, num_calls, } } // Multiply input polynomials directly. pub(crate) fn call_poly_direct( &mut self, outp: &mut [F], inp: &[Vec], ) -> Result<(), FlpError> { let v = poly_mul(&inp[0], &inp[1]); outp[..v.len()].clone_from_slice(&v); Ok(()) } // Multiply input polynomials using FFT. pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { let n = self.n; let mut buf = vec![F::zero(); n]; discrete_fourier_transform(&mut buf, &inp[0], n)?; discrete_fourier_transform(outp, &inp[1], n)?; for i in 0..n { buf[i] *= outp[i]; } discrete_fourier_transform(outp, &buf, n)?; discrete_fourier_transform_inv_finish(outp, n, self.n_inv); Ok(()) } } impl Gadget for Mul { fn call(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; Ok(inp[0] * inp[1]) } fn call_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { gadget_call_poly_check(self, outp, inp)?; if inp[0].len() >= FFT_THRESHOLD { self.call_poly_fft(outp, inp) } else { self.call_poly_direct(outp, inp) } } fn arity(&self) -> usize { 2 } fn degree(&self) -> usize { 2 } fn calls(&self) -> usize { self.num_calls } fn as_any(&mut self) -> &mut dyn Any { self } } /// An arity-1 gadget that evaluates its input on some polynomial. // // TODO Make `poly` an array of length determined by a const generic. #[derive(Clone, Debug, Eq, PartialEq)] pub struct PolyEval { poly: Vec, /// Size of buffer for FFT operations. n: usize, /// Inverse of `n` in `F`. n_inv: F, /// The number of times this gadget will be called. num_calls: usize, } impl PolyEval { /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times /// this gadget is called by the validity circuit. pub fn new(poly: Vec, num_calls: usize) -> Self { let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls); let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); Self { poly, n, n_inv, num_calls, } } } impl PolyEval { // Multiply input polynomials directly. fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { outp[0] = self.poly[0]; let mut x = inp[0].to_vec(); for i in 1..self.poly.len() { for j in 0..x.len() { outp[j] += self.poly[i] * x[j]; } if i < self.poly.len() - 1 { x = poly_mul(&x, &inp[0]); } } Ok(()) } // Multiply input polynomials using FFT. fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { let n = self.n; let inp = &inp[0]; let mut inp_vals = vec![F::zero(); n]; discrete_fourier_transform(&mut inp_vals, inp, n)?; let mut x_vals = inp_vals.clone(); let mut x = vec![F::zero(); n]; x[..inp.len()].clone_from_slice(inp); outp[0] = self.poly[0]; for i in 1..self.poly.len() { for j in 0..n { outp[j] += self.poly[i] * x[j]; } if i < self.poly.len() - 1 { for j in 0..n { x_vals[j] *= inp_vals[j]; } discrete_fourier_transform(&mut x, &x_vals, n)?; discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv); } } Ok(()) } } impl Gadget for PolyEval { fn call(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; Ok(poly_eval(&self.poly, inp[0])) } fn call_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { gadget_call_poly_check(self, outp, inp)?; for item in outp.iter_mut() { *item = F::zero(); } if inp[0].len() >= FFT_THRESHOLD { self.call_poly_fft(outp, inp) } else { self.call_poly_direct(outp, inp) } } fn arity(&self) -> usize { 1 } fn degree(&self) -> usize { poly_deg(&self.poly) } fn calls(&self) -> usize { self.num_calls } fn as_any(&mut self) -> &mut dyn Any { self } } /// Trait for abstracting over [`ParallelSum`]. pub trait ParallelSumGadget: Gadget + Debug { /// Wraps `inner` into a sum gadget that calls it `chunks` many times, and adds the reuslts. fn new(inner: G, chunks: usize) -> Self; } /// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the /// outputs. The arity is equal to the arity of the inner gadget times the number of times it is /// called. #[derive(Clone, Debug, Eq, PartialEq)] pub struct ParallelSum> { inner: G, chunks: usize, phantom: PhantomData, } impl> ParallelSumGadget for ParallelSum { fn new(inner: G, chunks: usize) -> Self { Self { inner, chunks, phantom: PhantomData, } } } impl> Gadget for ParallelSum { fn call(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; let mut outp = F::zero(); for chunk in inp.chunks(self.inner.arity()) { outp += self.inner.call(chunk)?; } Ok(outp) } fn call_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { gadget_call_poly_check(self, outp, inp)?; for x in outp.iter_mut() { *x = F::zero(); } let mut partial_outp = vec![F::zero(); outp.len()]; for chunk in inp.chunks(self.inner.arity()) { self.inner.call_poly(&mut partial_outp, chunk)?; for i in 0..outp.len() { outp[i] += partial_outp[i] } } Ok(()) } fn arity(&self) -> usize { self.chunks * self.inner.arity() } fn degree(&self) -> usize { self.inner.degree() } fn calls(&self) -> usize { self.inner.calls() } fn as_any(&mut self) -> &mut dyn Any { self } } /// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the /// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum /// evaluation is multithreaded. #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] #[derive(Clone, Debug, Eq, PartialEq)] pub struct ParallelSumMultithreaded> { serial_sum: ParallelSum, } #[cfg(feature = "multithreaded")] impl ParallelSumGadget for ParallelSumMultithreaded where F: FftFriendlyFieldElement + Sync + Send, G: 'static + Gadget + Clone + Sync + Send, { fn new(inner: G, chunks: usize) -> Self { Self { serial_sum: ParallelSum::new(inner, chunks), } } } /// Data structures passed between fold operations in [`ParallelSumMultithreaded`]. #[cfg(feature = "multithreaded")] struct ParallelSumFoldState { /// Inner gadget. inner: G, /// Output buffer for `call_poly()`. partial_output: Vec, /// Sum accumulator. partial_sum: Vec, } #[cfg(feature = "multithreaded")] impl ParallelSumFoldState { fn new(gadget: &G, length: usize) -> ParallelSumFoldState where G: Clone, F: FftFriendlyFieldElement, { ParallelSumFoldState { inner: gadget.clone(), partial_output: vec![F::zero(); length], partial_sum: vec![F::zero(); length], } } } #[cfg(feature = "multithreaded")] impl Gadget for ParallelSumMultithreaded where F: FftFriendlyFieldElement + Sync + Send, G: 'static + Gadget + Clone + Sync + Send, { fn call(&mut self, inp: &[F]) -> Result { self.serial_sum.call(inp) } fn call_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { gadget_call_poly_check(self, outp, inp)?; // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the // gadget on each input polynomial, using the first temporary buffer as an output buffer. // Then accumulate that result into the second temporary buffer, which acts as a running // sum. Then, discard everything but the partial sums, add them, and finally copy the sum // to the output parameter. This is equivalent to the single threaded calculation in // ParallelSum, since we only rearrange additions, and field addition is associative. let res = inp .par_chunks(self.serial_sum.inner.arity()) .fold( || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()), |mut state, chunk| { state .inner .call_poly(&mut state.partial_output, chunk) .unwrap(); for (sum_elem, output_elem) in state .partial_sum .iter_mut() .zip(state.partial_output.iter()) { *sum_elem += *output_elem; } state }, ) .map(|state| state.partial_sum) .reduce( || vec![F::zero(); outp.len()], |mut x, y| { for (xi, yi) in x.iter_mut().zip(y.iter()) { *xi += *yi; } x }, ); outp.copy_from_slice(&res[..]); Ok(()) } fn arity(&self) -> usize { self.serial_sum.arity() } fn degree(&self) -> usize { self.serial_sum.degree() } fn calls(&self) -> usize { self.serial_sum.calls() } fn as_any(&mut self) -> &mut dyn Any { self } } // Check that the input parameters of g.call() are well-formed. fn gadget_call_check>( gadget: &G, in_len: usize, ) -> Result<(), FlpError> { if in_len != gadget.arity() { return Err(FlpError::Gadget(format!( "unexpected number of inputs: got {}; want {}", in_len, gadget.arity() ))); } if in_len == 0 { return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string())); } Ok(()) } // Check that the input parameters of g.call_poly() are well-formed. fn gadget_call_poly_check>( gadget: &G, outp: &[F], inp: &[Vec], ) -> Result<(), FlpError> where G: Gadget, { gadget_call_check(gadget, inp.len())?; for i in 1..inp.len() { if inp[i].len() != inp[0].len() { return Err(FlpError::Gadget( "gadget called on wire polynomials with different lengths".to_string(), )); } } let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two(); if outp.len() != expected { return Err(FlpError::Gadget(format!( "incorrect output length: got {}; want {}", outp.len(), expected ))); } Ok(()) } #[inline] fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize { gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two() } #[cfg(test)] mod tests { use super::*; #[cfg(feature = "multithreaded")] use crate::field::FieldElement; use crate::field::{random_vector, Field64 as TestField}; use crate::prng::Prng; #[test] fn test_mul() { // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the // naive multiplication code path. let num_calls = FFT_THRESHOLD / 2; let mut g: Mul = Mul::new(num_calls); gadget_test(&mut g, num_calls); // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises // FFT-based polynomial multiplication. let num_calls = FFT_THRESHOLD; let mut g: Mul = Mul::new(num_calls); gadget_test(&mut g, num_calls); } #[test] fn test_poly_eval() { let poly: Vec = random_vector(10).unwrap(); let num_calls = FFT_THRESHOLD / 2; let mut g: PolyEval = PolyEval::new(poly.clone(), num_calls); gadget_test(&mut g, num_calls); let num_calls = FFT_THRESHOLD; let mut g: PolyEval = PolyEval::new(poly, num_calls); gadget_test(&mut g, num_calls); } #[test] fn test_parallel_sum() { let num_calls = 10; let chunks = 23; let mut g = ParallelSum::new(Mul::::new(num_calls), chunks); gadget_test(&mut g, num_calls); } #[test] #[cfg(feature = "multithreaded")] fn test_parallel_sum_multithreaded() { use std::iter; for num_calls in [1, 10, 100] { let chunks = 23; let mut g = ParallelSumMultithreaded::new(Mul::new(num_calls), chunks); gadget_test(&mut g, num_calls); // Test that the multithreaded version has the same output as the normal version. let mut g_serial = ParallelSum::new(Mul::new(num_calls), chunks); assert_eq!(g.arity(), g_serial.arity()); assert_eq!(g.degree(), g_serial.degree()); assert_eq!(g.calls(), g_serial.calls()); let arity = g.arity(); let degree = g.degree(); // Test that both gadgets evaluate to the same value when run on scalar inputs. let inp: Vec = random_vector(arity).unwrap(); let result = g.call(&inp).unwrap(); let result_serial = g_serial.call(&inp).unwrap(); assert_eq!(result, result_serial); // Test that both gadgets evaluate to the same value when run on polynomial inputs. let mut poly_outp = vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; let mut poly_outp_serial = vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; let mut prng: Prng = Prng::new().unwrap(); let poly_inp: Vec<_> = iter::repeat_with(|| { iter::repeat_with(|| prng.get()) .take(1 + num_calls) .collect::>() }) .take(arity) .collect(); g.call_poly(&mut poly_outp, &poly_inp).unwrap(); g_serial .call_poly(&mut poly_outp_serial, &poly_inp) .unwrap(); assert_eq!(poly_outp, poly_outp_serial); } } // Test that calling g.call_poly() and evaluating the output at a given point is equivalent // to evaluating each of the inputs at the same point and applying g.call() on the results. fn gadget_test>(g: &mut G, num_calls: usize) { let wire_poly_len = (1 + num_calls).next_power_of_two(); let mut prng = Prng::new().unwrap(); let mut inp = vec![F::zero(); g.arity()]; let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)]; let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()]; let r = prng.get(); for i in 0..g.arity() { for j in 0..wire_poly_len { wire_polys[i][j] = prng.get(); } inp[i] = poly_eval(&wire_polys[i], r); } g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); let got = poly_eval(&gadget_poly, r); let want = g.call(&inp).unwrap(); assert_eq!(got, want); // Repeat the call to make sure that the gadget's memory is reset properly between calls. g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); let got = poly_eval(&gadget_poly, r); assert_eq!(got, want); } }