diff options
Diffstat (limited to 'third_party/rust/prio/src/fp.rs')
-rw-r--r-- | third_party/rust/prio/src/fp.rs | 561 |
1 files changed, 561 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/fp.rs b/third_party/rust/prio/src/fp.rs new file mode 100644 index 0000000000..d828fb7daf --- /dev/null +++ b/third_party/rust/prio/src/fp.rs @@ -0,0 +1,561 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic for any field GF(p) for which p < 2^128. + +#[cfg(test)] +use rand::{prelude::*, Rng}; + +/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots +/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This +/// is the largest input size we would ever need for the cryptographic applications in this crate. +pub(crate) const MAX_ROOTS: usize = 20; + +/// This structure represents the parameters of a finite field GF(p) for which p < 2^128. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct FieldParameters { + /// The prime modulus `p`. + pub p: u128, + /// `mu = -p^(-1) mod 2^64`. + pub mu: u64, + /// `r2 = (2^128)^2 mod p`. + pub r2: u128, + /// The `2^num_roots`-th -principal root of unity. This element is used to generate the + /// elements of `roots`. + pub g: u128, + /// The number of principal roots of unity in `roots`. + pub num_roots: usize, + /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. + pub bit_mask: u128, + /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the + /// multiplicative group. `roots[0]` is equal to one by definition. + pub roots: [u128; MAX_ROOTS + 1], +} + +impl FieldParameters { + /// Addition. The result will be in [0, p), so long as both x and y are as well. + pub fn add(&self, x: u128, y: u128) -> u128 { + // 0,x + // + 0,y + // ===== + // c,z + let (z, carry) = x.overflowing_add(y); + // c, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(self.p); + let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128); + // if b1 == 1: return z + // else: return s0 + let m = 0u128.wrapping_sub(b1 as u128); + (z & m) | (s0 & !m) + } + + /// Subtraction. The result will be in [0, p), so long as both x and y are as well. + pub fn sub(&self, x: u128, y: u128) -> u128 { + // 0, x + // - 0, y + // ======== + // b1,z1,z0 + let (z0, b0) = x.overflowing_sub(y); + let (_z1, b1) = 0u128.overflowing_sub(b0 as u128); + let m = 0u128.wrapping_sub(b1 as u128); + // z1,z0 + // + 0, p + // ======== + // s1,s0 + z0.wrapping_add(m & self.p) + // if b1 == 1: return s0 + // else: return z0 + } + + /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm + /// described + /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf). + /// The result will be in [0, p). + /// + /// # Example usage + /// ```text + /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); + /// ``` + pub fn mul(&self, x: u128, y: u128) -> u128 { + let x = [lo64(x), hi64(x)]; + let y = [lo64(y), hi64(y)]; + let p = [lo64(self.p), hi64(self.p)]; + let mut zz = [0; 4]; + + // Integer multiplication + // z = x * y + + // x1,x0 + // * y1,y0 + // =========== + // z3,z2,z1,z0 + let mut result = x[0] * y[0]; + let mut carry = hi64(result); + zz[0] = lo64(result); + result = x[0] * y[1]; + let mut hi = hi64(result); + let mut lo = lo64(result); + result = lo + carry; + zz[1] = lo64(result); + let mut cc = hi64(result); + result = hi + cc; + zz[2] = lo64(result); + + result = x[1] * y[0]; + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = x[1] * y[1]; + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = hi + cc; + zz[3] = lo64(result); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. + + // z3,z2,z1,z0 + // + p1,p0 + // * w = mu*z0 + // =========== + // z3,z2,z1, 0 + let w = self.mu.wrapping_mul(zz[0] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[0] + lo; + zz[0] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = zz[2] + hi + cc; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + cc; + zz[3] = lo64(result); + + // z3,z2,z1 + // + p1,p0 + // * w = mu*z1 + // =========== + // z3,z2, 0 + let w = self.mu.wrapping_mul(zz[1] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + hi + cc; + zz[3] = lo64(result); + cc = hi64(result); + + // z = (z3,z2) + let prod = zz[2] | (zz[3] << 64); + + // Final subtraction + // If z >= p, then z = z - p + + // 0, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = prod.overflowing_sub(self.p); + let (_s1, b1) = (cc as u128).overflowing_sub(b0 as u128); + // if b1 == 1: return z + // else: return s0 + let mask = 0u128.wrapping_sub(b1 as u128); + (prod & mask) | (s0 & !mask) + } + + /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the + /// runtime of this algorithm is linear in the bit length of `exp`. + pub fn pow(&self, x: u128, exp: u128) -> u128 { + let mut t = self.montgomery(1); + for i in (0..128 - exp.leading_zeros()).rev() { + t = self.mul(t, t); + if (exp >> i) & 1 != 0 { + t = self.mul(t, x); + } + } + t + } + + /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of + /// this algorithm is linear in the bit length of `p`. + pub fn inv(&self, x: u128) -> u128 { + self.pow(x, self.p - 2) + } + + /// Negation, i.e., `-x (mod p)` where `p` is the modulus. + pub fn neg(&self, x: u128) -> u128 { + self.sub(0, x) + } + + /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery + /// domain in order to carry out field arithmetic. The result will be in [0, p). + /// + /// # Example usage + /// ```text + /// let integer = 1; // Standard integer representation + /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain + /// assert_eq!(elem, 2564090464); + /// ``` + pub fn montgomery(&self, x: u128) -> u128 { + modp(self.mul(x, self.r2), self.p) + } + + /// Returns a random field element mapped. + #[cfg(test)] + pub fn rand_elem<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 { + let uniform = rand::distributions::Uniform::from(0..self.p); + self.montgomery(uniform.sample(rng)) + } + + /// Maps a field element to its representation as an integer. The result will be in [0, p). + /// + /// #Example usage + /// ```text + /// let elem = 2564090464; // Internal representation in the Montgomery domain + /// let integer = fp.residue(elem); // Standard integer representation + /// assert_eq!(integer, 1); + /// ``` + pub fn residue(&self, x: u128) -> u128 { + modp(self.mul(x, 1), self.p) + } + + #[cfg(test)] + pub fn check(&self, p: u128, g: u128, order: u128) { + use modinverse::modinverse; + use num_bigint::{BigInt, ToBigInt}; + use std::cmp::max; + + assert_eq!(self.p, p, "p mismatch"); + + let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { + Some(mu) => mu as u64, + None => panic!("inverse of -p (mod 2^64) is undefined"), + }; + assert_eq!(self.mu, mu, "mu mismatch"); + + let big_p = &p.to_bigint().unwrap(); + let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p); + let big_r2: &BigInt = &(&(big_r * big_r) % big_p); + let mut it = big_r2.iter_u64_digits(); + let mut r2 = 0; + r2 |= it.next().unwrap() as u128; + if let Some(x) = it.next() { + r2 |= (x as u128) << 64; + } + assert_eq!(self.r2, r2, "r2 mismatch"); + + assert_eq!(self.g, self.montgomery(g), "g mismatch"); + assert_eq!( + self.residue(self.pow(self.g, order)), + 1, + "g order incorrect" + ); + + let num_roots = log2(order) as usize; + assert_eq!(order, 1 << num_roots, "order not a power of 2"); + assert_eq!(self.num_roots, num_roots, "num_roots mismatch"); + + let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; + roots[num_roots] = self.montgomery(g); + for i in (0..num_roots).rev() { + roots[i] = self.mul(roots[i + 1], roots[i + 1]); + } + assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch"); + assert_eq!(self.residue(self.roots[0]), 1, "first root is not one"); + + let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); + assert_eq!( + self.bit_mask.to_bigint().unwrap(), + bit_mask, + "bit_mask mismatch" + ); + } +} + +fn lo64(x: u128) -> u128 { + x & ((1 << 64) - 1) +} + +fn hi64(x: u128) -> u128 { + x >> 64 +} + +fn modp(x: u128, p: u128) -> u128 { + let (z, carry) = x.overflowing_sub(p); + let m = 0u128.wrapping_sub(carry as u128); + z.wrapping_add(m & p) +} + +pub(crate) const FP32: FieldParameters = FieldParameters { + p: 4293918721, // 32-bit prime + mu: 17302828673139736575, + r2: 1676699750, + g: 1074114499, + num_roots: 20, + bit_mask: 4294967295, + roots: [ + 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, + 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, + 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, + ], +}; + +pub(crate) const FP64: FieldParameters = FieldParameters { + p: 18446744069414584321, // 64-bit prime + mu: 18446744069414584319, + r2: 4294967295, + g: 959634606461954525, + num_roots: 32, + bit_mask: 18446744073709551615, + roots: [ + 18446744065119617025, + 4294967296, + 18446462594437939201, + 72057594037927936, + 1152921504338411520, + 16384, + 18446743519658770561, + 18446735273187346433, + 6519596376689022014, + 9996039020351967275, + 15452408553935940313, + 15855629130643256449, + 8619522106083987867, + 13036116919365988132, + 1033106119984023956, + 16593078884869787648, + 16980581328500004402, + 12245796497946355434, + 8709441440702798460, + 8611358103550827629, + 8120528636261052110, + ], +}; + +pub(crate) const FP96: FieldParameters = FieldParameters { + p: 79228148845226978974766202881, // 96-bit prime + mu: 18446744073709551615, + r2: 69162923446439011319006025217, + g: 11329412859948499305522312170, + num_roots: 64, + bit_mask: 79228162514264337593543950335, + roots: [ + 10128756682736510015896859, + 79218020088544242464750306022, + 9188608122889034248261485869, + 10170869429050723924726258983, + 36379376833245035199462139324, + 20898601228930800484072244511, + 2845758484723985721473442509, + 71302585629145191158180162028, + 76552499132904394167108068662, + 48651998692455360626769616967, + 36570983454832589044179852640, + 72716740645782532591407744342, + 73296872548531908678227377531, + 14831293153408122430659535205, + 61540280632476003580389854060, + 42256269782069635955059793151, + 51673352890110285959979141934, + 43102967204983216507957944322, + 3990455111079735553382399289, + 68042997008257313116433801954, + 44344622755749285146379045633, + ], +}; + +pub(crate) const FP128: FieldParameters = FieldParameters { + p: 340282366920938462946865773367900766209, // 128-bit prime + mu: 18446744073709551615, + r2: 403909908237944342183153, + g: 107630958476043550189608038630704257141, + num_roots: 66, + bit_mask: 340282366920938463463374607431768211455, + roots: [ + 516508834063867445247, + 340282366920938462430356939304033320962, + 129526470195413442198896969089616959958, + 169031622068548287099117778531474117974, + 81612939378432101163303892927894236156, + 122401220764524715189382260548353967708, + 199453575871863981432000940507837456190, + 272368408887745135168960576051472383806, + 24863773656265022616993900367764287617, + 257882853788779266319541142124730662203, + 323732363244658673145040701829006542956, + 57532865270871759635014308631881743007, + 149571414409418047452773959687184934208, + 177018931070866797456844925926211239962, + 268896136799800963964749917185333891349, + 244556960591856046954834420512544511831, + 118945432085812380213390062516065622346, + 202007153998709986841225284843501908420, + 332677126194796691532164818746739771387, + 258279638927684931537542082169183965856, + 148221243758794364405224645520862378432, + ], +}; + +// Compute the ceiling of the base-2 logarithm of `x`. +pub(crate) fn log2(x: u128) -> u128 { + let y = (127 - x.leading_zeros()) as u128; + y + ((x > 1 << y) as u128) +} + +#[cfg(test)] +mod tests { + use super::*; + use num_bigint::ToBigInt; + + #[test] + fn test_log2() { + assert_eq!(log2(1), 0); + assert_eq!(log2(2), 1); + assert_eq!(log2(3), 2); + assert_eq!(log2(4), 2); + assert_eq!(log2(15), 4); + assert_eq!(log2(16), 4); + assert_eq!(log2(30), 5); + assert_eq!(log2(32), 5); + assert_eq!(log2(1 << 127), 127); + assert_eq!(log2((1 << 127) + 13), 128); + } + + struct TestFieldParametersData { + fp: FieldParameters, // The paramters being tested + expected_p: u128, // Expected fp.p + expected_g: u128, // Expected fp.residue(fp.g) + expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1 + } + + #[test] + fn test_fp() { + let test_fps = vec![ + TestFieldParametersData { + fp: FP32, + expected_p: 4293918721, + expected_g: 3925978153, + expected_order: 1 << 20, + }, + TestFieldParametersData { + fp: FP64, + expected_p: 18446744069414584321, + expected_g: 1753635133440165772, + expected_order: 1 << 32, + }, + TestFieldParametersData { + fp: FP96, + expected_p: 79228148845226978974766202881, + expected_g: 34233996298771126927060021012, + expected_order: 1 << 64, + }, + TestFieldParametersData { + fp: FP128, + expected_p: 340282366920938462946865773367900766209, + expected_g: 145091266659756586618791329697897684742, + expected_order: 1 << 66, + }, + ]; + + for t in test_fps.into_iter() { + // Check that the field parameters have been constructed properly. + t.fp.check(t.expected_p, t.expected_g, t.expected_order); + + // Check that the generator has the correct order. + assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1); + assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1); + + // Test arithmetic using the field parameters. + arithmetic_test(&t.fp); + } + } + + fn arithmetic_test(fp: &FieldParameters) { + let mut rng = rand::thread_rng(); + let big_p = &fp.p.to_bigint().unwrap(); + + for _ in 0..100 { + let x = fp.rand_elem(&mut rng); + let y = fp.rand_elem(&mut rng); + let big_x = &fp.residue(x).to_bigint().unwrap(); + let big_y = &fp.residue(y).to_bigint().unwrap(); + + // Test addition. + let got = fp.add(x, y); + let want = (big_x + big_y) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test subtraction. + let got = fp.sub(x, y); + let want = if big_x >= big_y { + big_x - big_y + } else { + big_p - big_y + big_x + }; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test multiplication. + let got = fp.mul(x, y); + let want = (big_x * big_y) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test inversion. + let got = fp.inv(x); + let want = big_x.modpow(&(big_p - 2u128), big_p); + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + assert_eq!(fp.residue(fp.mul(got, x)), 1); + + // Test negation. + let got = fp.neg(x); + let want = (big_p - big_x) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + assert_eq!(fp.residue(fp.add(got, x)), 0); + } + } +} |