diff options
Diffstat (limited to 'third_party/rust/prio/src/polynomial.rs')
-rw-r--r-- | third_party/rust/prio/src/polynomial.rs | 384 |
1 files changed, 384 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/polynomial.rs b/third_party/rust/prio/src/polynomial.rs new file mode 100644 index 0000000000..7c38341e36 --- /dev/null +++ b/third_party/rust/prio/src/polynomial.rs @@ -0,0 +1,384 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Functions for polynomial interpolation and evaluation + +use crate::field::FieldElement; + +use std::convert::TryFrom; + +/// Temporary memory used for FFT +#[derive(Clone, Debug)] +pub struct PolyFFTTempMemory<F> { + fft_tmp: Vec<F>, + fft_y_sub: Vec<F>, + fft_roots_sub: Vec<F>, +} + +impl<F: FieldElement> PolyFFTTempMemory<F> { + fn new(length: usize) -> Self { + PolyFFTTempMemory { + fft_tmp: vec![F::zero(); length], + fft_y_sub: vec![F::zero(); length], + fft_roots_sub: vec![F::zero(); length], + } + } +} + +/// Auxiliary memory for polynomial interpolation and evaluation +#[derive(Clone, Debug)] +pub struct PolyAuxMemory<F> { + pub roots_2n: Vec<F>, + pub roots_2n_inverted: Vec<F>, + pub roots_n: Vec<F>, + pub roots_n_inverted: Vec<F>, + pub coeffs: Vec<F>, + pub fft_memory: PolyFFTTempMemory<F>, +} + +impl<F: FieldElement> PolyAuxMemory<F> { + pub fn new(n: usize) -> Self { + PolyAuxMemory { + roots_2n: fft_get_roots(2 * n, false), + roots_2n_inverted: fft_get_roots(2 * n, true), + roots_n: fft_get_roots(n, false), + roots_n_inverted: fft_get_roots(n, true), + coeffs: vec![F::zero(); 2 * n], + fft_memory: PolyFFTTempMemory::new(2 * n), + } + } +} + +fn fft_recurse<F: FieldElement>( + out: &mut [F], + n: usize, + roots: &[F], + ys: &[F], + tmp: &mut [F], + y_sub: &mut [F], + roots_sub: &mut [F], +) { + if n == 1 { + out[0] = ys[0]; + return; + } + + let half_n = n / 2; + + let (tmp_first, tmp_second) = tmp.split_at_mut(half_n); + let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n); + let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n); + + // Recurse on the first half + for i in 0..half_n { + y_sub_first[i] = ys[i] + ys[i + half_n]; + roots_sub_first[i] = roots[2 * i]; + } + fft_recurse( + tmp_first, + half_n, + roots_sub_first, + y_sub_first, + tmp_second, + y_sub_second, + roots_sub_second, + ); + for i in 0..half_n { + out[2 * i] = tmp_first[i]; + } + + // Recurse on the second half + for i in 0..half_n { + y_sub_first[i] = ys[i] - ys[i + half_n]; + y_sub_first[i] *= roots[i]; + } + fft_recurse( + tmp_first, + half_n, + roots_sub_first, + y_sub_first, + tmp_second, + y_sub_second, + roots_sub_second, + ); + for i in 0..half_n { + out[2 * i + 1] = tmp[i]; + } +} + +/// Calculate `count` number of roots of unity of order `count` +fn fft_get_roots<F: FieldElement>(count: usize, invert: bool) -> Vec<F> { + let mut roots = vec![F::zero(); count]; + let mut gen = F::generator(); + if invert { + gen = gen.inv(); + } + + roots[0] = F::one(); + let step_size = F::generator_order() / F::Integer::try_from(count).unwrap(); + // generator for subgroup of order count + gen = gen.pow(step_size); + + roots[1] = gen; + + for i in 2..count { + roots[i] = gen * roots[i - 1]; + } + + roots +} + +fn fft_interpolate_raw<F: FieldElement>( + out: &mut [F], + ys: &[F], + n_points: usize, + roots: &[F], + invert: bool, + mem: &mut PolyFFTTempMemory<F>, +) { + fft_recurse( + out, + n_points, + roots, + ys, + &mut mem.fft_tmp, + &mut mem.fft_y_sub, + &mut mem.fft_roots_sub, + ); + if invert { + let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv(); + #[allow(clippy::needless_range_loop)] + for i in 0..n_points { + out[i] *= n_inverse; + } + } +} + +pub fn poly_fft<F: FieldElement>( + points_out: &mut [F], + points_in: &[F], + scaled_roots: &[F], + n_points: usize, + invert: bool, + mem: &mut PolyFFTTempMemory<F>, +) { + fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) +} + +// Evaluate a polynomial using Horner's method. +pub fn poly_eval<F: FieldElement>(poly: &[F], eval_at: F) -> F { + if poly.is_empty() { + return F::zero(); + } + + let mut result = poly[poly.len() - 1]; + for i in (0..poly.len() - 1).rev() { + result *= eval_at; + result += poly[i]; + } + + result +} + +// Returns the degree of polynomial `p`. +pub fn poly_deg<F: FieldElement>(p: &[F]) -> usize { + let mut d = p.len(); + while d > 0 && p[d - 1] == F::zero() { + d -= 1; + } + d.saturating_sub(1) +} + +// Multiplies polynomials `p` and `q` and returns the result. +pub fn poly_mul<F: FieldElement>(p: &[F], q: &[F]) -> Vec<F> { + let p_size = poly_deg(p) + 1; + let q_size = poly_deg(q) + 1; + let mut out = vec![F::zero(); p_size + q_size]; + for i in 0..p_size { + for j in 0..q_size { + out[i + j] += p[i] * q[j]; + } + } + out.truncate(poly_deg(&out) + 1); + out +} + +#[cfg(feature = "prio2")] +pub fn poly_interpret_eval<F: FieldElement>( + points: &[F], + roots: &[F], + eval_at: F, + tmp_coeffs: &mut [F], + fft_memory: &mut PolyFFTTempMemory<F>, +) -> F { + poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory); + poly_eval(&tmp_coeffs[..points.len()], eval_at) +} + +// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise, +// the output is not `0`. +pub(crate) fn poly_range_check<F: FieldElement>(start: usize, end: usize) -> Vec<F> { + let mut p = vec![F::one()]; + let mut q = [F::zero(), F::one()]; + for i in start..end { + q[0] = -F::from(F::Integer::try_from(i).unwrap()); + p = poly_mul(&p, &q); + } + p +} + +#[test] +fn test_roots() { + use crate::field::Field32; + + let count = 128; + let roots = fft_get_roots::<Field32>(count, false); + let roots_inv = fft_get_roots::<Field32>(count, true); + + for i in 0..count { + assert_eq!(roots[i] * roots_inv[i], 1); + assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1); + assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1); + } +} + +#[test] +fn test_eval() { + use crate::field::Field32; + + let mut poly = vec![Field32::from(0); 4]; + poly[0] = 2.into(); + poly[1] = 1.into(); + poly[2] = 5.into(); + // 5*3^2 + 3 + 2 = 50 + assert_eq!(poly_eval(&poly[..3], 3.into()), 50); + poly[3] = 4.into(); + // 4*3^3 + 5*3^2 + 3 + 2 = 158 + assert_eq!(poly_eval(&poly[..4], 3.into()), 158); +} + +#[test] +fn test_poly_deg() { + use crate::field::Field32; + + let zero = Field32::zero(); + let one = Field32::root(0).unwrap(); + assert_eq!(poly_deg(&[zero]), 0); + assert_eq!(poly_deg(&[one]), 0); + assert_eq!(poly_deg(&[zero, one]), 1); + assert_eq!(poly_deg(&[zero, zero, one]), 2); + assert_eq!(poly_deg(&[zero, one, one]), 2); + assert_eq!(poly_deg(&[zero, one, one, one]), 3); + assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3); + assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3); +} + +#[test] +fn test_poly_mul() { + use crate::field::Field64; + + let p = [ + Field64::from(u64::try_from(2).unwrap()), + Field64::from(u64::try_from(3).unwrap()), + ]; + + let q = [ + Field64::one(), + Field64::zero(), + Field64::from(u64::try_from(5).unwrap()), + ]; + + let want = [ + Field64::from(u64::try_from(2).unwrap()), + Field64::from(u64::try_from(3).unwrap()), + Field64::from(u64::try_from(10).unwrap()), + Field64::from(u64::try_from(15).unwrap()), + ]; + + let got = poly_mul(&p, &q); + assert_eq!(&got, &want); +} + +#[test] +fn test_poly_range_check() { + use crate::field::Field64; + + let start = 74; + let end = 112; + let p = poly_range_check(start, end); + + // Check each number in the range. + for i in start..end { + let x = Field64::from(i as u64); + let y = poly_eval(&p, x); + assert_eq!(y, Field64::zero(), "range check failed for {}", i); + } + + // Check the number below the range. + let x = Field64::from((start - 1) as u64); + let y = poly_eval(&p, x); + assert_ne!(y, Field64::zero()); + + // Check a number above the range. + let x = Field64::from(end as u64); + let y = poly_eval(&p, x); + assert_ne!(y, Field64::zero()); +} + +#[test] +fn test_fft() { + use crate::field::Field32; + + use rand::prelude::*; + use std::convert::TryFrom; + + let count = 128; + let mut mem = PolyAuxMemory::new(count / 2); + + let mut poly = vec![Field32::from(0); count]; + let mut points2 = vec![Field32::from(0); count]; + + let points = (0..count) + .into_iter() + .map(|_| Field32::from(random::<u32>())) + .collect::<Vec<Field32>>(); + + // From points to coeffs and back + poly_fft( + &mut poly, + &points, + &mem.roots_2n, + count, + false, + &mut mem.fft_memory, + ); + poly_fft( + &mut points2, + &poly, + &mem.roots_2n_inverted, + count, + true, + &mut mem.fft_memory, + ); + + assert_eq!(points, points2); + + // interpolation + poly_fft( + &mut poly, + &points, + &mem.roots_2n, + count, + false, + &mut mem.fft_memory, + ); + + #[allow(clippy::needless_range_loop)] + for i in 0..count { + let mut should_be = Field32::from(0); + for j in 0..count { + should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be; + } + assert_eq!(should_be, poly[i]); + } +} |