// SPDX-License-Identifier: MPL-2.0 //! Finite field arithmetic for any field GF(p) for which p < 2^128. #[cfg(test)] use rand::{prelude::*, Rng}; /// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots /// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This /// is the largest input size we would ever need for the cryptographic applications in this crate. pub(crate) const MAX_ROOTS: usize = 20; /// This structure represents the parameters of a finite field GF(p) for which p < 2^128. #[derive(Debug, PartialEq, Eq)] pub(crate) struct FieldParameters { /// The prime modulus `p`. pub p: u128, /// `mu = -p^(-1) mod 2^64`. pub mu: u64, /// `r2 = (2^128)^2 mod p`. pub r2: u128, /// The `2^num_roots`-th -principal root of unity. This element is used to generate the /// elements of `roots`. pub g: u128, /// The number of principal roots of unity in `roots`. pub num_roots: usize, /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. pub bit_mask: u128, /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the /// multiplicative group. `roots[0]` is equal to one by definition. pub roots: [u128; MAX_ROOTS + 1], } impl FieldParameters { /// Addition. The result will be in [0, p), so long as both x and y are as well. #[inline(always)] pub fn add(&self, x: u128, y: u128) -> u128 { // 0,x // + 0,y // ===== // c,z let (z, carry) = x.overflowing_add(y); // c, z // - 0, p // ======== // b1,s1,s0 let (s0, b0) = z.overflowing_sub(self.p); let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128); // if b1 == 1: return z // else: return s0 let m = 0u128.wrapping_sub(b1 as u128); (z & m) | (s0 & !m) } /// Subtraction. The result will be in [0, p), so long as both x and y are as well. #[inline(always)] pub fn sub(&self, x: u128, y: u128) -> u128 { // 0, x // - 0, y // ======== // b1,z1,z0 let (z0, b0) = x.overflowing_sub(y); let (_z1, b1) = 0u128.overflowing_sub(b0 as u128); let m = 0u128.wrapping_sub(b1 as u128); // z1,z0 // + 0, p // ======== // s1,s0 z0.wrapping_add(m & self.p) // if b1 == 1: return s0 // else: return z0 } /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm /// described /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf). /// The result will be in [0, p). /// /// # Example usage /// ```text /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); /// ``` #[inline(always)] pub fn mul(&self, x: u128, y: u128) -> u128 { let x = [lo64(x), hi64(x)]; let y = [lo64(y), hi64(y)]; let p = [lo64(self.p), hi64(self.p)]; let mut zz = [0; 4]; // Integer multiplication // z = x * y // x1,x0 // * y1,y0 // =========== // z3,z2,z1,z0 let mut result = x[0] * y[0]; let mut carry = hi64(result); zz[0] = lo64(result); result = x[0] * y[1]; let mut hi = hi64(result); let mut lo = lo64(result); result = lo + carry; zz[1] = lo64(result); let mut cc = hi64(result); result = hi + cc; zz[2] = lo64(result); result = x[1] * y[0]; hi = hi64(result); lo = lo64(result); result = zz[1] + lo; zz[1] = lo64(result); cc = hi64(result); result = hi + cc; carry = lo64(result); result = x[1] * y[1]; hi = hi64(result); lo = lo64(result); result = lo + carry; lo = lo64(result); cc = hi64(result); result = hi + cc; hi = lo64(result); result = zz[2] + lo; zz[2] = lo64(result); cc = hi64(result); result = hi + cc; zz[3] = lo64(result); // Montgomery Reduction // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. // z3,z2,z1,z0 // + p1,p0 // * w = mu*z0 // =========== // z3,z2,z1, 0 let w = self.mu.wrapping_mul(zz[0] as u64); result = p[0] * (w as u128); hi = hi64(result); lo = lo64(result); result = zz[0] + lo; zz[0] = lo64(result); cc = hi64(result); result = hi + cc; carry = lo64(result); result = p[1] * (w as u128); hi = hi64(result); lo = lo64(result); result = lo + carry; lo = lo64(result); cc = hi64(result); result = hi + cc; hi = lo64(result); result = zz[1] + lo; zz[1] = lo64(result); cc = hi64(result); result = zz[2] + hi + cc; zz[2] = lo64(result); cc = hi64(result); result = zz[3] + cc; zz[3] = lo64(result); // z3,z2,z1 // + p1,p0 // * w = mu*z1 // =========== // z3,z2, 0 let w = self.mu.wrapping_mul(zz[1] as u64); result = p[0] * (w as u128); hi = hi64(result); lo = lo64(result); result = zz[1] + lo; zz[1] = lo64(result); cc = hi64(result); result = hi + cc; carry = lo64(result); result = p[1] * (w as u128); hi = hi64(result); lo = lo64(result); result = lo + carry; lo = lo64(result); cc = hi64(result); result = hi + cc; hi = lo64(result); result = zz[2] + lo; zz[2] = lo64(result); cc = hi64(result); result = zz[3] + hi + cc; zz[3] = lo64(result); cc = hi64(result); // z = (z3,z2) let prod = zz[2] | (zz[3] << 64); // Final subtraction // If z >= p, then z = z - p // 0, z // - 0, p // ======== // b1,s1,s0 let (s0, b0) = prod.overflowing_sub(self.p); let (_s1, b1) = cc.overflowing_sub(b0 as u128); // if b1 == 1: return z // else: return s0 let mask = 0u128.wrapping_sub(b1 as u128); (prod & mask) | (s0 & !mask) } /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the /// runtime of this algorithm is linear in the bit length of `exp`. pub fn pow(&self, x: u128, exp: u128) -> u128 { let mut t = self.montgomery(1); for i in (0..128 - exp.leading_zeros()).rev() { t = self.mul(t, t); if (exp >> i) & 1 != 0 { t = self.mul(t, x); } } t } /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of /// this algorithm is linear in the bit length of `p`. #[inline(always)] pub fn inv(&self, x: u128) -> u128 { self.pow(x, self.p - 2) } /// Negation, i.e., `-x (mod p)` where `p` is the modulus. #[inline(always)] pub fn neg(&self, x: u128) -> u128 { self.sub(0, x) } /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery /// domain in order to carry out field arithmetic. The result will be in [0, p). /// /// # Example usage /// ```text /// let integer = 1; // Standard integer representation /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain /// assert_eq!(elem, 2564090464); /// ``` #[inline(always)] pub fn montgomery(&self, x: u128) -> u128 { modp(self.mul(x, self.r2), self.p) } /// Returns a random field element mapped. #[cfg(test)] pub fn rand_elem(&self, rng: &mut R) -> u128 { let uniform = rand::distributions::Uniform::from(0..self.p); self.montgomery(uniform.sample(rng)) } /// Maps a field element to its representation as an integer. The result will be in [0, p). /// /// #Example usage /// ```text /// let elem = 2564090464; // Internal representation in the Montgomery domain /// let integer = fp.residue(elem); // Standard integer representation /// assert_eq!(integer, 1); /// ``` #[inline(always)] pub fn residue(&self, x: u128) -> u128 { modp(self.mul(x, 1), self.p) } #[cfg(test)] pub fn check(&self, p: u128, g: u128, order: u128) { use modinverse::modinverse; use num_bigint::{BigInt, ToBigInt}; use std::cmp::max; assert_eq!(self.p, p, "p mismatch"); let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { Some(mu) => mu as u64, None => panic!("inverse of -p (mod 2^64) is undefined"), }; assert_eq!(self.mu, mu, "mu mismatch"); let big_p = &p.to_bigint().unwrap(); let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p); let big_r2: &BigInt = &(&(big_r * big_r) % big_p); let mut it = big_r2.iter_u64_digits(); let mut r2 = 0; r2 |= it.next().unwrap() as u128; if let Some(x) = it.next() { r2 |= (x as u128) << 64; } assert_eq!(self.r2, r2, "r2 mismatch"); assert_eq!(self.g, self.montgomery(g), "g mismatch"); assert_eq!( self.residue(self.pow(self.g, order)), 1, "g order incorrect" ); let num_roots = log2(order) as usize; assert_eq!(order, 1 << num_roots, "order not a power of 2"); assert_eq!(self.num_roots, num_roots, "num_roots mismatch"); let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; roots[num_roots] = self.montgomery(g); for i in (0..num_roots).rev() { roots[i] = self.mul(roots[i + 1], roots[i + 1]); } assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch"); assert_eq!(self.residue(self.roots[0]), 1, "first root is not one"); let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); assert_eq!( self.bit_mask.to_bigint().unwrap(), bit_mask, "bit_mask mismatch" ); } } #[inline(always)] fn lo64(x: u128) -> u128 { x & ((1 << 64) - 1) } #[inline(always)] fn hi64(x: u128) -> u128 { x >> 64 } #[inline(always)] fn modp(x: u128, p: u128) -> u128 { let (z, carry) = x.overflowing_sub(p); let m = 0u128.wrapping_sub(carry as u128); z.wrapping_add(m & p) } pub(crate) const FP32: FieldParameters = FieldParameters { p: 4293918721, // 32-bit prime mu: 17302828673139736575, r2: 1676699750, g: 1074114499, num_roots: 20, bit_mask: 4294967295, roots: [ 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, ], }; pub(crate) const FP64: FieldParameters = FieldParameters { p: 18446744069414584321, // 64-bit prime mu: 18446744069414584319, r2: 4294967295, g: 959634606461954525, num_roots: 32, bit_mask: 18446744073709551615, roots: [ 18446744065119617025, 4294967296, 18446462594437939201, 72057594037927936, 1152921504338411520, 16384, 18446743519658770561, 18446735273187346433, 6519596376689022014, 9996039020351967275, 15452408553935940313, 15855629130643256449, 8619522106083987867, 13036116919365988132, 1033106119984023956, 16593078884869787648, 16980581328500004402, 12245796497946355434, 8709441440702798460, 8611358103550827629, 8120528636261052110, ], }; pub(crate) const FP128: FieldParameters = FieldParameters { p: 340282366920938462946865773367900766209, // 128-bit prime mu: 18446744073709551615, r2: 403909908237944342183153, g: 107630958476043550189608038630704257141, num_roots: 66, bit_mask: 340282366920938463463374607431768211455, roots: [ 516508834063867445247, 340282366920938462430356939304033320962, 129526470195413442198896969089616959958, 169031622068548287099117778531474117974, 81612939378432101163303892927894236156, 122401220764524715189382260548353967708, 199453575871863981432000940507837456190, 272368408887745135168960576051472383806, 24863773656265022616993900367764287617, 257882853788779266319541142124730662203, 323732363244658673145040701829006542956, 57532865270871759635014308631881743007, 149571414409418047452773959687184934208, 177018931070866797456844925926211239962, 268896136799800963964749917185333891349, 244556960591856046954834420512544511831, 118945432085812380213390062516065622346, 202007153998709986841225284843501908420, 332677126194796691532164818746739771387, 258279638927684931537542082169183965856, 148221243758794364405224645520862378432, ], }; // Compute the ceiling of the base-2 logarithm of `x`. pub(crate) fn log2(x: u128) -> u128 { let y = (127 - x.leading_zeros()) as u128; y + ((x > 1 << y) as u128) } #[cfg(test)] mod tests { use super::*; use num_bigint::ToBigInt; #[test] fn test_log2() { assert_eq!(log2(1), 0); assert_eq!(log2(2), 1); assert_eq!(log2(3), 2); assert_eq!(log2(4), 2); assert_eq!(log2(15), 4); assert_eq!(log2(16), 4); assert_eq!(log2(30), 5); assert_eq!(log2(32), 5); assert_eq!(log2(1 << 127), 127); assert_eq!(log2((1 << 127) + 13), 128); } struct TestFieldParametersData { fp: FieldParameters, // The paramters being tested expected_p: u128, // Expected fp.p expected_g: u128, // Expected fp.residue(fp.g) expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1 } #[test] fn test_fp() { let test_fps = vec![ TestFieldParametersData { fp: FP32, expected_p: 4293918721, expected_g: 3925978153, expected_order: 1 << 20, }, TestFieldParametersData { fp: FP64, expected_p: 18446744069414584321, expected_g: 1753635133440165772, expected_order: 1 << 32, }, TestFieldParametersData { fp: FP128, expected_p: 340282366920938462946865773367900766209, expected_g: 145091266659756586618791329697897684742, expected_order: 1 << 66, }, ]; for t in test_fps.into_iter() { // Check that the field parameters have been constructed properly. t.fp.check(t.expected_p, t.expected_g, t.expected_order); // Check that the generator has the correct order. assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1); assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1); // Test arithmetic using the field parameters. arithmetic_test(&t.fp); } } fn arithmetic_test(fp: &FieldParameters) { let mut rng = rand::thread_rng(); let big_p = &fp.p.to_bigint().unwrap(); for _ in 0..100 { let x = fp.rand_elem(&mut rng); let y = fp.rand_elem(&mut rng); let big_x = &fp.residue(x).to_bigint().unwrap(); let big_y = &fp.residue(y).to_bigint().unwrap(); // Test addition. let got = fp.add(x, y); let want = (big_x + big_y) % big_p; assert_eq!(fp.residue(got).to_bigint().unwrap(), want); // Test subtraction. let got = fp.sub(x, y); let want = if big_x >= big_y { big_x - big_y } else { big_p - big_y + big_x }; assert_eq!(fp.residue(got).to_bigint().unwrap(), want); // Test multiplication. let got = fp.mul(x, y); let want = (big_x * big_y) % big_p; assert_eq!(fp.residue(got).to_bigint().unwrap(), want); // Test inversion. let got = fp.inv(x); let want = big_x.modpow(&(big_p - 2u128), big_p); assert_eq!(fp.residue(got).to_bigint().unwrap(), want); assert_eq!(fp.residue(fp.mul(got, x)), 1); // Test negation. let got = fp.neg(x); let want = (big_p - big_x) % big_p; assert_eq!(fp.residue(got).to_bigint().unwrap(), want); assert_eq!(fp.residue(fp.add(got, x)), 0); } } }