summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/polynomial.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/polynomial.rs')
-rw-r--r--third_party/rust/prio/src/polynomial.rs384
1 files changed, 384 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/polynomial.rs b/third_party/rust/prio/src/polynomial.rs
new file mode 100644
index 0000000000..7c38341e36
--- /dev/null
+++ b/third_party/rust/prio/src/polynomial.rs
@@ -0,0 +1,384 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Functions for polynomial interpolation and evaluation
+
+use crate::field::FieldElement;
+
+use std::convert::TryFrom;
+
+/// Temporary memory used for FFT
+#[derive(Clone, Debug)]
+pub struct PolyFFTTempMemory<F> {
+ fft_tmp: Vec<F>,
+ fft_y_sub: Vec<F>,
+ fft_roots_sub: Vec<F>,
+}
+
+impl<F: FieldElement> PolyFFTTempMemory<F> {
+ fn new(length: usize) -> Self {
+ PolyFFTTempMemory {
+ fft_tmp: vec![F::zero(); length],
+ fft_y_sub: vec![F::zero(); length],
+ fft_roots_sub: vec![F::zero(); length],
+ }
+ }
+}
+
+/// Auxiliary memory for polynomial interpolation and evaluation
+#[derive(Clone, Debug)]
+pub struct PolyAuxMemory<F> {
+ pub roots_2n: Vec<F>,
+ pub roots_2n_inverted: Vec<F>,
+ pub roots_n: Vec<F>,
+ pub roots_n_inverted: Vec<F>,
+ pub coeffs: Vec<F>,
+ pub fft_memory: PolyFFTTempMemory<F>,
+}
+
+impl<F: FieldElement> PolyAuxMemory<F> {
+ pub fn new(n: usize) -> Self {
+ PolyAuxMemory {
+ roots_2n: fft_get_roots(2 * n, false),
+ roots_2n_inverted: fft_get_roots(2 * n, true),
+ roots_n: fft_get_roots(n, false),
+ roots_n_inverted: fft_get_roots(n, true),
+ coeffs: vec![F::zero(); 2 * n],
+ fft_memory: PolyFFTTempMemory::new(2 * n),
+ }
+ }
+}
+
+fn fft_recurse<F: FieldElement>(
+ out: &mut [F],
+ n: usize,
+ roots: &[F],
+ ys: &[F],
+ tmp: &mut [F],
+ y_sub: &mut [F],
+ roots_sub: &mut [F],
+) {
+ if n == 1 {
+ out[0] = ys[0];
+ return;
+ }
+
+ let half_n = n / 2;
+
+ let (tmp_first, tmp_second) = tmp.split_at_mut(half_n);
+ let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n);
+ let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n);
+
+ // Recurse on the first half
+ for i in 0..half_n {
+ y_sub_first[i] = ys[i] + ys[i + half_n];
+ roots_sub_first[i] = roots[2 * i];
+ }
+ fft_recurse(
+ tmp_first,
+ half_n,
+ roots_sub_first,
+ y_sub_first,
+ tmp_second,
+ y_sub_second,
+ roots_sub_second,
+ );
+ for i in 0..half_n {
+ out[2 * i] = tmp_first[i];
+ }
+
+ // Recurse on the second half
+ for i in 0..half_n {
+ y_sub_first[i] = ys[i] - ys[i + half_n];
+ y_sub_first[i] *= roots[i];
+ }
+ fft_recurse(
+ tmp_first,
+ half_n,
+ roots_sub_first,
+ y_sub_first,
+ tmp_second,
+ y_sub_second,
+ roots_sub_second,
+ );
+ for i in 0..half_n {
+ out[2 * i + 1] = tmp[i];
+ }
+}
+
+/// Calculate `count` number of roots of unity of order `count`
+fn fft_get_roots<F: FieldElement>(count: usize, invert: bool) -> Vec<F> {
+ let mut roots = vec![F::zero(); count];
+ let mut gen = F::generator();
+ if invert {
+ gen = gen.inv();
+ }
+
+ roots[0] = F::one();
+ let step_size = F::generator_order() / F::Integer::try_from(count).unwrap();
+ // generator for subgroup of order count
+ gen = gen.pow(step_size);
+
+ roots[1] = gen;
+
+ for i in 2..count {
+ roots[i] = gen * roots[i - 1];
+ }
+
+ roots
+}
+
+fn fft_interpolate_raw<F: FieldElement>(
+ out: &mut [F],
+ ys: &[F],
+ n_points: usize,
+ roots: &[F],
+ invert: bool,
+ mem: &mut PolyFFTTempMemory<F>,
+) {
+ fft_recurse(
+ out,
+ n_points,
+ roots,
+ ys,
+ &mut mem.fft_tmp,
+ &mut mem.fft_y_sub,
+ &mut mem.fft_roots_sub,
+ );
+ if invert {
+ let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv();
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..n_points {
+ out[i] *= n_inverse;
+ }
+ }
+}
+
+pub fn poly_fft<F: FieldElement>(
+ points_out: &mut [F],
+ points_in: &[F],
+ scaled_roots: &[F],
+ n_points: usize,
+ invert: bool,
+ mem: &mut PolyFFTTempMemory<F>,
+) {
+ fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem)
+}
+
+// Evaluate a polynomial using Horner's method.
+pub fn poly_eval<F: FieldElement>(poly: &[F], eval_at: F) -> F {
+ if poly.is_empty() {
+ return F::zero();
+ }
+
+ let mut result = poly[poly.len() - 1];
+ for i in (0..poly.len() - 1).rev() {
+ result *= eval_at;
+ result += poly[i];
+ }
+
+ result
+}
+
+// Returns the degree of polynomial `p`.
+pub fn poly_deg<F: FieldElement>(p: &[F]) -> usize {
+ let mut d = p.len();
+ while d > 0 && p[d - 1] == F::zero() {
+ d -= 1;
+ }
+ d.saturating_sub(1)
+}
+
+// Multiplies polynomials `p` and `q` and returns the result.
+pub fn poly_mul<F: FieldElement>(p: &[F], q: &[F]) -> Vec<F> {
+ let p_size = poly_deg(p) + 1;
+ let q_size = poly_deg(q) + 1;
+ let mut out = vec![F::zero(); p_size + q_size];
+ for i in 0..p_size {
+ for j in 0..q_size {
+ out[i + j] += p[i] * q[j];
+ }
+ }
+ out.truncate(poly_deg(&out) + 1);
+ out
+}
+
+#[cfg(feature = "prio2")]
+pub fn poly_interpret_eval<F: FieldElement>(
+ points: &[F],
+ roots: &[F],
+ eval_at: F,
+ tmp_coeffs: &mut [F],
+ fft_memory: &mut PolyFFTTempMemory<F>,
+) -> F {
+ poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory);
+ poly_eval(&tmp_coeffs[..points.len()], eval_at)
+}
+
+// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise,
+// the output is not `0`.
+pub(crate) fn poly_range_check<F: FieldElement>(start: usize, end: usize) -> Vec<F> {
+ let mut p = vec![F::one()];
+ let mut q = [F::zero(), F::one()];
+ for i in start..end {
+ q[0] = -F::from(F::Integer::try_from(i).unwrap());
+ p = poly_mul(&p, &q);
+ }
+ p
+}
+
+#[test]
+fn test_roots() {
+ use crate::field::Field32;
+
+ let count = 128;
+ let roots = fft_get_roots::<Field32>(count, false);
+ let roots_inv = fft_get_roots::<Field32>(count, true);
+
+ for i in 0..count {
+ assert_eq!(roots[i] * roots_inv[i], 1);
+ assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1);
+ assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1);
+ }
+}
+
+#[test]
+fn test_eval() {
+ use crate::field::Field32;
+
+ let mut poly = vec![Field32::from(0); 4];
+ poly[0] = 2.into();
+ poly[1] = 1.into();
+ poly[2] = 5.into();
+ // 5*3^2 + 3 + 2 = 50
+ assert_eq!(poly_eval(&poly[..3], 3.into()), 50);
+ poly[3] = 4.into();
+ // 4*3^3 + 5*3^2 + 3 + 2 = 158
+ assert_eq!(poly_eval(&poly[..4], 3.into()), 158);
+}
+
+#[test]
+fn test_poly_deg() {
+ use crate::field::Field32;
+
+ let zero = Field32::zero();
+ let one = Field32::root(0).unwrap();
+ assert_eq!(poly_deg(&[zero]), 0);
+ assert_eq!(poly_deg(&[one]), 0);
+ assert_eq!(poly_deg(&[zero, one]), 1);
+ assert_eq!(poly_deg(&[zero, zero, one]), 2);
+ assert_eq!(poly_deg(&[zero, one, one]), 2);
+ assert_eq!(poly_deg(&[zero, one, one, one]), 3);
+ assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3);
+ assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3);
+}
+
+#[test]
+fn test_poly_mul() {
+ use crate::field::Field64;
+
+ let p = [
+ Field64::from(u64::try_from(2).unwrap()),
+ Field64::from(u64::try_from(3).unwrap()),
+ ];
+
+ let q = [
+ Field64::one(),
+ Field64::zero(),
+ Field64::from(u64::try_from(5).unwrap()),
+ ];
+
+ let want = [
+ Field64::from(u64::try_from(2).unwrap()),
+ Field64::from(u64::try_from(3).unwrap()),
+ Field64::from(u64::try_from(10).unwrap()),
+ Field64::from(u64::try_from(15).unwrap()),
+ ];
+
+ let got = poly_mul(&p, &q);
+ assert_eq!(&got, &want);
+}
+
+#[test]
+fn test_poly_range_check() {
+ use crate::field::Field64;
+
+ let start = 74;
+ let end = 112;
+ let p = poly_range_check(start, end);
+
+ // Check each number in the range.
+ for i in start..end {
+ let x = Field64::from(i as u64);
+ let y = poly_eval(&p, x);
+ assert_eq!(y, Field64::zero(), "range check failed for {}", i);
+ }
+
+ // Check the number below the range.
+ let x = Field64::from((start - 1) as u64);
+ let y = poly_eval(&p, x);
+ assert_ne!(y, Field64::zero());
+
+ // Check a number above the range.
+ let x = Field64::from(end as u64);
+ let y = poly_eval(&p, x);
+ assert_ne!(y, Field64::zero());
+}
+
+#[test]
+fn test_fft() {
+ use crate::field::Field32;
+
+ use rand::prelude::*;
+ use std::convert::TryFrom;
+
+ let count = 128;
+ let mut mem = PolyAuxMemory::new(count / 2);
+
+ let mut poly = vec![Field32::from(0); count];
+ let mut points2 = vec![Field32::from(0); count];
+
+ let points = (0..count)
+ .into_iter()
+ .map(|_| Field32::from(random::<u32>()))
+ .collect::<Vec<Field32>>();
+
+ // From points to coeffs and back
+ poly_fft(
+ &mut poly,
+ &points,
+ &mem.roots_2n,
+ count,
+ false,
+ &mut mem.fft_memory,
+ );
+ poly_fft(
+ &mut points2,
+ &poly,
+ &mem.roots_2n_inverted,
+ count,
+ true,
+ &mut mem.fft_memory,
+ );
+
+ assert_eq!(points, points2);
+
+ // interpolation
+ poly_fft(
+ &mut poly,
+ &points,
+ &mem.roots_2n,
+ count,
+ false,
+ &mut mem.fft_memory,
+ );
+
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..count {
+ let mut should_be = Field32::from(0);
+ for j in 0..count {
+ should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be;
+ }
+ assert_eq!(should_be, poly[i]);
+ }
+}