summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/fft.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/fft.rs')
-rw-r--r--third_party/rust/prio/src/fft.rs226
1 files changed, 226 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/fft.rs b/third_party/rust/prio/src/fft.rs
new file mode 100644
index 0000000000..c7a1dfbb8b
--- /dev/null
+++ b/third_party/rust/prio/src/fft.rs
@@ -0,0 +1,226 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
+//! Transform (DFT) over a slice of field elements.
+
+use crate::field::FieldElement;
+use crate::fp::{log2, MAX_ROOTS};
+
+use std::convert::TryFrom;
+
+/// An error returned by an FFT operation.
+#[derive(Debug, PartialEq, Eq, thiserror::Error)]
+pub enum FftError {
+ /// The output is too small.
+ #[error("output slice is smaller than specified size")]
+ OutputTooSmall,
+ /// The specified size is too large.
+ #[error("size is larger than than maximum permitted")]
+ SizeTooLarge,
+ /// The specified size is not a power of 2.
+ #[error("size is not a power of 2")]
+ SizeInvalid,
+}
+
+/// Sets `outp` to the DFT of `inp`.
+///
+/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input
+/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of
+/// unity.
+#[allow(clippy::many_single_char_names)]
+pub fn discrete_fourier_transform<F: FieldElement>(
+ outp: &mut [F],
+ inp: &[F],
+ size: usize,
+) -> Result<(), FftError> {
+ let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?;
+
+ if size > outp.len() {
+ return Err(FftError::OutputTooSmall);
+ }
+
+ if size > 1 << MAX_ROOTS {
+ return Err(FftError::SizeTooLarge);
+ }
+
+ if size != 1 << d {
+ return Err(FftError::SizeInvalid);
+ }
+
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..size {
+ let j = bitrev(d, i);
+ outp[i] = if j < inp.len() { inp[j] } else { F::zero() }
+ }
+
+ let mut w: F;
+ for l in 1..d + 1 {
+ w = F::one();
+ let r = F::root(l).unwrap();
+ let y = 1 << (l - 1);
+ for i in 0..y {
+ for j in 0..(size / y) >> 1 {
+ let x = (1 << l) * j + i;
+ let u = outp[x];
+ let v = w * outp[x + y];
+ outp[x] = u + v;
+ outp[x + y] = u - v;
+ }
+ w *= r;
+ }
+ }
+
+ Ok(())
+}
+
+/// Sets `outp` to the inverse of the DFT of `inp`.
+#[cfg(test)]
+pub(crate) fn discrete_fourier_transform_inv<F: FieldElement>(
+ outp: &mut [F],
+ inp: &[F],
+ size: usize,
+) -> Result<(), FftError> {
+ let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv();
+ discrete_fourier_transform(outp, inp, size)?;
+ discrete_fourier_transform_inv_finish(outp, size, size_inv);
+ Ok(())
+}
+
+/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to
+/// amortize the cost the modular inverse across multiple inverse DFT operations.
+pub(crate) fn discrete_fourier_transform_inv_finish<F: FieldElement>(
+ outp: &mut [F],
+ size: usize,
+ size_inv: F,
+) {
+ let mut tmp: F;
+ outp[0] *= size_inv;
+ outp[size >> 1] *= size_inv;
+ for i in 1..size >> 1 {
+ tmp = outp[i] * size_inv;
+ outp[i] = outp[size - i] * size_inv;
+ outp[size - i] = tmp;
+ }
+}
+
+// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
+fn bitrev(d: usize, x: usize) -> usize {
+ let mut y = 0;
+ for i in 0..d {
+ y += ((x >> i) & 1) << (d - i);
+ }
+ y >> 1
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{
+ random_vector, split_vector, Field128, Field32, Field64, Field96, FieldPrio2,
+ };
+ use crate::polynomial::{poly_fft, PolyAuxMemory};
+
+ fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
+ let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];
+
+ for size in test_sizes.iter() {
+ let mut tmp = vec![F::zero(); *size];
+ let mut got = vec![F::zero(); *size];
+ let want = random_vector(*size).unwrap();
+
+ discrete_fourier_transform(&mut tmp, &want, want.len())?;
+ discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?;
+ assert_eq!(got, want);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_field32() {
+ discrete_fourier_transform_then_inv_test::<Field32>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_priov2_field32() {
+ discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field64() {
+ discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field96() {
+ discrete_fourier_transform_then_inv_test::<Field96>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field128() {
+ discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_recursive_fft() {
+ let size = 128;
+ let mut mem = PolyAuxMemory::new(size / 2);
+
+ let inp = random_vector(size).unwrap();
+ let mut want = vec![Field32::zero(); size];
+ let mut got = vec![Field32::zero(); size];
+
+ discrete_fourier_transform::<Field32>(&mut want, &inp, inp.len()).unwrap();
+
+ poly_fft(
+ &mut got,
+ &inp,
+ &mem.roots_2n,
+ size,
+ false,
+ &mut mem.fft_memory,
+ );
+
+ assert_eq!(got, want);
+ }
+
+ // This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial
+ // over secret shares and summing up the coefficients is equivalent to interpolating a
+ // polynomial over the plaintext data.
+ #[test]
+ fn test_fft_linearity() {
+ let len = 16;
+ let num_shares = 3;
+ let x: Vec<Field64> = random_vector(len).unwrap();
+ let mut x_shares = split_vector(&x, num_shares).unwrap();
+
+ // Just for fun, let's do something different with a subset of the inputs. For the first
+ // share, every odd element is set to the plaintext value. For all shares but the first,
+ // every odd element is set to 0.
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..len {
+ if i % 2 != 0 {
+ x_shares[0][i] = x[i];
+ }
+ for j in 1..num_shares {
+ if i % 2 != 0 {
+ x_shares[j][i] = Field64::zero();
+ }
+ }
+ }
+
+ let mut got = vec![Field64::zero(); len];
+ let mut buf = vec![Field64::zero(); len];
+ for share in x_shares {
+ discrete_fourier_transform_inv(&mut buf, &share, len).unwrap();
+ for i in 0..len {
+ got[i] += buf[i];
+ }
+ }
+
+ let mut want = vec![Field64::zero(); len];
+ discrete_fourier_transform_inv(&mut want, &x, len).unwrap();
+
+ assert_eq!(got, want);
+ }
+}