diff options
Diffstat (limited to 'third_party/rust/prio/src')
26 files changed, 11069 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/benchmarked.rs b/third_party/rust/prio/src/benchmarked.rs new file mode 100644 index 0000000000..8811250f9a --- /dev/null +++ b/third_party/rust/prio/src/benchmarked.rs @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module provides wrappers around internal components of this crate that we want to +//! benchmark, but which we don't want to expose in the public API. + +#[cfg(feature = "prio2")] +use crate::client::Client; +use crate::fft::discrete_fourier_transform; +use crate::field::FieldElement; +use crate::flp::gadgets::Mul; +use crate::flp::FlpError; +use crate::polynomial::{poly_fft, PolyAuxMemory}; + +/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. +pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) { + discrete_fourier_transform(outp, inp, inp.len()).unwrap(); +} + +/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. +pub fn benchmarked_recursive_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) { + let mut mem = PolyAuxMemory::new(inp.len() / 2); + poly_fft( + outp, + inp, + &mem.roots_2n, + inp.len(), + false, + &mut mem.fft_memory, + ) +} + +/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function +/// uses FFT for multiplication. +pub fn benchmarked_gadget_mul_call_poly_fft<F: FieldElement>( + g: &mut Mul<F>, + outp: &mut [F], + inp: &[Vec<F>], +) -> Result<(), FlpError> { + g.call_poly_fft(outp, inp) +} + +/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function +/// does the multiplication directly. +pub fn benchmarked_gadget_mul_call_poly_direct<F: FieldElement>( + g: &mut Mul<F>, + outp: &mut [F], + inp: &[Vec<F>], +) -> Result<(), FlpError> { + g.call_poly_direct(outp, inp) +} + +/// Returns a Prio v2 proof that `data` is a valid boolean vector. +#[cfg(feature = "prio2")] +pub fn benchmarked_v2_prove<F: FieldElement>(data: &[F], client: &mut Client<F>) -> Vec<F> { + client.gen_proof(data) +} diff --git a/third_party/rust/prio/src/client.rs b/third_party/rust/prio/src/client.rs new file mode 100644 index 0000000000..3ed5ee66c7 --- /dev/null +++ b/third_party/rust/prio/src/client.rs @@ -0,0 +1,264 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! The Prio v2 client. Only 0 / 1 vectors are supported for now. + +use crate::{ + encrypt::{encrypt_share, EncryptError, PublicKey}, + field::FieldElement, + polynomial::{poly_fft, PolyAuxMemory}, + prng::{Prng, PrngError}, + util::{proof_length, unpack_proof_mut}, + vdaf::{prg::SeedStreamAes128, VdafError}, +}; + +use std::convert::TryFrom; + +/// The main object that can be used to create Prio shares +/// +/// Client is used to create Prio shares. +#[derive(Debug)] +pub struct Client<F: FieldElement> { + dimension: usize, + mem: ClientMemory<F>, + public_key1: PublicKey, + public_key2: PublicKey, +} + +/// Errors that might be emitted by the client. +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + /// Encryption/decryption error + #[error("encryption/decryption error")] + Encrypt(#[from] EncryptError), + /// PRNG error + #[error("prng error: {0}")] + Prng(#[from] PrngError), + /// VDAF error + #[error("vdaf error: {0}")] + Vdaf(#[from] VdafError), +} + +impl<F: FieldElement> Client<F> { + /// Construct a new Prio client + pub fn new( + dimension: usize, + public_key1: PublicKey, + public_key2: PublicKey, + ) -> Result<Self, ClientError> { + Ok(Client { + dimension, + mem: ClientMemory::new(dimension)?, + public_key1, + public_key2, + }) + } + + /// Construct a pair of encrypted shares based on the input data. + pub fn encode_simple(&mut self, data: &[F]) -> Result<(Vec<u8>, Vec<u8>), ClientError> { + let copy_data = |share_data: &mut [F]| { + share_data[..].clone_from_slice(data); + }; + Ok(self.encode_with(copy_data)?) + } + + /// Construct a pair of encrypted shares using a initilization function. + /// + /// This might be slightly more efficient on large vectors, because one can + /// avoid copying the input data. + pub fn encode_with<G>(&mut self, init_function: G) -> Result<(Vec<u8>, Vec<u8>), EncryptError> + where + G: FnOnce(&mut [F]), + { + let mut proof = self.mem.prove_with(self.dimension, init_function); + + // use prng to share the proof: share2 is the PRNG seed, and proof is mutated + // in-place + let mut share2 = [0; 32]; + getrandom::getrandom(&mut share2)?; + let share2_prng = Prng::from_prio2_seed(&share2); + for (s1, d) in proof.iter_mut().zip(share2_prng.into_iter()) { + *s1 -= d; + } + let share1 = F::slice_into_byte_vec(&proof); + // encrypt shares with respective keys + let encrypted_share1 = encrypt_share(&share1, &self.public_key1)?; + let encrypted_share2 = encrypt_share(&share2, &self.public_key2)?; + Ok((encrypted_share1, encrypted_share2)) + } + + /// Generate a proof of the input's validity. The output is the encoded input and proof. + pub(crate) fn gen_proof(&mut self, input: &[F]) -> Vec<F> { + let copy_data = |share_data: &mut [F]| { + share_data[..].clone_from_slice(input); + }; + self.mem.prove_with(self.dimension, copy_data) + } +} + +#[derive(Debug)] +pub(crate) struct ClientMemory<F> { + prng: Prng<F, SeedStreamAes128>, + points_f: Vec<F>, + points_g: Vec<F>, + evals_f: Vec<F>, + evals_g: Vec<F>, + poly_mem: PolyAuxMemory<F>, +} + +impl<F: FieldElement> ClientMemory<F> { + pub(crate) fn new(dimension: usize) -> Result<Self, VdafError> { + let n = (dimension + 1).next_power_of_two(); + if let Ok(size) = F::Integer::try_from(2 * n) { + if size > F::generator_order() { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + } else { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + + Ok(Self { + prng: Prng::new()?, + points_f: vec![F::zero(); n], + points_g: vec![F::zero(); n], + evals_f: vec![F::zero(); 2 * n], + evals_g: vec![F::zero(); 2 * n], + poly_mem: PolyAuxMemory::new(n), + }) + } +} + +impl<F: FieldElement> ClientMemory<F> { + pub(crate) fn prove_with<G>(&mut self, dimension: usize, init_function: G) -> Vec<F> + where + G: FnOnce(&mut [F]), + { + let mut proof = vec![F::zero(); proof_length(dimension)]; + // unpack one long vector to different subparts + let unpacked = unpack_proof_mut(&mut proof, dimension).unwrap(); + // initialize the data part + init_function(unpacked.data); + // fill in the rest + construct_proof( + unpacked.data, + dimension, + unpacked.f0, + unpacked.g0, + unpacked.h0, + unpacked.points_h_packed, + self, + ); + + proof + } +} + +/// Convenience function if one does not want to reuse +/// [`Client`](struct.Client.html). +pub fn encode_simple<F: FieldElement>( + data: &[F], + public_key1: PublicKey, + public_key2: PublicKey, +) -> Result<(Vec<u8>, Vec<u8>), ClientError> { + let dimension = data.len(); + let mut client_memory = Client::new(dimension, public_key1, public_key2)?; + client_memory.encode_simple(data) +} + +fn interpolate_and_evaluate_at_2n<F: FieldElement>( + n: usize, + points_in: &[F], + evals_out: &mut [F], + mem: &mut PolyAuxMemory<F>, +) { + // interpolate through roots of unity + poly_fft( + &mut mem.coeffs, + points_in, + &mem.roots_n_inverted, + n, + true, + &mut mem.fft_memory, + ); + // evaluate at 2N roots of unity + poly_fft( + evals_out, + &mem.coeffs, + &mem.roots_2n, + 2 * n, + false, + &mut mem.fft_memory, + ); +} + +/// Proof construction +/// +/// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation +/// This constructs the output \pi by doing the necessesary calculations +fn construct_proof<F: FieldElement>( + data: &[F], + dimension: usize, + f0: &mut F, + g0: &mut F, + h0: &mut F, + points_h_packed: &mut [F], + mem: &mut ClientMemory<F>, +) { + let n = (dimension + 1).next_power_of_two(); + + // set zero terms to random + *f0 = mem.prng.get(); + *g0 = mem.prng.get(); + mem.points_f[0] = *f0; + mem.points_g[0] = *g0; + + // set zero term for the proof polynomial + *h0 = *f0 * *g0; + + // set f_i = data_(i - 1) + // set g_i = f_i - 1 + #[allow(clippy::needless_range_loop)] + for i in 0..dimension { + mem.points_f[i + 1] = data[i]; + mem.points_g[i + 1] = data[i] - F::one(); + } + + // interpolate and evaluate at roots of unity + interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem); + interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem); + + // calculate the proof polynomial as evals_f(r) * evals_g(r) + // only add non-zero points + let mut j: usize = 0; + let mut i: usize = 1; + while i < 2 * n { + points_h_packed[j] = mem.evals_f[i] * mem.evals_g[i]; + j += 1; + i += 2; + } +} + +#[test] +fn test_encode() { + use crate::field::Field32; + let pub_key1 = PublicKey::from_base64( + "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=", + ) + .unwrap(); + let pub_key2 = PublicKey::from_base64( + "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LE=", + ) + .unwrap(); + + let data_u32 = [0u32, 1, 0, 1, 1, 0, 0, 0, 1]; + let data = data_u32 + .iter() + .map(|x| Field32::from(*x)) + .collect::<Vec<Field32>>(); + let encoded_shares = encode_simple(&data, pub_key1, pub_key2); + assert!(encoded_shares.is_ok()); +} diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs new file mode 100644 index 0000000000..c3ee8e9db7 --- /dev/null +++ b/third_party/rust/prio/src/codec.rs @@ -0,0 +1,658 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Module `codec` provides support for encoding and decoding messages to or from the TLS wire +//! encoding, as specified in [RFC 8446, Section 3][1]. It provides traits that can be implemented +//! on values that need to be encoded or decoded, as well as utility functions for encoding +//! sequences of values. +//! +//! [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3 + +use byteorder::{BigEndian, ReadBytesExt}; +use std::{ + convert::TryInto, + error::Error, + io::{Cursor, Read}, + mem::size_of, + num::TryFromIntError, +}; + +#[allow(missing_docs)] +#[derive(Debug, thiserror::Error)] +pub enum CodecError { + #[error("I/O error")] + Io(#[from] std::io::Error), + #[error("{0} bytes left in buffer after decoding value")] + BytesLeftOver(usize), + #[error("length prefix of encoded vector overflows buffer: {0}")] + LengthPrefixTooBig(usize), + #[error("other error: {0}")] + Other(#[source] Box<dyn Error + 'static + Send + Sync>), + #[error("unexpected value")] + UnexpectedValue, +} + +/// Describes how to decode an object from a byte sequence. +pub trait Decode: Sized { + /// Read and decode an encoded object from `bytes`. On success, the decoded value is returned + /// and `bytes` is advanced by the encoded size of the value. On failure, an error is returned + /// and no further attempt to read from `bytes` should be made. + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>; + + /// Convenience method to get decoded value. Returns an error if [`Self::decode`] fails, or if + /// there are any bytes left in `bytes` after decoding a value. + fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> { + Self::get_decoded_with_param(&(), bytes) + } +} + +/// Describes how to decode an object from a byte sequence, with a decoding parameter provided to +/// provide additional data used in decoding. +pub trait ParameterizedDecode<P>: Sized { + /// Read and decode an encoded object from `bytes`. `decoding_parameter` provides details of the + /// wire encoding such as lengths of different portions of the message. On success, the decoded + /// value is returned and `bytes` is advanced by the encoded size of the value. On failure, an + /// error is returned and no further attempt to read from `bytes` should be made. + fn decode_with_param( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError>; + + /// Convenience method to get decoded value. Returns an error if [`Self::decode_with_param`] + /// fails, or if there are any bytes left in `bytes` after decoding a value. + fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> { + let mut cursor = Cursor::new(bytes); + let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?; + if cursor.position() as usize != bytes.len() { + return Err(CodecError::BytesLeftOver( + bytes.len() - cursor.position() as usize, + )); + } + + Ok(decoded) + } +} + +// Provide a blanket implementation so that any Decode can be used as a ParameterizedDecode<T> for +// any T. +impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D { + fn decode_with_param( + _decoding_parameter: &T, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Self::decode(bytes) + } +} + +/// Describes how to encode objects into a byte sequence. +pub trait Encode { + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + fn encode(&self, bytes: &mut Vec<u8>); + + /// Convenience method to get encoded value. + fn get_encoded(&self) -> Vec<u8> { + self.get_encoded_with_param(&()) + } +} + +/// Describes how to encode objects into a byte sequence. +pub trait ParameterizedEncode<P> { + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + /// `encoding_parameter` provides details of the wire encoding, used to control how the value + /// is encoded. + fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>); + + /// Convenience method to get encoded value. + fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> { + let mut ret = Vec::new(); + self.encode_with_param(encoding_parameter, &mut ret); + ret + } +} + +// Provide a blanket implementation so that any Encode can be used as a ParameterizedEncode<T> for +// any T. +impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E { + fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) { + self.encode(bytes) + } +} + +impl Decode for () { + fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(()) + } +} + +impl Encode for () { + fn encode(&self, _bytes: &mut Vec<u8>) {} +} + +impl Decode for u8 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut value = [0u8; size_of::<u8>()]; + bytes.read_exact(&mut value)?; + Ok(value[0]) + } +} + +impl Encode for u8 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.push(*self); + } +} + +impl Decode for u16 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(bytes.read_u16::<BigEndian>()?) + } +} + +impl Encode for u16 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&u16::to_be_bytes(*self)); + } +} + +/// 24 bit integer, per +/// [RFC 8443, section 3.3](https://datatracker.ietf.org/doc/html/rfc8446#section-3.3) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct U24(pub u32); + +impl Decode for U24 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(U24(bytes.read_u24::<BigEndian>()?)) + } +} + +impl Encode for U24 { + fn encode(&self, bytes: &mut Vec<u8>) { + // Encode lower three bytes of the u32 as u24 + bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]); + } +} + +impl Decode for u32 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(bytes.read_u32::<BigEndian>()?) + } +} + +impl Encode for u32 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&u32::to_be_bytes(*self)); + } +} + +impl Decode for u64 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(bytes.read_u64::<BigEndian>()?) + } +} + +impl Encode for u64 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&u64::to_be_bytes(*self)); + } +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u8_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + bytes.push(0); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 1; + assert!(len <= u8::MAX.into()); + bytes[len_offset] = len as u8; +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u8_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read one byte to get length of opaque byte vector + let length = usize::from(u8::decode(bytes)?); + + decode_items(length, decoding_parameter, bytes) +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u16_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + 0u16.encode(bytes); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 2; + assert!(len <= u16::MAX.into()); + for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() { + bytes[len_offset + offset] = *byte; + } +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u16_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read two bytes to get length of opaque byte vector + let length = usize::from(u16::decode(bytes)?); + + decode_items(length, decoding_parameter, bytes) +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of +/// `0xffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u24_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + U24(0).encode(bytes); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 3; + assert!(len <= 0xffffff); + for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() { + bytes[len_offset + offset] = *byte; + } +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u24_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read three bytes to get length of opaque byte vector + let length = U24::decode(bytes)?.0 as usize; + + decode_items(length, decoding_parameter, bytes) +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of +/// `0xffffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u32_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + 0u32.encode(bytes); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 4; + let len: u32 = len.try_into().expect("Length too large"); + for (offset, byte) in len.to_be_bytes().iter().enumerate() { + bytes[len_offset + offset] = *byte; + } +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xffffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u32_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read four bytes to get length of opaque byte vector. + let len: usize = u32::decode(bytes)? + .try_into() + .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?; + + decode_items(len, decoding_parameter, bytes) +} + +/// Decode the next `length` bytes from `bytes` into as many instances of `D` as possible. +fn decode_items<P, D: ParameterizedDecode<P>>( + length: usize, + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + let mut decoded = Vec::new(); + let initial_position = bytes.position() as usize; + + // Create cursor over specified portion of provided cursor to ensure we can't read past length. + let inner = bytes.get_ref(); + + // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer. + let (items_end, overflowed) = initial_position.overflowing_add(length); + if overflowed || items_end > inner.len() { + return Err(CodecError::LengthPrefixTooBig(length)); + } + + let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]); + + while sub.position() < length as u64 { + decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?); + } + + // Advance outer cursor by the amount read in the inner cursor + bytes.set_position(initial_position as u64 + sub.position()); + + Ok(decoded) +} + +#[cfg(test)] +mod tests { + + use super::*; + use assert_matches::assert_matches; + + #[test] + fn encode_nothing() { + let mut bytes = vec![]; + ().encode(&mut bytes); + assert_eq!(bytes.len(), 0); + } + + #[test] + fn roundtrip_u8() { + let value = 100u8; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 1); + + let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u16() { + let value = 1000u16; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 2); + // Check endianness of encoding + assert_eq!(bytes, vec![3, 232]); + + let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u24() { + let value = U24(1_000_000u32); + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 3); + // Check endianness of encoding + assert_eq!(bytes, vec![15, 66, 64]); + + let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u32() { + let value = 134_217_728u32; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 4); + // Check endianness of encoding + assert_eq!(bytes, vec![8, 0, 0, 0]); + + let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u64() { + let value = 137_438_953_472u64; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 8); + // Check endianness of encoding + assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]); + + let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[derive(Debug, Eq, PartialEq)] + struct TestMessage { + field_u8: u8, + field_u16: u16, + field_u24: U24, + field_u32: u32, + field_u64: u64, + } + + impl Encode for TestMessage { + fn encode(&self, bytes: &mut Vec<u8>) { + self.field_u8.encode(bytes); + self.field_u16.encode(bytes); + self.field_u24.encode(bytes); + self.field_u32.encode(bytes); + self.field_u64.encode(bytes); + } + } + + impl Decode for TestMessage { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let field_u8 = u8::decode(bytes)?; + let field_u16 = u16::decode(bytes)?; + let field_u24 = U24::decode(bytes)?; + let field_u32 = u32::decode(bytes)?; + let field_u64 = u64::decode(bytes)?; + + Ok(TestMessage { + field_u8, + field_u16, + field_u24, + field_u32, + field_u64, + }) + } + } + + impl TestMessage { + fn encoded_length() -> usize { + // u8 field + 1 + + // u16 field + 2 + + // u24 field + 3 + + // u32 field + 4 + + // u64 field + 8 + } + } + + #[test] + fn roundtrip_message() { + let value = TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), TestMessage::encoded_length()); + + let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + fn messages_vec() -> Vec<TestMessage> { + vec![ + TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }, + TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }, + TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }, + ] + } + + #[test] + fn roundtrip_variable_length_u8() { + let values = messages_vec(); + let mut bytes = vec![]; + encode_u8_items(&mut bytes, &(), &values); + + assert_eq!( + bytes.len(), + // Length of opaque vector + 1 + + // 3 TestMessage values + 3 * TestMessage::encoded_length() + ); + + let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn roundtrip_variable_length_u16() { + let values = messages_vec(); + let mut bytes = vec![]; + encode_u16_items(&mut bytes, &(), &values); + + assert_eq!( + bytes.len(), + // Length of opaque vector + 2 + + // 3 TestMessage values + 3 * TestMessage::encoded_length() + ); + + // Check endianness of encoded length + assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]); + + let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn roundtrip_variable_length_u24() { + let values = messages_vec(); + let mut bytes = vec![]; + encode_u24_items(&mut bytes, &(), &values); + + assert_eq!( + bytes.len(), + // Length of opaque vector + 3 + + // 3 TestMessage values + 3 * TestMessage::encoded_length() + ); + + // Check endianness of encoded length + assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]); + + let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn roundtrip_variable_length_u32() { + let values = messages_vec(); + let mut bytes = Vec::new(); + encode_u32_items(&mut bytes, &(), &values); + + assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length()); + + // Check endianness of encoded length. + assert_eq!( + bytes[0..4], + [0, 0, 0, 3 * TestMessage::encoded_length() as u8] + ); + + let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn decode_items_overflow() { + let encoded = vec![1u8]; + + let mut cursor = Cursor::new(encoded.as_slice()); + cursor.set_position(1); + + assert_matches!( + decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(), + CodecError::LengthPrefixTooBig(usize::MAX) + ); + } + + #[test] + fn decode_items_too_big() { + let encoded = vec![1u8]; + + let mut cursor = Cursor::new(encoded.as_slice()); + cursor.set_position(1); + + assert_matches!( + decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(), + CodecError::LengthPrefixTooBig(2) + ); + } +} diff --git a/third_party/rust/prio/src/encrypt.rs b/third_party/rust/prio/src/encrypt.rs new file mode 100644 index 0000000000..9429c0f2ce --- /dev/null +++ b/third_party/rust/prio/src/encrypt.rs @@ -0,0 +1,232 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Utilities for ECIES encryption / decryption used by the Prio client and server. + +use crate::prng::PrngError; + +use aes_gcm::aead::generic_array::typenum::U16; +use aes_gcm::aead::generic_array::GenericArray; +use aes_gcm::{AeadInPlace, NewAead}; +use ring::agreement; +type Aes128 = aes_gcm::AesGcm<aes_gcm::aes::Aes128, U16>; + +/// Length of the EC public key (X9.62 format) +pub const PUBLICKEY_LENGTH: usize = 65; +/// Length of the AES-GCM tag +pub const TAG_LENGTH: usize = 16; +/// Length of the symmetric AES-GCM key +const KEY_LENGTH: usize = 16; + +/// Possible errors from encryption / decryption. +#[derive(Debug, thiserror::Error)] +pub enum EncryptError { + /// Base64 decoding error + #[error("base64 decoding error")] + DecodeBase64(#[from] base64::DecodeError), + /// Error in ECDH + #[error("error in ECDH")] + KeyAgreement, + /// Buffer for ciphertext was not large enough + #[error("buffer for ciphertext was not large enough")] + Encryption, + /// Authentication tags did not match. + #[error("authentication tags did not match")] + Decryption, + /// Input ciphertext was too small + #[error("input ciphertext was too small")] + DecryptionLength, + /// PRNG error + #[error("prng error: {0}")] + Prng(#[from] PrngError), + /// failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), +} + +/// NIST P-256, public key in X9.62 uncompressed format +#[derive(Debug, Clone)] +pub struct PublicKey(Vec<u8>); + +/// NIST P-256, private key +/// +/// X9.62 uncompressed public key concatenated with the secret scalar. +#[derive(Debug, Clone)] +pub struct PrivateKey(Vec<u8>); + +impl PublicKey { + /// Load public key from a base64 encoded X9.62 uncompressed representation. + pub fn from_base64(key: &str) -> Result<Self, EncryptError> { + let keydata = base64::decode(key)?; + Ok(PublicKey(keydata)) + } +} + +/// Copy public key from a private key +impl std::convert::From<&PrivateKey> for PublicKey { + fn from(pk: &PrivateKey) -> Self { + PublicKey(pk.0[..PUBLICKEY_LENGTH].to_owned()) + } +} + +impl PrivateKey { + /// Load private key from a base64 encoded string. + pub fn from_base64(key: &str) -> Result<Self, EncryptError> { + let keydata = base64::decode(key)?; + Ok(PrivateKey(keydata)) + } +} + +/// Encrypt a bytestring using the public key +/// +/// This uses ECIES with X9.63 key derivation function and AES-GCM for the +/// symmetic encryption and MAC. +pub fn encrypt_share(share: &[u8], key: &PublicKey) -> Result<Vec<u8>, EncryptError> { + let rng = ring::rand::SystemRandom::new(); + let ephemeral_priv = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &rng) + .map_err(|_| EncryptError::KeyAgreement)?; + let peer_public = agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, &key.0); + let ephemeral_pub = ephemeral_priv + .compute_public_key() + .map_err(|_| EncryptError::KeyAgreement)?; + + let symmetric_key_bytes = agreement::agree_ephemeral( + ephemeral_priv, + &peer_public, + EncryptError::KeyAgreement, + |material| Ok(x963_kdf(material, ephemeral_pub.as_ref())), + )?; + + let in_out = share.to_owned(); + let encrypted = encrypt_aes_gcm( + &symmetric_key_bytes[..16], + &symmetric_key_bytes[16..], + in_out, + )?; + + let mut output = Vec::with_capacity(encrypted.len() + ephemeral_pub.as_ref().len()); + output.extend_from_slice(ephemeral_pub.as_ref()); + output.extend_from_slice(&encrypted); + + Ok(output) +} + +/// Decrypt a bytestring using the private key +/// +/// This uses ECIES with X9.63 key derivation function and AES-GCM for the +/// symmetic encryption and MAC. +pub fn decrypt_share(share: &[u8], key: &PrivateKey) -> Result<Vec<u8>, EncryptError> { + if share.len() < PUBLICKEY_LENGTH + TAG_LENGTH { + return Err(EncryptError::DecryptionLength); + } + let empheral_pub_bytes: &[u8] = &share[0..PUBLICKEY_LENGTH]; + + let ephemeral_pub = + agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, empheral_pub_bytes); + + let fake_rng = ring::test::rand::FixedSliceRandom { + // private key consists of the public key + private scalar + bytes: &key.0[PUBLICKEY_LENGTH..], + }; + + let private_key = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &fake_rng) + .map_err(|_| EncryptError::KeyAgreement)?; + + let symmetric_key_bytes = agreement::agree_ephemeral( + private_key, + &ephemeral_pub, + EncryptError::KeyAgreement, + |material| Ok(x963_kdf(material, empheral_pub_bytes)), + )?; + + // in_out is the AES-GCM ciphertext+tag, wihtout the ephemeral EC pubkey + let in_out = share[PUBLICKEY_LENGTH..].to_owned(); + decrypt_aes_gcm( + &symmetric_key_bytes[..KEY_LENGTH], + &symmetric_key_bytes[KEY_LENGTH..], + in_out, + ) +} + +fn x963_kdf(z: &[u8], shared_info: &[u8]) -> [u8; 32] { + let mut hasher = ring::digest::Context::new(&ring::digest::SHA256); + hasher.update(z); + hasher.update(&1u32.to_be_bytes()); + hasher.update(shared_info); + let digest = hasher.finish(); + use std::convert::TryInto; + // unwrap never fails because SHA256 output len is 32 + digest.as_ref().try_into().unwrap() +} + +fn decrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> { + let cipher = Aes128::new(GenericArray::from_slice(key)); + cipher + .decrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data) + .map_err(|_| EncryptError::Decryption)?; + Ok(data) +} + +fn encrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> { + let cipher = Aes128::new(GenericArray::from_slice(key)); + cipher + .encrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data) + .map_err(|_| EncryptError::Encryption)?; + Ok(data) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encrypt_decrypt() -> Result<(), EncryptError> { + let pub_key = PublicKey::from_base64( + "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9\ + HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=", + )?; + let priv_key = PrivateKey::from_base64( + "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgN\ + t9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==", + )?; + let data = (0..100).map(|_| rand::random::<u8>()).collect::<Vec<u8>>(); + let encrypted = encrypt_share(&data, &pub_key)?; + + let decrypted = decrypt_share(&encrypted, &priv_key)?; + assert_eq!(decrypted, data); + Ok(()) + } + + #[test] + fn test_interop() { + let share1 = base64::decode("Kbnd2ZWrsfLfcpuxHffMrJ1b7sCrAsNqlb6Y1eAMfwCVUNXt").unwrap(); + let share2 = base64::decode("hu+vT3+8/taHP7B/dWXh/g==").unwrap(); + let encrypted_share1 = base64::decode( + "BEWObg41JiMJglSEA6Ebk37xOeflD2a1t2eiLmX0OPccJhAER5NmOI+4r4Cfm7aJn141sGKnTbCuIB9+AeVuw\ + MAQnzjsGPu5aNgkdpp+6VowAcVAV1DlzZvtwlQkCFlX4f3xmafTPFTPOokYi2a+H1n8GKwd", + ) + .unwrap(); + let encrypted_share2 = base64::decode( + "BNRzQ6TbqSc4pk0S8aziVRNjWm4DXQR5yCYTK2w22iSw4XAPW4OB9RxBpWVa1C/3ywVBT/3yLArOMXEsCEMOG\ + 1+d2CiEvtuU1zADH2MVaCnXL/dVXkDchYZsvPWPkDcjQA==", + ) + .unwrap(); + + let priv_key1 = PrivateKey::from_base64( + "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOg\ + Nt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==", + ) + .unwrap(); + let priv_key2 = PrivateKey::from_base64( + "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhF\ + LMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==", + ) + .unwrap(); + + let decrypted1 = decrypt_share(&encrypted_share1, &priv_key1).unwrap(); + let decrypted2 = decrypt_share(&encrypted_share2, &priv_key2).unwrap(); + + assert_eq!(decrypted1, share1); + assert_eq!(decrypted2, share2); + } +} diff --git a/third_party/rust/prio/src/fft.rs b/third_party/rust/prio/src/fft.rs new file mode 100644 index 0000000000..c7a1dfbb8b --- /dev/null +++ b/third_party/rust/prio/src/fft.rs @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier +//! Transform (DFT) over a slice of field elements. + +use crate::field::FieldElement; +use crate::fp::{log2, MAX_ROOTS}; + +use std::convert::TryFrom; + +/// An error returned by an FFT operation. +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum FftError { + /// The output is too small. + #[error("output slice is smaller than specified size")] + OutputTooSmall, + /// The specified size is too large. + #[error("size is larger than than maximum permitted")] + SizeTooLarge, + /// The specified size is not a power of 2. + #[error("size is not a power of 2")] + SizeInvalid, +} + +/// Sets `outp` to the DFT of `inp`. +/// +/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input +/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of +/// unity. +#[allow(clippy::many_single_char_names)] +pub fn discrete_fourier_transform<F: FieldElement>( + outp: &mut [F], + inp: &[F], + size: usize, +) -> Result<(), FftError> { + let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?; + + if size > outp.len() { + return Err(FftError::OutputTooSmall); + } + + if size > 1 << MAX_ROOTS { + return Err(FftError::SizeTooLarge); + } + + if size != 1 << d { + return Err(FftError::SizeInvalid); + } + + #[allow(clippy::needless_range_loop)] + for i in 0..size { + let j = bitrev(d, i); + outp[i] = if j < inp.len() { inp[j] } else { F::zero() } + } + + let mut w: F; + for l in 1..d + 1 { + w = F::one(); + let r = F::root(l).unwrap(); + let y = 1 << (l - 1); + for i in 0..y { + for j in 0..(size / y) >> 1 { + let x = (1 << l) * j + i; + let u = outp[x]; + let v = w * outp[x + y]; + outp[x] = u + v; + outp[x + y] = u - v; + } + w *= r; + } + } + + Ok(()) +} + +/// Sets `outp` to the inverse of the DFT of `inp`. +#[cfg(test)] +pub(crate) fn discrete_fourier_transform_inv<F: FieldElement>( + outp: &mut [F], + inp: &[F], + size: usize, +) -> Result<(), FftError> { + let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv(); + discrete_fourier_transform(outp, inp, size)?; + discrete_fourier_transform_inv_finish(outp, size, size_inv); + Ok(()) +} + +/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to +/// amortize the cost the modular inverse across multiple inverse DFT operations. +pub(crate) fn discrete_fourier_transform_inv_finish<F: FieldElement>( + outp: &mut [F], + size: usize, + size_inv: F, +) { + let mut tmp: F; + outp[0] *= size_inv; + outp[size >> 1] *= size_inv; + for i in 1..size >> 1 { + tmp = outp[i] * size_inv; + outp[i] = outp[size - i] * size_inv; + outp[size - i] = tmp; + } +} + +// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109) +fn bitrev(d: usize, x: usize) -> usize { + let mut y = 0; + for i in 0..d { + y += ((x >> i) & 1) << (d - i); + } + y >> 1 +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{ + random_vector, split_vector, Field128, Field32, Field64, Field96, FieldPrio2, + }; + use crate::polynomial::{poly_fft, PolyAuxMemory}; + + fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> { + let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; + + for size in test_sizes.iter() { + let mut tmp = vec![F::zero(); *size]; + let mut got = vec![F::zero(); *size]; + let want = random_vector(*size).unwrap(); + + discrete_fourier_transform(&mut tmp, &want, want.len())?; + discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?; + assert_eq!(got, want); + } + + Ok(()) + } + + #[test] + fn test_field32() { + discrete_fourier_transform_then_inv_test::<Field32>().expect("unexpected error"); + } + + #[test] + fn test_priov2_field32() { + discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error"); + } + + #[test] + fn test_field64() { + discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error"); + } + + #[test] + fn test_field96() { + discrete_fourier_transform_then_inv_test::<Field96>().expect("unexpected error"); + } + + #[test] + fn test_field128() { + discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error"); + } + + #[test] + fn test_recursive_fft() { + let size = 128; + let mut mem = PolyAuxMemory::new(size / 2); + + let inp = random_vector(size).unwrap(); + let mut want = vec![Field32::zero(); size]; + let mut got = vec![Field32::zero(); size]; + + discrete_fourier_transform::<Field32>(&mut want, &inp, inp.len()).unwrap(); + + poly_fft( + &mut got, + &inp, + &mem.roots_2n, + size, + false, + &mut mem.fft_memory, + ); + + assert_eq!(got, want); + } + + // This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial + // over secret shares and summing up the coefficients is equivalent to interpolating a + // polynomial over the plaintext data. + #[test] + fn test_fft_linearity() { + let len = 16; + let num_shares = 3; + let x: Vec<Field64> = random_vector(len).unwrap(); + let mut x_shares = split_vector(&x, num_shares).unwrap(); + + // Just for fun, let's do something different with a subset of the inputs. For the first + // share, every odd element is set to the plaintext value. For all shares but the first, + // every odd element is set to 0. + #[allow(clippy::needless_range_loop)] + for i in 0..len { + if i % 2 != 0 { + x_shares[0][i] = x[i]; + } + for j in 1..num_shares { + if i % 2 != 0 { + x_shares[j][i] = Field64::zero(); + } + } + } + + let mut got = vec![Field64::zero(); len]; + let mut buf = vec![Field64::zero(); len]; + for share in x_shares { + discrete_fourier_transform_inv(&mut buf, &share, len).unwrap(); + for i in 0..len { + got[i] += buf[i]; + } + } + + let mut want = vec![Field64::zero(); len]; + discrete_fourier_transform_inv(&mut want, &x, len).unwrap(); + + assert_eq!(got, want); + } +} diff --git a/third_party/rust/prio/src/field.rs b/third_party/rust/prio/src/field.rs new file mode 100644 index 0000000000..cbfe40826e --- /dev/null +++ b/third_party/rust/prio/src/field.rs @@ -0,0 +1,960 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic. +//! +//! Each field has an associated parameter called the "generator" that generates a multiplicative +//! subgroup of order `2^n` for some `n`. + +#[cfg(feature = "crypto-dependencies")] +use crate::prng::{Prng, PrngError}; +use crate::{ + codec::{CodecError, Decode, Encode}, + fp::{FP128, FP32, FP64, FP96}, +}; +use serde::{ + de::{DeserializeOwned, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::{ + cmp::min, + convert::{TryFrom, TryInto}, + fmt::{self, Debug, Display, Formatter}, + hash::{Hash, Hasher}, + io::{Cursor, Read}, + marker::PhantomData, + ops::{Add, AddAssign, BitAnd, Div, DivAssign, Mul, MulAssign, Neg, Shl, Shr, Sub, SubAssign}, +}; + +/// Possible errors from finite field operations. +#[derive(Debug, thiserror::Error)] +pub enum FieldError { + /// Input sizes do not match. + #[error("input sizes do not match")] + InputSizeMismatch, + /// Returned when decoding a `FieldElement` from a short byte string. + #[error("short read from bytes")] + ShortRead, + /// Returned when decoding a `FieldElement` from a byte string encoding an integer larger than + /// or equal to the field modulus. + #[error("read from byte slice exceeds modulus")] + ModulusOverflow, + /// Error while performing I/O. + #[error("I/O error")] + Io(#[from] std::io::Error), + /// Error encoding or decoding a field. + #[error("Codec error")] + Codec(#[from] CodecError), + /// Error converting to `FieldElement::Integer`. + #[error("Integer TryFrom error")] + IntegerTryFrom, + /// Error converting `FieldElement::Integer` into something else. + #[error("Integer TryInto error")] + IntegerTryInto, +} + +/// Byte order for encoding FieldElement values into byte sequences. +#[derive(Clone, Copy, Debug)] +enum ByteOrder { + /// Big endian byte order. + BigEndian, + /// Little endian byte order. + LittleEndian, +} + +/// Objects with this trait represent an element of `GF(p)` for some prime `p`. +pub trait FieldElement: + Sized + + Debug + + Copy + + PartialEq + + Eq + + Add<Output = Self> + + AddAssign + + Sub<Output = Self> + + SubAssign + + Mul<Output = Self> + + MulAssign + + Div<Output = Self> + + DivAssign + + Neg<Output = Self> + + Display + + From<<Self as FieldElement>::Integer> + + for<'a> TryFrom<&'a [u8], Error = FieldError> + // NOTE Ideally we would require `Into<[u8; Self::ENCODED_SIZE]>` instead of `Into<Vec<u8>>`, + // since the former avoids a heap allocation and can easily be converted into Vec<u8>, but that + // isn't possible yet[1]. However we can provide the impl on FieldElement implementations. + // [1]: https://github.com/rust-lang/rust/issues/60551 + + Into<Vec<u8>> + + Serialize + + DeserializeOwned + + Encode + + Decode + + 'static // NOTE This bound is needed for downcasting a `dyn Gadget<F>>` to a concrete type. +{ + /// Size in bytes of the encoding of a value. + const ENCODED_SIZE: usize; + + /// The error returned if converting `usize` to an `Integer` fails. + type IntegerTryFromError: std::error::Error; + + /// The error returend if converting an `Integer` to a `u64` fails. + type TryIntoU64Error: std::error::Error; + + /// The integer representation of the field element. + type Integer: Copy + + Debug + + Eq + + Ord + + BitAnd<Output = <Self as FieldElement>::Integer> + + Div<Output = <Self as FieldElement>::Integer> + + Shl<Output = <Self as FieldElement>::Integer> + + Shr<Output = <Self as FieldElement>::Integer> + + Add<Output = <Self as FieldElement>::Integer> + + Sub<Output = <Self as FieldElement>::Integer> + + From<Self> + + TryFrom<usize, Error = Self::IntegerTryFromError> + + TryInto<u64, Error = Self::TryIntoU64Error>; + + /// Modular exponentation, i.e., `self^exp (mod p)`. + fn pow(&self, exp: Self::Integer) -> Self; + + /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined. + fn inv(&self) -> Self; + + /// Returns the prime modulus `p`. + fn modulus() -> Self::Integer; + + /// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the + /// field. The `m` most significant bits are cleared, where `m` is equal to the length of + /// [`Self::Integer`] in bits minus the length of the modulus in bits. + /// + /// # Errors + /// + /// An error is returned if the provided slice is too small to encode a field element or if the + /// result encodes an integer larger than or equal to the field modulus. + /// + /// # Warnings + /// + /// This function should only be used within [`prng::Prng`] to convert a random byte string into + /// a field element. Use [`Self::decode`] to deserialize field elements. Use + /// [`field::rand`] or [`prng::Prng`] to randomly generate field elements. + #[doc(hidden)] + fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError>; + + /// Returns the size of the multiplicative subgroup generated by `generator()`. + fn generator_order() -> Self::Integer; + + /// Returns the generator of the multiplicative subgroup of size `generator_order()`. + fn generator() -> Self; + + /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th + /// prinicpal root of unity is 1 by definition. + fn root(l: usize) -> Option<Self>; + + /// Returns the additive identity. + fn zero() -> Self; + + /// Returns the multiplicative identity. + fn one() -> Self; + + /// Convert a slice of field elements into a vector of bytes. + /// + /// # Notes + /// + /// Ideally we would implement `From<&[F: FieldElement]> for Vec<u8>` or the corresponding + /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this + /// impossible. + fn slice_into_byte_vec(values: &[Self]) -> Vec<u8> { + let mut vec = Vec::with_capacity(values.len() * Self::ENCODED_SIZE); + for elem in values { + vec.append(&mut (*elem).into()); + } + vec + } + + /// Convert a slice of bytes into a vector of field elements. The slice is interpreted as a + /// sequence of [`Self::ENCODED_SIZE`]-byte sequences. + /// + /// # Errors + /// + /// Returns an error if the length of the provided byte slice is not a multiple of the size of a + /// field element, or if any of the values in the byte slice are invalid encodings of a field + /// element, because the encoded integer is larger than or equal to the field modulus. + /// + /// # Notes + /// + /// Ideally we would implement `From<&[u8]> for Vec<F: FieldElement>` or the corresponding + /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this + /// impossible. + fn byte_slice_into_vec(bytes: &[u8]) -> Result<Vec<Self>, FieldError> { + if bytes.len() % Self::ENCODED_SIZE != 0 { + return Err(FieldError::ShortRead); + } + let mut vec = Vec::with_capacity(bytes.len() / Self::ENCODED_SIZE); + for chunk in bytes.chunks_exact(Self::ENCODED_SIZE) { + vec.push(Self::get_decoded(chunk)?); + } + Ok(vec) + } +} + +/// Methods common to all `FieldElement` implementations that are private to the crate. +pub(crate) trait FieldElementExt: FieldElement { + /// Encode `input` as `bits`-bit vector of elements of `Self` if it's small enough + /// to be represented with that many bits. + /// + /// # Arguments + /// + /// * `input` - The field element to encode + /// * `bits` - The number of bits to use for the encoding + fn encode_into_bitvector_representation( + input: &Self::Integer, + bits: usize, + ) -> Result<Vec<Self>, FieldError> { + // Create a mutable copy of `input`. In each iteration of the following loop we take the + // least significant bit, and shift input to the right by one bit. + let mut i = *input; + + let one = Self::Integer::from(Self::one()); + let mut encoded = Vec::with_capacity(bits); + for _ in 0..bits { + let w = Self::from(i & one); + encoded.push(w); + i = i >> one; + } + + // If `i` is still not zero, this means that it cannot be encoded by `bits` bits. + if i != Self::Integer::from(Self::zero()) { + return Err(FieldError::InputSizeMismatch); + } + + Ok(encoded) + } + + /// Decode the bitvector-represented value `input` into a simple representation as a single + /// field element. + /// + /// # Errors + /// + /// This function errors if `2^input.len() - 1` does not fit into the field `Self`. + fn decode_from_bitvector_representation(input: &[Self]) -> Result<Self, FieldError> { + if !Self::valid_integer_bitlength(input.len()) { + return Err(FieldError::ModulusOverflow); + } + + let mut decoded = Self::zero(); + for (l, bit) in input.iter().enumerate() { + let w = Self::Integer::try_from(1 << l).map_err(|_| FieldError::IntegerTryFrom)?; + decoded += Self::from(w) * *bit; + } + Ok(decoded) + } + + /// Interpret `i` as [`Self::Integer`] if it's representable in that type and smaller than the + /// field modulus. + fn valid_integer_try_from<N>(i: N) -> Result<Self::Integer, FieldError> + where + Self::Integer: TryFrom<N>, + { + let i_int = Self::Integer::try_from(i).map_err(|_| FieldError::IntegerTryFrom)?; + if Self::modulus() <= i_int { + return Err(FieldError::ModulusOverflow); + } + Ok(i_int) + } + + /// Check if the largest number representable with `bits` bits (i.e. 2^bits - 1) is + /// representable in this field. + fn valid_integer_bitlength(bits: usize) -> bool { + if let Ok(bits_int) = Self::Integer::try_from(bits) { + if Self::modulus() >> bits_int != Self::Integer::from(Self::zero()) { + return true; + } + } + false + } +} + +impl<F: FieldElement> FieldElementExt for F {} + +/// serde Visitor implementation used to generically deserialize `FieldElement` +/// values from byte arrays. +struct FieldElementVisitor<F: FieldElement> { + phantom: PhantomData<F>, +} + +impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> { + type Value = F; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + formatter.write_fmt(format_args!("an array of {} bytes", F::ENCODED_SIZE)) + } + + fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + Self::Value::try_from(v).map_err(E::custom) + } + + fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + let mut bytes = vec![]; + while let Some(byte) = seq.next_element()? { + bytes.push(byte); + } + + self.visit_bytes(&bytes) + } +} + +macro_rules! make_field { + ( + $(#[$meta:meta])* + $elem:ident, $int:ident, $fp:ident, $encoding_size:literal, $encoding_order:expr, + ) => { + $(#[$meta])* + /// + /// This structure represents a field element in a prime order field. The concrete + /// representation of the element is via the Montgomery domain. For an element n in GF(p), + /// we store n * R^-1 mod p (where R is a given power of two). This representation enables + /// using a more efficient (and branchless) multiplication algorithm, at the expense of + /// having to convert elements between their Montgomery domain representation and natural + /// representation. For calculations with many multiplications or exponentiations, this is + /// worthwhile. + /// + /// As an invariant, this integer representing the field element in the Montgomery domain + /// must be less than the prime p. + #[derive(Clone, Copy, PartialOrd, Ord, Default)] + pub struct $elem(u128); + + impl $elem { + /// Attempts to instantiate an `$elem` from the first `Self::ENCODED_SIZE` bytes in the + /// provided slice. The decoded value will be bitwise-ANDed with `mask` before reducing + /// it using the field modulus. + /// + /// # Errors + /// + /// An error is returned if the provided slice is not long enough to encode a field + /// element or if the decoded value is greater than the field prime. + /// + /// # Notes + /// + /// We cannot use `u128::from_le_bytes` or `u128::from_be_bytes` because those functions + /// expect inputs to be exactly 16 bytes long. Our encoding of most field elements is + /// more compact, and does not have to correspond to the size of an integer type. For + /// instance,`Field96`'s encoding is 12 bytes, even though it is a 16 byte `u128` in + /// memory. + fn try_from_bytes(bytes: &[u8], mask: u128) -> Result<Self, FieldError> { + if Self::ENCODED_SIZE > bytes.len() { + return Err(FieldError::ShortRead); + } + + let mut int = 0; + for i in 0..Self::ENCODED_SIZE { + let j = match $encoding_order { + ByteOrder::LittleEndian => i, + ByteOrder::BigEndian => Self::ENCODED_SIZE - i - 1, + }; + + int |= (bytes[j] as u128) << (i << 3); + } + + int &= mask; + + if int >= $fp.p { + return Err(FieldError::ModulusOverflow); + } + // FieldParameters::montgomery() will return a value that has been fully reduced + // mod p, satisfying the invariant on Self. + Ok(Self($fp.montgomery(int))) + } + } + + impl PartialEq for $elem { + fn eq(&self, rhs: &Self) -> bool { + // The fields included in this comparison MUST match the fields + // used in Hash::hash + // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq + + // Check the invariant that the integer representation is fully reduced. + debug_assert!(self.0 < $fp.p); + debug_assert!(rhs.0 < $fp.p); + + self.0 == rhs.0 + } + } + + impl Hash for $elem { + fn hash<H: Hasher>(&self, state: &mut H) { + // The fields included in this hash MUST match the fields used + // in PartialEq::eq + // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq + + // Check the invariant that the integer representation is fully reduced. + debug_assert!(self.0 < $fp.p); + + self.0.hash(state); + } + } + + impl Eq for $elem {} + + impl Add for $elem { + type Output = $elem; + fn add(self, rhs: Self) -> Self { + // FieldParameters::add() returns a value that has been fully reduced + // mod p, satisfying the invariant on Self. + Self($fp.add(self.0, rhs.0)) + } + } + + impl Add for &$elem { + type Output = $elem; + fn add(self, rhs: Self) -> $elem { + *self + *rhs + } + } + + impl AddAssign for $elem { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + + impl Sub for $elem { + type Output = $elem; + fn sub(self, rhs: Self) -> Self { + // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub() + // returns a value less than p, satisfying the invariant on Self. + Self($fp.sub(self.0, rhs.0)) + } + } + + impl Sub for &$elem { + type Output = $elem; + fn sub(self, rhs: Self) -> $elem { + *self - *rhs + } + } + + impl SubAssign for $elem { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + + impl Mul for $elem { + type Output = $elem; + fn mul(self, rhs: Self) -> Self { + // FieldParameters::mul() always returns a value less than p, so the invariant on + // Self is satisfied. + Self($fp.mul(self.0, rhs.0)) + } + } + + impl Mul for &$elem { + type Output = $elem; + fn mul(self, rhs: Self) -> $elem { + *self * *rhs + } + } + + impl MulAssign for $elem { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + + impl Div for $elem { + type Output = $elem; + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self { + self * rhs.inv() + } + } + + impl Div for &$elem { + type Output = $elem; + fn div(self, rhs: Self) -> $elem { + *self / *rhs + } + } + + impl DivAssign for $elem { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } + + impl Neg for $elem { + type Output = $elem; + fn neg(self) -> Self { + // FieldParameters::neg() will return a value less than p because self.0 is less + // than p, and neg() dispatches to sub(). + Self($fp.neg(self.0)) + } + } + + impl Neg for &$elem { + type Output = $elem; + fn neg(self) -> $elem { + -(*self) + } + } + + impl From<$int> for $elem { + fn from(x: $int) -> Self { + // FieldParameters::montgomery() will return a value that has been fully reduced + // mod p, satisfying the invariant on Self. + Self($fp.montgomery(u128::try_from(x).unwrap())) + } + } + + impl From<$elem> for $int { + fn from(x: $elem) -> Self { + $int::try_from($fp.residue(x.0)).unwrap() + } + } + + impl PartialEq<$int> for $elem { + fn eq(&self, rhs: &$int) -> bool { + $fp.residue(self.0) == u128::try_from(*rhs).unwrap() + } + } + + impl<'a> TryFrom<&'a [u8]> for $elem { + type Error = FieldError; + + fn try_from(bytes: &[u8]) -> Result<Self, FieldError> { + Self::try_from_bytes(bytes, u128::MAX) + } + } + + impl From<$elem> for [u8; $elem::ENCODED_SIZE] { + fn from(elem: $elem) -> Self { + let int = $fp.residue(elem.0); + let mut slice = [0; $elem::ENCODED_SIZE]; + for i in 0..$elem::ENCODED_SIZE { + let j = match $encoding_order { + ByteOrder::LittleEndian => i, + ByteOrder::BigEndian => $elem::ENCODED_SIZE - i - 1, + }; + + slice[j] = ((int >> (i << 3)) & 0xff) as u8; + } + slice + } + } + + impl From<$elem> for Vec<u8> { + fn from(elem: $elem) -> Self { + <[u8; $elem::ENCODED_SIZE]>::from(elem).to_vec() + } + } + + impl Display for $elem { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "{}", $fp.residue(self.0)) + } + } + + impl Debug for $elem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $fp.residue(self.0)) + } + } + + // We provide custom [`serde::Serialize`] and [`serde::Deserialize`] implementations because + // the derived implementations would represent `FieldElement` values as the backing `u128`, + // which is not what we want because (1) we can be more efficient in all cases and (2) in + // some circumstances, [some serializers don't support `u128`](https://github.com/serde-rs/json/issues/625). + impl Serialize for $elem { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + let bytes: [u8; $elem::ENCODED_SIZE] = (*self).into(); + serializer.serialize_bytes(&bytes) + } + } + + impl<'de> Deserialize<'de> for $elem { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<$elem, D::Error> { + deserializer.deserialize_bytes(FieldElementVisitor { phantom: PhantomData }) + } + } + + impl Encode for $elem { + fn encode(&self, bytes: &mut Vec<u8>) { + let slice = <[u8; $elem::ENCODED_SIZE]>::from(*self); + bytes.extend_from_slice(&slice); + } + } + + impl Decode for $elem { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut value = [0u8; $elem::ENCODED_SIZE]; + bytes.read_exact(&mut value)?; + $elem::try_from_bytes(&value, u128::MAX).map_err(|e| { + CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>) + }) + } + } + + impl FieldElement for $elem { + const ENCODED_SIZE: usize = $encoding_size; + type Integer = $int; + type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error; + type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error; + + fn pow(&self, exp: Self::Integer) -> Self { + // FieldParameters::pow() relies on mul(), and will always return a value less + // than p. + Self($fp.pow(self.0, u128::try_from(exp).unwrap())) + } + + fn inv(&self) -> Self { + // FieldParameters::inv() ultimately relies on mul(), and will always return a + // value less than p. + Self($fp.inv(self.0)) + } + + fn modulus() -> Self::Integer { + $fp.p as $int + } + + fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> { + $elem::try_from_bytes(bytes, $fp.bit_mask) + } + + fn generator() -> Self { + Self($fp.g) + } + + fn generator_order() -> Self::Integer { + 1 << (Self::Integer::try_from($fp.num_roots).unwrap()) + } + + fn root(l: usize) -> Option<Self> { + if l < min($fp.roots.len(), $fp.num_roots+1) { + Some(Self($fp.roots[l])) + } else { + None + } + } + + fn zero() -> Self { + Self(0) + } + + fn one() -> Self { + Self($fp.roots[0]) + } + } + }; +} + +make_field!( + /// `GF(4293918721)`, a 32-bit field. + Field32, + u32, + FP32, + 4, + ByteOrder::BigEndian, +); + +make_field!( + /// Same as Field32, but encoded in little endian for compatibility with Prio v2. + FieldPrio2, + u32, + FP32, + 4, + ByteOrder::LittleEndian, +); + +make_field!( + /// `GF(18446744069414584321)`, a 64-bit field. + Field64, + u64, + FP64, + 8, + ByteOrder::BigEndian, +); + +make_field!( + /// `GF(79228148845226978974766202881)`, a 96-bit field. + Field96, + u128, + FP96, + 12, + ByteOrder::BigEndian, +); + +make_field!( + /// `GF(340282366920938462946865773367900766209)`, a 128-bit field. + Field128, + u128, + FP128, + 16, + ByteOrder::BigEndian, +); + +/// Merge two vectors of fields by summing other_vector into accumulator. +/// +/// # Errors +/// +/// Fails if the two vectors do not have the same length. +#[cfg(any(test, feature = "prio2"))] +pub(crate) fn merge_vector<F: FieldElement>( + accumulator: &mut [F], + other_vector: &[F], +) -> Result<(), FieldError> { + if accumulator.len() != other_vector.len() { + return Err(FieldError::InputSizeMismatch); + } + for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) { + *a += *o; + } + + Ok(()) +} + +/// Outputs an additive secret sharing of the input. +#[cfg(feature = "crypto-dependencies")] +pub(crate) fn split_vector<F: FieldElement>( + inp: &[F], + num_shares: usize, +) -> Result<Vec<Vec<F>>, PrngError> { + if num_shares == 0 { + return Ok(vec![]); + } + + let mut outp = Vec::with_capacity(num_shares); + outp.push(inp.to_vec()); + + for _ in 1..num_shares { + let share: Vec<F> = random_vector(inp.len())?; + for (x, y) in outp[0].iter_mut().zip(&share) { + *x -= *y; + } + outp.push(share); + } + + Ok(outp) +} + +/// Generate a vector of uniform random field elements. +#[cfg(feature = "crypto-dependencies")] +pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> { + Ok(Prng::new()?.take(len).collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::fp::MAX_ROOTS; + use crate::prng::Prng; + use assert_matches::assert_matches; + use std::collections::hash_map::DefaultHasher; + + #[test] + fn test_endianness() { + let little_endian_encoded: [u8; FieldPrio2::ENCODED_SIZE] = + FieldPrio2(0x12_34_56_78).into(); + + let mut big_endian_encoded: [u8; Field32::ENCODED_SIZE] = Field32(0x12_34_56_78).into(); + big_endian_encoded.reverse(); + + assert_eq!(little_endian_encoded, big_endian_encoded); + } + + #[test] + fn test_accumulate() { + let mut lhs = vec![Field32(1); 10]; + let rhs = vec![Field32(2); 10]; + + merge_vector(&mut lhs, &rhs).unwrap(); + + lhs.iter().for_each(|f| assert_eq!(*f, Field32(3))); + rhs.iter().for_each(|f| assert_eq!(*f, Field32(2))); + + let wrong_len = vec![Field32::zero(); 9]; + let result = merge_vector(&mut lhs, &wrong_len); + assert_matches!(result, Err(FieldError::InputSizeMismatch)); + } + + fn hash_helper<H: Hash>(input: H) -> u64 { + let mut hasher = DefaultHasher::new(); + input.hash(&mut hasher); + hasher.finish() + } + + // Some of the checks in this function, like `assert_eq!(one - one, zero)` + // or `assert_eq!(two / two, one)` trip this clippy lint for tautological + // comparisons, but we have a legitimate need to verify these basics. We put + // the #[allow] on the whole function since "attributes on expressions are + // experimental" https://github.com/rust-lang/rust/issues/15701 + #[allow(clippy::eq_op)] + fn field_element_test<F: FieldElement + Hash>() { + let mut prng: Prng<F, _> = Prng::new().unwrap(); + let int_modulus = F::modulus(); + let int_one = F::Integer::try_from(1).unwrap(); + let zero = F::zero(); + let one = F::one(); + let two = F::from(F::Integer::try_from(2).unwrap()); + let four = F::from(F::Integer::try_from(4).unwrap()); + + // add + assert_eq!(F::from(int_modulus - int_one) + one, zero); + assert_eq!(one + one, two); + assert_eq!(two + F::from(int_modulus), two); + + // sub + assert_eq!(zero - one, F::from(int_modulus - int_one)); + assert_eq!(one - one, zero); + assert_eq!(two - F::from(int_modulus), two); + assert_eq!(one - F::from(int_modulus - int_one), two); + + // add + sub + for _ in 0..100 { + let f = prng.get(); + let g = prng.get(); + assert_eq!(f + g - f - g, zero); + assert_eq!(f + g - g, f); + assert_eq!(f + g - f, g); + } + + // mul + assert_eq!(two * two, four); + assert_eq!(two * one, two); + assert_eq!(two * zero, zero); + assert_eq!(one * F::from(int_modulus), zero); + + // div + assert_eq!(four / two, two); + assert_eq!(two / two, one); + assert_eq!(zero / two, zero); + assert_eq!(two / zero, zero); // Undefined behavior + assert_eq!(zero.inv(), zero); // Undefined behavior + + // mul + div + for _ in 0..100 { + let f = prng.get(); + if f == zero { + continue; + } + assert_eq!(f * f.inv(), one); + assert_eq!(f.inv() * f, one); + } + + // pow + assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one); + assert_eq!(two.pow(int_one), two); + assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four); + assert_eq!(two.pow(int_modulus - int_one), one); + assert_eq!(two.pow(int_modulus), two); + + // roots + let mut int_order = F::generator_order(); + for l in 0..MAX_ROOTS + 1 { + assert_eq!( + F::generator().pow(int_order), + F::root(l).unwrap(), + "failure for F::root({})", + l + ); + int_order = int_order >> int_one; + } + + // serialization + let test_inputs = vec![zero, one, prng.get(), F::from(int_modulus - int_one)]; + for want in test_inputs.iter() { + let mut bytes = vec![]; + want.encode(&mut bytes); + + assert_eq!(bytes.len(), F::ENCODED_SIZE); + + let got = F::get_decoded(&bytes).unwrap(); + assert_eq!(got, *want); + } + + let serialized_vec = F::slice_into_byte_vec(&test_inputs); + let deserialized = F::byte_slice_into_vec(&serialized_vec).unwrap(); + assert_eq!(deserialized, test_inputs); + + // equality and hash: Generate many elements, confirm they are not equal, and confirm + // various products that should be equal have the same hash. Three is chosen as a generator + // here because it happens to generate fairly large subgroups of (Z/pZ)* for all four + // primes. + let three = F::from(F::Integer::try_from(3).unwrap()); + let mut powers_of_three = Vec::with_capacity(500); + let mut power = one; + for _ in 0..500 { + powers_of_three.push(power); + power *= three; + } + // Check all these elements are mutually not equal. + for i in 0..powers_of_three.len() { + let first = &powers_of_three[i]; + for second in &powers_of_three[0..i] { + assert_ne!(first, second); + } + } + + // Check that 3^i is the same whether it's calculated with pow() or repeated + // multiplication, with both equality and hash equality. + for (i, power) in powers_of_three.iter().enumerate() { + let result = three.pow(F::Integer::try_from(i).unwrap()); + assert_eq!(result, *power); + let hash1 = hash_helper(power); + let hash2 = hash_helper(result); + assert_eq!(hash1, hash2); + } + + // Check that 3^n = (3^i)*(3^(n-i)), via both equality and hash equality. + let expected_product = powers_of_three[powers_of_three.len() - 1]; + let expected_hash = hash_helper(expected_product); + for i in 0..powers_of_three.len() { + let a = powers_of_three[i]; + let b = powers_of_three[powers_of_three.len() - 1 - i]; + let product = a * b; + assert_eq!(product, expected_product); + assert_eq!(hash_helper(product), expected_hash); + } + + // Construct an element from a number that needs to be reduced, and test comparisons on it, + // confirming that FieldParameters::montgomery() reduced it correctly. + let p = F::from(int_modulus); + assert_eq!(p, zero); + assert_eq!(hash_helper(p), hash_helper(zero)); + let p_plus_one = F::from(int_modulus + F::Integer::try_from(1).unwrap()); + assert_eq!(p_plus_one, one); + assert_eq!(hash_helper(p_plus_one), hash_helper(one)); + } + + #[test] + fn test_field32() { + field_element_test::<Field32>(); + } + + #[test] + fn test_field_priov2() { + field_element_test::<FieldPrio2>(); + } + + #[test] + fn test_field64() { + field_element_test::<Field64>(); + } + + #[test] + fn test_field96() { + field_element_test::<Field96>(); + } + + #[test] + fn test_field128() { + field_element_test::<Field128>(); + } +} diff --git a/third_party/rust/prio/src/flp.rs b/third_party/rust/prio/src/flp.rs new file mode 100644 index 0000000000..7f37347ca3 --- /dev/null +++ b/third_party/rust/prio/src/flp.rs @@ -0,0 +1,1035 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of the generic Fully Linear Proof (FLP) system specified in +//! [[draft-irtf-cfrg-vdaf-03]]. This is the main building block of [`Prio3`](crate::vdaf::prio3). +//! +//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation +//! specifies a validity circuit that defines the set of valid measurements, as well as the finite +//! field in which the validity circuit is evaluated. It also determines how raw measurements are +//! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of +//! measurements. +//! +//! # Overview +//! +//! The proof system is comprised of three algorithms. The first, `prove`, is run by the prover in +//! order to generate a proof of a statement's validity. The second and third, `query` and +//! `decide`, are run by the verifier in order to check the proof. The proof asserts that the input +//! is an element of a language recognized by the arithmetic circuit. If an input is _not_ valid, +//! then the verification step will fail with high probability: +//! +//! ``` +//! use prio::flp::types::Count; +//! use prio::flp::Type; +//! use prio::field::{random_vector, FieldElement, Field64}; +//! +//! // The prover chooses a measurement. +//! let count = Count::new(); +//! let input: Vec<Field64> = count.encode_measurement(&0).unwrap(); +//! +//! // The prover and verifier agree on "joint randomness" used to generate and +//! // check the proof. The application needs to ensure that the prover +//! // "commits" to the input before this point. In Prio3, the joint +//! // randomness is derived from additive shares of the input. +//! let joint_rand = random_vector(count.joint_rand_len()).unwrap(); +//! +//! // The prover generates the proof. +//! let prove_rand = random_vector(count.prove_rand_len()).unwrap(); +//! let proof = count.prove(&input, &prove_rand, &joint_rand).unwrap(); +//! +//! // The verifier checks the proof. In the first step, the verifier "queries" +//! // the input and proof, getting the "verifier message" in response. It then +//! // inspects the verifier to decide if the input is valid. +//! let query_rand = random_vector(count.query_rand_len()).unwrap(); +//! let verifier = count.query(&input, &proof, &query_rand, &joint_rand, 1).unwrap(); +//! assert!(count.decide(&verifier).unwrap()); +//! ``` +//! +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError}; +use crate::field::{FieldElement, FieldError}; +use crate::fp::log2; +use crate::polynomial::poly_eval; +use std::any::Any; +use std::convert::TryFrom; +use std::fmt::Debug; + +pub mod gadgets; +pub mod types; + +/// Errors propagated by methods in this module. +#[derive(Debug, thiserror::Error)] +pub enum FlpError { + /// Calling [`Type::prove`] returned an error. + #[error("prove error: {0}")] + Prove(String), + + /// Calling [`Type::query`] returned an error. + #[error("query error: {0}")] + Query(String), + + /// Calling [`Type::decide`] returned an error. + #[error("decide error: {0}")] + Decide(String), + + /// Calling a gadget returned an error. + #[error("gadget error: {0}")] + Gadget(String), + + /// Calling the validity circuit returned an error. + #[error("validity circuit error: {0}")] + Valid(String), + + /// Calling [`Type::encode_measurement`] returned an error. + #[error("value error: {0}")] + Encode(String), + + /// Calling [`Type::decode_result`] returned an error. + #[error("value error: {0}")] + Decode(String), + + /// Calling [`Type::truncate`] returned an error. + #[error("truncate error: {0}")] + Truncate(String), + + /// Returned if an FFT operation propagates an error. + #[error("FFT error: {0}")] + Fft(#[from] FftError), + + /// Returned if a field operation encountered an error. + #[error("Field error: {0}")] + Field(#[from] FieldError), + + /// Unit test error. + #[cfg(test)] + #[error("test failed: {0}")] + Test(String), +} + +/// A type. Implementations of this trait specify how a particular kind of measurement is encoded +/// as a vector of field elements and how validity of the encoded measurement is determined. +/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement. +pub trait Type: Sized + Eq + Clone + Debug { + /// The Prio3 VDAF identifier corresponding to this type. + const ID: u32; + + /// The type of raw measurement to be encoded. + type Measurement: Clone + Debug; + + /// The type of aggregate result for this type. + type AggregateResult: Clone + Debug; + + /// The finite field used for this type. + type Field: FieldElement; + + /// Encodes a measurement as a vector of [`Self::input_len`] field elements. + fn encode_measurement( + &self, + measurement: &Self::Measurement, + ) -> Result<Vec<Self::Field>, FlpError>; + + /// Decode an aggregate result. + fn decode_result( + &self, + data: &[Self::Field], + num_measurements: usize, + ) -> Result<Self::AggregateResult, FlpError>; + + /// Returns the sequence of gadgets associated with the validity circuit. + /// + /// # Notes + /// + /// The construction of [[BBCG+19], Theorem 4.3] uses a single gadget rather than many. The + /// idea to generalize the proof system to allow multiple gadgets is discussed briefly in + /// [[BBCG+19], Remark 4.5], but no construction is given. The construction implemented here + /// requires security analysis. + /// + /// [BBCG+19]: https://ia.cr/2019/188 + fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>; + + /// Evaluates the validity circuit on an input and returns the output. + /// + /// # Parameters + /// + /// * `gadgets` is the sequence of gadgets, presumably output by [`Self::gadget`]. + /// * `input` is the input to be validated. + /// * `joint_rand` is the joint randomness shared by the prover and verifier. + /// * `num_shares` is the number of input shares. + /// + /// # Example usage + /// + /// Applications typically do not call this method directly. It is used internally by + /// [`Self::prove`] and [`Self::query`] to generate and verify the proof respectively. + /// + /// ``` + /// use prio::flp::types::Count; + /// use prio::flp::Type; + /// use prio::field::{random_vector, FieldElement, Field64}; + /// + /// let count = Count::new(); + /// let input: Vec<Field64> = count.encode_measurement(&1).unwrap(); + /// let joint_rand = random_vector(count.joint_rand_len()).unwrap(); + /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap(); + /// assert_eq!(v, Field64::zero()); + /// ``` + fn valid( + &self, + gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>, + input: &[Self::Field], + joint_rand: &[Self::Field], + num_shares: usize, + ) -> Result<Self::Field, FlpError>; + + /// Constructs an aggregatable output from an encoded input. Calling this method is only safe + /// once `input` has been validated. + fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>; + + /// The length in field elements of the encoded input returned by [`Self::encode_measurement`]. + fn input_len(&self) -> usize; + + /// The length in field elements of the proof generated for this type. + fn proof_len(&self) -> usize; + + /// The length in field elements of the verifier message constructed by [`Self::query`]. + fn verifier_len(&self) -> usize; + + /// The length of the truncated output (i.e., the output of [`Type::truncate`]). + fn output_len(&self) -> usize; + + /// The length of the joint random input. + fn joint_rand_len(&self) -> usize; + + /// The length in field elements of the random input consumed by the prover to generate a + /// proof. This is the same as the sum of the arity of each gadget in the validity circuit. + fn prove_rand_len(&self) -> usize; + + /// The length in field elements of the random input consumed by the verifier to make queries + /// against inputs and proofs. This is the same as the number of gadgets in the validity + /// circuit. + fn query_rand_len(&self) -> usize; + + /// Generate a proof of an input's validity. The return value is a sequence of + /// [`Self::proof_len`] field elements. + /// + /// # Parameters + /// + /// * `input` is the input. + /// * `prove_rand` is the prover' randomness. + /// * `joint_rand` is the randomness shared by the prover and verifier. + #[allow(clippy::needless_range_loop)] + fn prove( + &self, + input: &[Self::Field], + prove_rand: &[Self::Field], + joint_rand: &[Self::Field], + ) -> Result<Vec<Self::Field>, FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Prove(format!( + "unexpected input length: got {}; want {}", + input.len(), + self.input_len() + ))); + } + + if prove_rand.len() != self.prove_rand_len() { + return Err(FlpError::Prove(format!( + "unexpected prove randomness length: got {}; want {}", + prove_rand.len(), + self.prove_rand_len() + ))); + } + + if joint_rand.len() != self.joint_rand_len() { + return Err(FlpError::Prove(format!( + "unexpected joint randomness length: got {}; want {}", + joint_rand.len(), + self.joint_rand_len() + ))); + } + + let mut prove_rand_len = 0; + let mut shim = self + .gadget() + .into_iter() + .map(|inner| { + let inner_arity = inner.arity(); + if prove_rand_len + inner_arity > prove_rand.len() { + return Err(FlpError::Prove(format!( + "short prove randomness: got {}; want {}", + prove_rand.len(), + self.prove_rand_len() + ))); + } + + let gadget = Box::new(ProveShimGadget::new( + inner, + &prove_rand[prove_rand_len..prove_rand_len + inner_arity], + )?) as Box<dyn Gadget<Self::Field>>; + prove_rand_len += inner_arity; + + Ok(gadget) + }) + .collect::<Result<Vec<_>, FlpError>>()?; + assert_eq!(prove_rand_len, self.prove_rand_len()); + + // Create a buffer for storing the proof. The buffer is longer than the proof itself; the extra + // length is to accommodate the computation of each gadget polynomial. + let data_len = (0..shim.len()) + .map(|idx| { + let gadget_poly_len = + gadget_poly_len(shim[idx].degree(), wire_poly_len(shim[idx].calls())); + + // Computing the gadget polynomial using FFT requires an amount of memory that is a + // power of 2. Thus we choose the smallest power of 2 that is at least as large as + // the gadget polynomial. The wire seeds are encoded in the proof, too, so we + // include the arity of the gadget to ensure there is always enough room at the end + // of the buffer to compute the next gadget polynomial. It's likely that the + // memory footprint here can be reduced, with a bit of care. + shim[idx].arity() + gadget_poly_len.next_power_of_two() + }) + .sum(); + let mut proof = vec![Self::Field::zero(); data_len]; + + // Run the validity circuit with a sequence of "shim" gadgets that record the value of each + // input wire of each gadget evaluation. These values are used to construct the wire + // polynomials for each gadget in the next step. + let _ = self.valid(&mut shim, input, joint_rand, 1)?; + + // Construct the proof. + let mut proof_len = 0; + for idx in 0..shim.len() { + let gadget = shim[idx] + .as_any() + .downcast_mut::<ProveShimGadget<Self::Field>>() + .unwrap(); + + // Interpolate the wire polynomials `f[0], ..., f[g_arity-1]` from the input wires of each + // evaluation of the gadget. + let m = wire_poly_len(gadget.calls()); + let m_inv = + Self::Field::from(<Self::Field as FieldElement>::Integer::try_from(m).unwrap()) + .inv(); + let mut f = vec![vec![Self::Field::zero(); m]; gadget.arity()]; + for wire in 0..gadget.arity() { + discrete_fourier_transform(&mut f[wire], &gadget.f_vals[wire], m)?; + discrete_fourier_transform_inv_finish(&mut f[wire], m, m_inv); + + // The first point on each wire polynomial is a random value chosen by the prover. This + // point is stored in the proof so that the verifier can reconstruct the wire + // polynomials. + proof[proof_len + wire] = gadget.f_vals[wire][0]; + } + + // Construct the gadget polynomial `G(f[0], ..., f[g_arity-1])` and append it to `proof`. + let gadget_poly_len = gadget_poly_len(gadget.degree(), m); + let start = proof_len + gadget.arity(); + let end = start + gadget_poly_len.next_power_of_two(); + gadget.call_poly(&mut proof[start..end], &f)?; + proof_len += gadget.arity() + gadget_poly_len; + } + + // Truncate the buffer to the size of the proof. + assert_eq!(proof_len, self.proof_len()); + proof.truncate(proof_len); + Ok(proof) + } + + /// Query an input and proof and return the verifier message. The return value has length + /// [`Self::verifier_len`]. + /// + /// # Parameters + /// + /// * `input` is the input or input share. + /// * `proof` is the proof or proof share. + /// * `query_rand` is the verifier's randomness. + /// * `joint_rand` is the randomness shared by the prover and verifier. + /// * `num_shares` is the total number of input shares. + fn query( + &self, + input: &[Self::Field], + proof: &[Self::Field], + query_rand: &[Self::Field], + joint_rand: &[Self::Field], + num_shares: usize, + ) -> Result<Vec<Self::Field>, FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Query(format!( + "unexpected input length: got {}; want {}", + input.len(), + self.input_len() + ))); + } + + if proof.len() != self.proof_len() { + return Err(FlpError::Query(format!( + "unexpected proof length: got {}; want {}", + proof.len(), + self.proof_len() + ))); + } + + if query_rand.len() != self.query_rand_len() { + return Err(FlpError::Query(format!( + "unexpected query randomness length: got {}; want {}", + query_rand.len(), + self.query_rand_len() + ))); + } + + if joint_rand.len() != self.joint_rand_len() { + return Err(FlpError::Query(format!( + "unexpected joint randomness length: got {}; want {}", + joint_rand.len(), + self.joint_rand_len() + ))); + } + + let mut proof_len = 0; + let mut shim = self + .gadget() + .into_iter() + .enumerate() + .map(|(idx, gadget)| { + let gadget_degree = gadget.degree(); + let gadget_arity = gadget.arity(); + let m = (1 + gadget.calls()).next_power_of_two(); + let r = query_rand[idx]; + + // Make sure the query randomness isn't a root of unity. Evaluating the gadget + // polynomial at any of these points would be a privacy violation, since these points + // were used by the prover to construct the wire polynomials. + if r.pow(<Self::Field as FieldElement>::Integer::try_from(m).unwrap()) + == Self::Field::one() + { + return Err(FlpError::Query(format!( + "invalid query randomness: encountered 2^{}-th root of unity", + m + ))); + } + + // Compute the length of the sub-proof corresponding to the `idx`-th gadget. + let next_len = gadget_arity + gadget_degree * (m - 1) + 1; + let proof_data = &proof[proof_len..proof_len + next_len]; + proof_len += next_len; + + Ok(Box::new(QueryShimGadget::new(gadget, r, proof_data)?) + as Box<dyn Gadget<Self::Field>>) + }) + .collect::<Result<Vec<_>, _>>()?; + + // Create a buffer for the verifier data. This includes the output of the validity circuit and, + // for each gadget `shim[idx].inner`, the wire polynomials evaluated at the query randomness + // `query_rand[idx]` and the gadget polynomial evaluated at `query_rand[idx]`. + let data_len = 1 + + (0..shim.len()) + .map(|idx| shim[idx].arity() + 1) + .sum::<usize>(); + let mut verifier = Vec::with_capacity(data_len); + + // Run the validity circuit with a sequence of "shim" gadgets that record the inputs to each + // wire for each gadget call. Record the output of the circuit and append it to the verifier + // message. + // + // NOTE The proof of [BBC+19, Theorem 4.3] assumes that the output of the validity circuit is + // equal to the output of the last gadget evaluation. Here we relax this assumption. This + // should be OK, since it's possible to transform any circuit into one for which this is true. + // (Needs security analysis.) + let validity = self.valid(&mut shim, input, joint_rand, num_shares)?; + verifier.push(validity); + + // Fill the buffer with the verifier message. + for idx in 0..shim.len() { + let r = query_rand[idx]; + let gadget = shim[idx] + .as_any() + .downcast_ref::<QueryShimGadget<Self::Field>>() + .unwrap(); + + // Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire + // polynomial at query randomness `r`. + let m = (1 + gadget.calls()).next_power_of_two(); + let m_inv = + Self::Field::from(<Self::Field as FieldElement>::Integer::try_from(m).unwrap()) + .inv(); + let mut f = vec![Self::Field::zero(); m]; + for wire in 0..gadget.arity() { + discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?; + discrete_fourier_transform_inv_finish(&mut f, m, m_inv); + verifier.push(poly_eval(&f, r)); + } + + // Add the value of the gadget polynomial evaluated at `r`. + verifier.push(gadget.p_at_r); + } + + assert_eq!(verifier.len(), self.verifier_len()); + Ok(verifier) + } + + /// Returns true if the verifier message indicates that the input from which it was generated is valid. + #[allow(clippy::needless_range_loop)] + fn decide(&self, verifier: &[Self::Field]) -> Result<bool, FlpError> { + if verifier.len() != self.verifier_len() { + return Err(FlpError::Decide(format!( + "unexpected verifier length: got {}; want {}", + verifier.len(), + self.verifier_len() + ))); + } + + // Check if the output of the circuit is 0. + if verifier[0] != Self::Field::zero() { + return Ok(false); + } + + // Check that each of the proof polynomials are well-formed. + let mut gadgets = self.gadget(); + let mut verifier_len = 1; + for idx in 0..gadgets.len() { + let next_len = 1 + gadgets[idx].arity(); + + let e = gadgets[idx].call(&verifier[verifier_len..verifier_len + next_len - 1])?; + if e != verifier[verifier_len + next_len - 1] { + return Ok(false); + } + + verifier_len += next_len; + } + + Ok(true) + } + + /// Check whether `input` and `joint_rand` have the length expected by `self`, + /// return [`FlpError::Valid`] otherwise. + fn valid_call_check( + &self, + input: &[Self::Field], + joint_rand: &[Self::Field], + ) -> Result<(), FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Valid(format!( + "unexpected input length: got {}; want {}", + input.len(), + self.input_len(), + ))); + } + + if joint_rand.len() != self.joint_rand_len() { + return Err(FlpError::Valid(format!( + "unexpected joint randomness length: got {}; want {}", + joint_rand.len(), + self.joint_rand_len() + ))); + } + + Ok(()) + } + + /// Check if the length of `input` matches `self`'s `input_len()`, + /// return [`FlpError::Truncate`] otherwise. + fn truncate_call_check(&self, input: &[Self::Field]) -> Result<(), FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Truncate(format!( + "Unexpected input length: got {}; want {}", + input.len(), + self.input_len() + ))); + } + + Ok(()) + } +} + +/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit. +pub trait Gadget<F: FieldElement>: Debug { + /// Evaluates the gadget on input `inp` and returns the output. + fn call(&mut self, inp: &[F]) -> Result<F, FlpError>; + + /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`. + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>; + + /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or + /// `call_poly`. + fn arity(&self) -> usize; + + /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp` + /// buffer passed to `call_poly`. + fn degree(&self) -> usize; + + /// Returns the number of times the gadget is expected to be called. + fn calls(&self) -> usize; + + /// This call is used to downcast a `Box<dyn Gadget<F>>` to a concrete type. + fn as_any(&mut self) -> &mut dyn Any; +} + +// A "shim" gadget used during proof generation to record the input wires each time a gadget is +// evaluated. +#[derive(Debug)] +struct ProveShimGadget<F: FieldElement> { + inner: Box<dyn Gadget<F>>, + + /// Points at which the wire polynomials are interpolated. + f_vals: Vec<Vec<F>>, + + /// The number of times the gadget has been called so far. + ct: usize, +} + +impl<F: FieldElement> ProveShimGadget<F> { + fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, FlpError> { + let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()]; + + #[allow(clippy::needless_range_loop)] + for wire in 0..f_vals.len() { + // Choose a random field element as the first point on the wire polynomial. + f_vals[wire][0] = prove_rand[wire]; + } + + Ok(Self { + inner, + f_vals, + ct: 1, + }) + } +} + +impl<F: FieldElement> Gadget<F> for ProveShimGadget<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + #[allow(clippy::needless_range_loop)] + for wire in 0..inp.len() { + self.f_vals[wire][self.ct] = inp[wire]; + } + self.ct += 1; + self.inner.call(inp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + self.inner.call_poly(outp, inp) + } + + fn arity(&self) -> usize { + self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +// A "shim" gadget used during proof verification to record the points at which the intermediate +// proof polynomials are evaluated. +#[derive(Debug)] +struct QueryShimGadget<F: FieldElement> { + inner: Box<dyn Gadget<F>>, + + /// Points at which intermediate proof polynomials are interpolated. + f_vals: Vec<Vec<F>>, + + /// Points at which the gadget polynomial is interpolated. + p_vals: Vec<F>, + + /// The gadget polynomial evaluated on a random input `r`. + p_at_r: F, + + /// Used to compute an index into `p_val`. + step: usize, + + /// The number of times the gadget has been called so far. + ct: usize, +} + +impl<F: FieldElement> QueryShimGadget<F> { + fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, FlpError> { + let gadget_degree = inner.degree(); + let gadget_arity = inner.arity(); + let m = (1 + inner.calls()).next_power_of_two(); + let p = m * gadget_degree; + + // Each call to this gadget records the values at which intermediate proof polynomials were + // interpolated. The first point was a random value chosen by the prover and transmitted in + // the proof. + let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; gadget_arity]; + for wire in 0..gadget_arity { + f_vals[wire][0] = proof_data[wire]; + } + + // Evaluate the gadget polynomial at roots of unity. + let size = p.next_power_of_two(); + let mut p_vals = vec![F::zero(); size]; + discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?; + + // The step is used to compute the element of `p_val` that will be returned by a call to + // the gadget. + let step = (1 << (log2(p as u128) - log2(m as u128))) as usize; + + // Evaluate the gadget polynomial `p` at query randomness `r`. + let p_at_r = poly_eval(&proof_data[gadget_arity..], r); + + Ok(Self { + inner, + f_vals, + p_vals, + p_at_r, + step, + ct: 1, + }) + } +} + +impl<F: FieldElement> Gadget<F> for QueryShimGadget<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + #[allow(clippy::needless_range_loop)] + for wire in 0..inp.len() { + self.f_vals[wire][self.ct] = inp[wire]; + } + let outp = self.p_vals[self.ct * self.step]; + self.ct += 1; + Ok(outp) + } + + fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> { + panic!("no-op"); + } + + fn arity(&self) -> usize { + self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// Compute the length of the wire polynomial constructed from the given number of gadget calls. +#[inline] +pub(crate) fn wire_poly_len(num_calls: usize) -> usize { + (1 + num_calls).next_power_of_two() +} + +/// Compute the length of the gadget polynomial for a gadget with the given degree and from wire +/// polynomials of the given length. +#[inline] +pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usize { + gadget_degree * (wire_poly_len - 1) + 1 +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{random_vector, split_vector, Field128}; + use crate::flp::gadgets::{Mul, PolyEval}; + use crate::polynomial::poly_range_check; + + use std::marker::PhantomData; + + // Simple integration test for the core FLP logic. You'll find more extensive unit tests for + // each implemented data type in src/types.rs. + #[test] + fn test_flp() { + const NUM_SHARES: usize = 2; + + let typ: TestType<Field128> = TestType::new(); + let input = typ.encode_measurement(&3).unwrap(); + assert_eq!(input.len(), typ.input_len()); + + let input_shares: Vec<Vec<Field128>> = split_vector(input.as_slice(), NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + assert_eq!(proof.len(), typ.proof_len()); + + let proof_shares: Vec<Vec<Field128>> = split_vector(&proof, NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let verifier: Vec<Field128> = (0..NUM_SHARES) + .map(|i| { + typ.query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + NUM_SHARES, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + + assert!(typ.decide(&verifier).unwrap()); + } + + /// A toy type used for testing multiple gadgets. Valid inputs of this type consist of a pair + /// of field elements `(x, y)` where `2 <= x < 5` and `x^3 == y`. + #[derive(Clone, Debug, PartialEq, Eq)] + struct TestType<F>(PhantomData<F>); + + impl<F> TestType<F> { + fn new() -> Self { + Self(PhantomData) + } + } + + impl<F: FieldElement> Type for TestType<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + let r = joint_rand[0]; + let mut res = F::zero(); + + // Check that `data[0]^3 == data[1]`. + let mut inp = [input[0], input[0]]; + inp[0] = g[0].call(&inp)?; + inp[0] = g[0].call(&inp)?; + let x3_diff = inp[0] - input[1]; + res += r * x3_diff; + + // Check that `data[0]` is in the correct range. + let x_checked = g[1].call(&[input[0]])?; + res += (r * r) * x_checked; + + Ok(res) + } + + fn input_len(&self) -> usize { + 2 + } + + fn proof_len(&self) -> usize { + // First chunk + let mul = 2 /* gadget arity */ + 2 /* gadget degree */ * ( + (1 + 2_usize /* gadget calls */).next_power_of_two() - 1) + 1; + + // Second chunk + let poly = 1 /* gadget arity */ + 3 /* gadget degree */ * ( + (1 + 1_usize /* gadget calls */).next_power_of_two() - 1) + 1; + + mul + poly + } + + fn verifier_len(&self) -> usize { + // First chunk + let mul = 1 + 2 /* gadget arity */; + + // Second chunk + let poly = 1 + 1 /* gadget arity */; + + 1 + mul + poly + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 3 + } + + fn query_rand_len(&self) -> usize { + 2 + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![ + Box::new(Mul::new(2)), + Box::new(PolyEval::new(poly_range_check(2, 5), 1)), + ] + } + + fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> { + Ok(vec![ + F::from(*measurement), + F::from(*measurement).pow(F::Integer::try_from(3).unwrap()), + ]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + Ok(input) + } + + fn decode_result( + &self, + _data: &[F], + _num_measurements: usize, + ) -> Result<F::Integer, FlpError> { + panic!("not implemented"); + } + } + + // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that + // gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than + // needed for computing the gadget polynomial. + #[test] + fn issue254() { + let typ: Issue254Type<Field128> = Issue254Type::new(); + let input = typ.encode_measurement(&0).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + #[derive(Clone, Debug, PartialEq, Eq)] + struct Issue254Type<F> { + num_gadget_calls: [usize; 2], + phantom: PhantomData<F>, + } + + impl<F> Issue254Type<F> { + fn new() -> Self { + Self { + // The bug is triggered when there are two gadgets, but it doesn't matter how many + // times the second gadget is called. + num_gadget_calls: [100, 0], + phantom: PhantomData, + } + } + } + + impl<F: FieldElement> Type for Issue254Type<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + _joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + // This is a useless circuit, as it only accepts "0". Its purpose is to exercise the + // use of multiple gadgets, each of which is called an arbitrary number of times. + let mut res = F::zero(); + for _ in 0..self.num_gadget_calls[0] { + res += g[0].call(&[input[0]])?; + } + for _ in 0..self.num_gadget_calls[1] { + res += g[1].call(&[input[0]])?; + } + Ok(res) + } + + fn input_len(&self) -> usize { + 1 + } + + fn proof_len(&self) -> usize { + // First chunk + let first = 1 /* gadget arity */ + 2 /* gadget degree */ * ( + (1 + self.num_gadget_calls[0]).next_power_of_two() - 1) + 1; + + // Second chunk + let second = 1 /* gadget arity */ + 2 /* gadget degree */ * ( + (1 + self.num_gadget_calls[1]).next_power_of_two() - 1) + 1; + + first + second + } + + fn verifier_len(&self) -> usize { + // First chunk + let first = 1 + 1 /* gadget arity */; + + // Second chunk + let second = 1 + 1 /* gadget arity */; + + 1 + first + second + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 0 + } + + fn prove_rand_len(&self) -> usize { + // First chunk + let first = 1; // gadget arity + + // Second chunk + let second = 1; // gadget arity + + first + second + } + + fn query_rand_len(&self) -> usize { + 2 // number of gadgets + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + let poly = poly_range_check(0, 2); // A polynomial with degree 2 + vec![ + Box::new(PolyEval::new(poly.clone(), self.num_gadget_calls[0])), + Box::new(PolyEval::new(poly, self.num_gadget_calls[1])), + ] + } + + fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> { + Ok(vec![F::from(*measurement)]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + Ok(input) + } + + fn decode_result( + &self, + _data: &[F], + _num_measurements: usize, + ) -> Result<F::Integer, FlpError> { + panic!("not implemented"); + } + } +} diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs new file mode 100644 index 0000000000..fd2be84eaa --- /dev/null +++ b/third_party/rust/prio/src/flp/gadgets.rs @@ -0,0 +1,715 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A collection of gadgets. + +use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; +use crate::field::FieldElement; +use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget}; +use crate::polynomial::{poly_deg, poly_eval, poly_mul}; + +#[cfg(feature = "multithreaded")] +use rayon::prelude::*; + +use std::any::Any; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::marker::PhantomData; + +/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for +/// polynomial multiplication. Otherwise, the gadget uses direct multiplication. +const FFT_THRESHOLD: usize = 60; + +/// An arity-2 gadget that multiples its inputs. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Mul<F: FieldElement> { + /// Size of buffer for FFT operations. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FieldElement> Mul<F> { + /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be + /// called by the validity circuit. + pub fn new(num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(2, num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + n, + n_inv, + num_calls, + } + } + + // Multiply input polynomials directly. + pub(crate) fn call_poly_direct( + &mut self, + outp: &mut [F], + inp: &[Vec<F>], + ) -> Result<(), FlpError> { + let v = poly_mul(&inp[0], &inp[1]); + outp[..v.len()].clone_from_slice(&v); + Ok(()) + } + + // Multiply input polynomials using FFT. + pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let mut buf = vec![F::zero(); n]; + + discrete_fourier_transform(&mut buf, &inp[0], n)?; + discrete_fourier_transform(outp, &inp[1], n)?; + + for i in 0..n { + buf[i] *= outp[i]; + } + + discrete_fourier_transform(outp, &buf, n)?; + discrete_fourier_transform_inv_finish(outp, n, self.n_inv); + Ok(()) + } +} + +impl<F: FieldElement> Gadget<F> for Mul<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(inp[0] * inp[1]) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 2 + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// An arity-1 gadget that evaluates its input on some polynomial. +// +// TODO Make `poly` an array of length determined by a const generic. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PolyEval<F: FieldElement> { + poly: Vec<F>, + /// Size of buffer for FFT operations. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FieldElement> PolyEval<F> { + /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times + /// this gadget is called by the validity circuit. + pub fn new(poly: Vec<F>, num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + poly, + n, + n_inv, + num_calls, + } + } +} + +impl<F: FieldElement> PolyEval<F> { + // Multiply input polynomials directly. + fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + outp[0] = self.poly[0]; + let mut x = inp[0].to_vec(); + for i in 1..self.poly.len() { + for j in 0..x.len() { + outp[j] += self.poly[i] * x[j]; + } + + if i < self.poly.len() - 1 { + x = poly_mul(&x, &inp[0]); + } + } + Ok(()) + } + + // Multiply input polynomials using FFT. + fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let inp = &inp[0]; + + let mut inp_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut inp_vals, inp, n)?; + + let mut x_vals = inp_vals.clone(); + let mut x = vec![F::zero(); n]; + x[..inp.len()].clone_from_slice(inp); + + outp[0] = self.poly[0]; + for i in 1..self.poly.len() { + for j in 0..n { + outp[j] += self.poly[i] * x[j]; + } + + if i < self.poly.len() - 1 { + for j in 0..n { + x_vals[j] *= inp_vals[j]; + } + + discrete_fourier_transform(&mut x, &x_vals, n)?; + discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv); + } + } + Ok(()) + } +} + +impl<F: FieldElement> Gadget<F> for PolyEval<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(poly_eval(&self.poly, inp[0])) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for item in outp.iter_mut() { + *item = F::zero(); + } + + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + poly_deg(&self.poly) + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// An arity-2 gadget that returns `poly(in[0]) * in[1]` for some polynomial `poly`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct BlindPolyEval<F: FieldElement> { + poly: Vec<F>, + /// Size of buffer for the outer FFT multiplication. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FieldElement> BlindPolyEval<F> { + /// Returns a `BlindPolyEval` gadget for polynomial `poly`. + pub fn new(poly: Vec<F>, num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(poly_deg(&poly) + 1, num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + poly, + n, + n_inv, + num_calls, + } + } + + fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let x = &inp[0]; + let y = &inp[1]; + + let mut z = y.to_vec(); + for i in 0..self.poly.len() { + for j in 0..z.len() { + outp[j] += self.poly[i] * z[j]; + } + + if i < self.poly.len() - 1 { + z = poly_mul(&z, x); + } + } + Ok(()) + } + + fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let x = &inp[0]; + let y = &inp[1]; + + let mut x_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut x_vals, x, n)?; + + let mut z_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut z_vals, y, n)?; + + let mut z = vec![F::zero(); n]; + let mut z_len = y.len(); + z[..y.len()].clone_from_slice(y); + + for i in 0..self.poly.len() { + for j in 0..z_len { + outp[j] += self.poly[i] * z[j]; + } + + if i < self.poly.len() - 1 { + for j in 0..n { + z_vals[j] *= x_vals[j]; + } + + discrete_fourier_transform(&mut z, &z_vals, n)?; + discrete_fourier_transform_inv_finish(&mut z, n, self.n_inv); + z_len += x.len(); + } + } + Ok(()) + } +} + +impl<F: FieldElement> Gadget<F> for BlindPolyEval<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(inp[1] * poly_eval(&self.poly, inp[0])) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for x in outp.iter_mut() { + *x = F::zero(); + } + + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + poly_deg(&self.poly) + 1 + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// Marker trait for abstracting over [`ParallelSum`]. +pub trait ParallelSumGadget<F: FieldElement, G>: Gadget<F> + Debug { + /// Wraps `inner` into a sum gadget with `chunks` chunks + fn new(inner: G, chunks: usize) -> Self; +} + +/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the +/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ParallelSum<F: FieldElement, G: Gadget<F>> { + inner: G, + chunks: usize, + phantom: PhantomData<F>, +} + +impl<F: FieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> for ParallelSum<F, G> { + fn new(inner: G, chunks: usize) -> Self { + Self { + inner, + chunks, + phantom: PhantomData, + } + } +} + +impl<F: FieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + let mut outp = F::zero(); + for chunk in inp.chunks(self.inner.arity()) { + outp += self.inner.call(chunk)?; + } + Ok(outp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for x in outp.iter_mut() { + *x = F::zero(); + } + + let mut partial_outp = vec![F::zero(); outp.len()]; + + for chunk in inp.chunks(self.inner.arity()) { + self.inner.call_poly(&mut partial_outp, chunk)?; + for i in 0..outp.len() { + outp[i] += partial_outp[i] + } + } + + Ok(()) + } + + fn arity(&self) -> usize { + self.chunks * self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the +/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum +/// evaluation is multithreaded. +#[cfg(feature = "multithreaded")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ParallelSumMultithreaded<F: FieldElement, G: Gadget<F>> { + serial_sum: ParallelSum<F, G>, +} + +#[cfg(feature = "multithreaded")] +impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G> +where + F: FieldElement + Sync + Send, + G: 'static + Gadget<F> + Clone + Sync + Send, +{ + fn new(inner: G, chunks: usize) -> Self { + Self { + serial_sum: ParallelSum::new(inner, chunks), + } + } +} + +/// Data structures passed between fold operations in [`ParallelSumMultithreaded`]. +#[cfg(feature = "multithreaded")] +struct ParallelSumFoldState<F, G> { + /// Inner gadget. + inner: G, + /// Output buffer for `call_poly()`. + partial_output: Vec<F>, + /// Sum accumulator. + partial_sum: Vec<F>, +} + +#[cfg(feature = "multithreaded")] +impl<F, G> ParallelSumFoldState<F, G> { + fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G> + where + G: Clone, + F: FieldElement, + { + ParallelSumFoldState { + inner: gadget.clone(), + partial_output: vec![F::zero(); length], + partial_sum: vec![F::zero(); length], + } + } +} + +#[cfg(feature = "multithreaded")] +impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G> +where + F: FieldElement + Sync + Send, + G: 'static + Gadget<F> + Clone + Sync + Send, +{ + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + self.serial_sum.call(inp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the + // gadget on each input polynomial, using the first temporary buffer as an output buffer. + // Then accumulate that result into the second temporary buffer, which acts as a running + // sum. Then, discard everything but the partial sums, add them, and finally copy the sum + // to the output parameter. This is equivalent to the single threaded calculation in + // ParallelSum, since we only rearrange additions, and field addition is associative. + let res = inp + .par_chunks(self.serial_sum.inner.arity()) + .fold( + || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()), + |mut state, chunk| { + state + .inner + .call_poly(&mut state.partial_output, chunk) + .unwrap(); + for (sum_elem, output_elem) in state + .partial_sum + .iter_mut() + .zip(state.partial_output.iter()) + { + *sum_elem += *output_elem; + } + state + }, + ) + .map(|state| state.partial_sum) + .reduce( + || vec![F::zero(); outp.len()], + |mut x, y| { + for (xi, yi) in x.iter_mut().zip(y.iter()) { + *xi += *yi; + } + x + }, + ); + + outp.copy_from_slice(&res[..]); + Ok(()) + } + + fn arity(&self) -> usize { + self.serial_sum.arity() + } + + fn degree(&self) -> usize { + self.serial_sum.degree() + } + + fn calls(&self) -> usize { + self.serial_sum.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +// Check that the input parameters of g.call() are well-formed. +fn gadget_call_check<F: FieldElement, G: Gadget<F>>( + gadget: &G, + in_len: usize, +) -> Result<(), FlpError> { + if in_len != gadget.arity() { + return Err(FlpError::Gadget(format!( + "unexpected number of inputs: got {}; want {}", + in_len, + gadget.arity() + ))); + } + + if in_len == 0 { + return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string())); + } + + Ok(()) +} + +// Check that the input parameters of g.call_poly() are well-formed. +fn gadget_call_poly_check<F: FieldElement, G: Gadget<F>>( + gadget: &G, + outp: &[F], + inp: &[Vec<F>], +) -> Result<(), FlpError> +where + G: Gadget<F>, +{ + gadget_call_check(gadget, inp.len())?; + + for i in 1..inp.len() { + if inp[i].len() != inp[0].len() { + return Err(FlpError::Gadget( + "gadget called on wire polynomials with different lengths".to_string(), + )); + } + } + + let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two(); + if outp.len() != expected { + return Err(FlpError::Gadget(format!( + "incorrect output length: got {}; want {}", + outp.len(), + expected + ))); + } + + Ok(()) +} + +#[inline] +fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize { + gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two() +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::field::{random_vector, Field96 as TestField}; + use crate::prng::Prng; + + #[test] + fn test_mul() { + // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the + // naive multiplication code path. + let num_calls = FFT_THRESHOLD / 2; + let mut g: Mul<TestField> = Mul::new(num_calls); + gadget_test(&mut g, num_calls); + + // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises + // FFT-based polynomial multiplication. + let num_calls = FFT_THRESHOLD; + let mut g: Mul<TestField> = Mul::new(num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_poly_eval() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + + let num_calls = FFT_THRESHOLD / 2; + let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls); + gadget_test(&mut g, num_calls); + + let num_calls = FFT_THRESHOLD; + let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_blind_poly_eval() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + + let num_calls = FFT_THRESHOLD / 2; + let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly.clone(), num_calls); + gadget_test(&mut g, num_calls); + + let num_calls = FFT_THRESHOLD; + let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly, num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_parallel_sum() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + let num_calls = 10; + let chunks = 23; + + let mut g = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks); + gadget_test(&mut g, num_calls); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_parallel_sum_multithreaded() { + use std::iter; + + for num_calls in [1, 10, 100] { + let poly: Vec<TestField> = random_vector(10).unwrap(); + let chunks = 23; + + let mut g = + ParallelSumMultithreaded::new(BlindPolyEval::new(poly.clone(), num_calls), chunks); + gadget_test(&mut g, num_calls); + + // Test that the multithreaded version has the same output as the normal version. + let mut g_serial = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks); + assert_eq!(g.arity(), g_serial.arity()); + assert_eq!(g.degree(), g_serial.degree()); + assert_eq!(g.calls(), g_serial.calls()); + + let arity = g.arity(); + let degree = g.degree(); + + // Test that both gadgets evaluate to the same value when run on scalar inputs. + let inp: Vec<TestField> = random_vector(arity).unwrap(); + let result = g.call(&inp).unwrap(); + let result_serial = g_serial.call(&inp).unwrap(); + assert_eq!(result, result_serial); + + // Test that both gadgets evaluate to the same value when run on polynomial inputs. + let mut poly_outp = + vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; + let mut poly_outp_serial = + vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; + let mut prng: Prng<TestField, _> = Prng::new().unwrap(); + let poly_inp: Vec<_> = iter::repeat_with(|| { + iter::repeat_with(|| prng.get()) + .take(1 + num_calls) + .collect::<Vec<_>>() + }) + .take(arity) + .collect(); + + g.call_poly(&mut poly_outp, &poly_inp).unwrap(); + g_serial + .call_poly(&mut poly_outp_serial, &poly_inp) + .unwrap(); + assert_eq!(poly_outp, poly_outp_serial); + } + } + + // Test that calling g.call_poly() and evaluating the output at a given point is equivalent + // to evaluating each of the inputs at the same point and applying g.call() on the results. + fn gadget_test<F: FieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) { + let wire_poly_len = (1 + num_calls).next_power_of_two(); + let mut prng = Prng::new().unwrap(); + let mut inp = vec![F::zero(); g.arity()]; + let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)]; + let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()]; + + let r = prng.get(); + for i in 0..g.arity() { + for j in 0..wire_poly_len { + wire_polys[i][j] = prng.get(); + } + inp[i] = poly_eval(&wire_polys[i], r); + } + + g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); + let got = poly_eval(&gadget_poly, r); + let want = g.call(&inp).unwrap(); + assert_eq!(got, want); + + // Repeat the call to make sure that the gadget's memory is reset properly between calls. + g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); + let got = poly_eval(&gadget_poly, r); + assert_eq!(got, want); + } +} diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs new file mode 100644 index 0000000000..83b9752f69 --- /dev/null +++ b/third_party/rust/prio/src/flp/types.rs @@ -0,0 +1,1199 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A collection of [`Type`](crate::flp::Type) implementations. + +use crate::field::{FieldElement, FieldElementExt}; +use crate::flp::gadgets::{BlindPolyEval, Mul, ParallelSumGadget, PolyEval}; +use crate::flp::{FlpError, Gadget, Type}; +use crate::polynomial::poly_range_check; +use std::convert::TryInto; +use std::marker::PhantomData; + +/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of +/// the measurements (i.e., the total number of `1s`). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Count<F> { + range_checker: Vec<F>, +} + +impl<F: FieldElement> Count<F> { + /// Return a new [`Count`] type instance. + pub fn new() -> Self { + Self { + range_checker: poly_range_check(0, 2), + } + } +} + +impl<F: FieldElement> Default for Count<F> { + fn default() -> Self { + Self::new() + } +} + +impl<F: FieldElement> Type for Count<F> { + const ID: u32 = 0x00000000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> { + let max = F::valid_integer_try_from(1)?; + if *value > max { + return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); + } + + Ok(vec![F::from(*value)]) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { + decode_result(data) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(Mul::new(1))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + Ok(g[0].call(&[input[0], input[0]])? - input[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + 1 + } + + fn proof_len(&self) -> usize { + 5 + } + + fn verifier_len(&self) -> usize { + 4 + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 0 + } + + fn prove_rand_len(&self) -> usize { + 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of +/// the measurements. +/// +/// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3]. +/// +/// [BBCG+19]: https://ia.cr/2019/188 +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Sum<F: FieldElement> { + bits: usize, + range_checker: Vec<F>, +} + +impl<F: FieldElement> Sum<F> { + /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0, + /// 2^bits)`. + pub fn new(bits: usize) -> Result<Self, FlpError> { + if !F::valid_integer_bitlength(bits) { + return Err(FlpError::Encode( + "invalid bits: number of bits exceeds maximum number of bits in this field" + .to_string(), + )); + } + Ok(Self { + bits, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FieldElement> Type for Sum<F> { + const ID: u32 = 0x00000001; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { + let v = F::encode_into_bitvector_representation(summand, self.bits)?; + Ok(v) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { + decode_result(data) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.clone(), + self.bits, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_from_bitvector_representation(&input)?; + Ok(vec![res]) + } + + fn input_len(&self) -> usize { + self.bits + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + 1 + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the +/// aggregate is the arithmetic average. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Average<F: FieldElement> { + bits: usize, + range_checker: Vec<F>, +} + +impl<F: FieldElement> Average<F> { + /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, + /// 2^bits)`. + pub fn new(bits: usize) -> Result<Self, FlpError> { + if !F::valid_integer_bitlength(bits) { + return Err(FlpError::Encode( + "invalid bits: number of bits exceeds maximum number of bits in this field" + .to_string(), + )); + } + Ok(Self { + bits, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FieldElement> Type for Average<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = f64; + type Field = F; + + fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { + let v = F::encode_into_bitvector_representation(summand, self.bits)?; + Ok(v) + } + + fn decode_result(&self, data: &[F], num_measurements: usize) -> Result<f64, FlpError> { + // Compute the average from the aggregated sum. + let data = decode_result(data)?; + let data: u64 = data.try_into().map_err(|err| { + FlpError::Decode(format!("failed to convert {:?} to u64: {}", data, err,)) + })?; + let result = (data as f64) / (num_measurements as f64); + Ok(result) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.clone(), + self.bits, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_from_bitvector_representation(&input)?; + Ok(vec![res]) + } + + fn input_len(&self) -> usize { + self.bits + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + 1 + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// The histogram type. Each measurement is a non-negative integer and the aggregate is a histogram +/// approximating the distribution of the measurements. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Histogram<F: FieldElement> { + buckets: Vec<F::Integer>, + range_checker: Vec<F>, +} + +impl<F: FieldElement> Histogram<F> { + /// Return a new [`Histogram`] type with the given buckets. + pub fn new(buckets: Vec<F::Integer>) -> Result<Self, FlpError> { + if buckets.len() >= u32::MAX as usize { + return Err(FlpError::Encode( + "invalid buckets: number of buckets exceeds maximum permitted".to_string(), + )); + } + + if !buckets.is_empty() { + for i in 0..buckets.len() - 1 { + if buckets[i + 1] <= buckets[i] { + return Err(FlpError::Encode( + "invalid buckets: out-of-order boundary".to_string(), + )); + } + } + } + + Ok(Self { + buckets, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FieldElement> Type for Histogram<F> { + const ID: u32 = 0x00000002; + type Measurement = F::Integer; + type AggregateResult = Vec<F::Integer>; + type Field = F; + + fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> { + let mut data = vec![F::zero(); self.buckets.len() + 1]; + + let bucket = match self.buckets.binary_search(measurement) { + Ok(i) => i, // on a bucket boundary + Err(i) => i, // smaller than the i-th bucket boundary + }; + + data[bucket] = F::one(); + Ok(data) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result<Vec<F::Integer>, FlpError> { + decode_result_vec(data, self.buckets.len() + 1) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.to_vec(), + self.input_len(), + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + + // Check that each element of `input` is a 0 or 1. + let range_check = call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])?; + + // Check that the elements of `input` sum to 1. + let mut sum_check = -(F::one() / F::from(F::valid_integer_try_from(num_shares)?)); + for val in input.iter() { + sum_check += *val; + } + + // Take a random linear combination of both checks. + let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check; + Ok(out) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + self.buckets.len() + 1 + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.input_len()).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 2 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// A sequence of counters. This type uses a neat trick from [[BBCG+19], Corollary 4.9] to reduce +/// the proof size to roughly the square root of the input size. +/// +/// [BBCG+19]: https://eprint.iacr.org/2019/188 +#[derive(Debug, PartialEq, Eq)] +pub struct CountVec<F, S> { + range_checker: Vec<F>, + len: usize, + chunk_len: usize, + gadget_calls: usize, + phantom: PhantomData<S>, +} + +impl<F: FieldElement, S: ParallelSumGadget<F, BlindPolyEval<F>>> CountVec<F, S> { + /// Returns a new [`CountVec`] with the given length. + pub fn new(len: usize) -> Self { + // The optimal chunk length is the square root of the input length. If the input length is + // not a perfect square, then round down. If the result is 0, then let the chunk length be + // 1 so that the underlying gadget can still be called. + let chunk_len = std::cmp::max(1, (len as f64).sqrt() as usize); + + let mut gadget_calls = len / chunk_len; + if len % chunk_len != 0 { + gadget_calls += 1; + } + + Self { + range_checker: poly_range_check(0, 2), + len, + chunk_len, + gadget_calls, + phantom: PhantomData, + } + } +} + +impl<F: FieldElement, S> Clone for CountVec<F, S> { + fn clone(&self) -> Self { + Self { + range_checker: self.range_checker.clone(), + len: self.len, + chunk_len: self.chunk_len, + gadget_calls: self.gadget_calls, + phantom: PhantomData, + } + } +} + +impl<F, S> Type for CountVec<F, S> +where + F: FieldElement, + S: ParallelSumGadget<F, BlindPolyEval<F>> + Eq + 'static, +{ + const ID: u32 = 0xFFFF0000; + type Measurement = Vec<F::Integer>; + type AggregateResult = Vec<F::Integer>; + type Field = F; + + fn encode_measurement(&self, measurement: &Vec<F::Integer>) -> Result<Vec<F>, FlpError> { + if measurement.len() != self.len { + return Err(FlpError::Encode(format!( + "unexpected measurement length: got {}; want {}", + measurement.len(), + self.len + ))); + } + + let max = F::Integer::from(F::one()); + for value in measurement { + if *value > max { + return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); + } + } + + Ok(measurement.iter().map(|value| F::from(*value)).collect()) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result<Vec<F::Integer>, FlpError> { + decode_result_vec(data, self.len) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(S::new( + BlindPolyEval::new(self.range_checker.clone(), self.gadget_calls), + self.chunk_len, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + + let s = F::from(F::valid_integer_try_from(num_shares)?).inv(); + let mut r = joint_rand[0]; + let mut outp = F::zero(); + let mut padded_chunk = vec![F::zero(); 2 * self.chunk_len]; + for chunk in input.chunks(self.chunk_len) { + let d = chunk.len(); + for i in 0..self.chunk_len { + if i < d { + padded_chunk[2 * i] = chunk[i]; + } else { + // If the chunk is smaller than the chunk length, then copy the last element of + // the chunk into the remaining slots. + padded_chunk[2 * i] = chunk[d - 1]; + } + padded_chunk[2 * i + 1] = r * s; + r *= joint_rand[0]; + } + + outp += g[0].call(&padded_chunk)?; + } + + Ok(outp) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + self.len + } + + fn proof_len(&self) -> usize { + (self.chunk_len * 2) + 3 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1 + } + + fn verifier_len(&self) -> usize { + 2 + self.chunk_len * 2 + } + + fn output_len(&self) -> usize { + self.len + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + self.chunk_len * 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// Compute a random linear combination of the result of calls of `g` on each element of `input`. +/// +/// # Arguments +/// +/// * `g` - The gadget to be applied elementwise +/// * `input` - The vector on whose elements to apply `g` +/// * `rnd` - The randomness used for the linear combination +pub(crate) fn call_gadget_on_vec_entries<F: FieldElement>( + g: &mut Box<dyn Gadget<F>>, + input: &[F], + rnd: F, +) -> Result<F, FlpError> { + let mut range_check = F::zero(); + let mut r = rnd; + for chunk in input.chunks(1) { + range_check += r * g.call(chunk)?; + r *= rnd; + } + Ok(range_check) +} + +/// Given a vector `data` of field elements which should contain exactly one entry, return the +/// integer representation of that entry. +pub(crate) fn decode_result<F: FieldElement>(data: &[F]) -> Result<F::Integer, FlpError> { + if data.len() != 1 { + return Err(FlpError::Decode("unexpected input length".into())); + } + Ok(F::Integer::from(data[0])) +} + +/// Given a vector `data` of field elements, return a vector containing the corresponding integer +/// representations, if the number of entries matches `expected_len`. +pub(crate) fn decode_result_vec<F: FieldElement>( + data: &[F], + expected_len: usize, +) -> Result<Vec<F::Integer>, FlpError> { + if data.len() != expected_len { + return Err(FlpError::Decode("unexpected input length".into())); + } + Ok(data.iter().map(|elem| F::Integer::from(*elem)).collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{random_vector, split_vector, Field64 as TestField}; + use crate::flp::gadgets::ParallelSum; + #[cfg(feature = "multithreaded")] + use crate::flp::gadgets::ParallelSumMultithreaded; + + // Number of shares to split input and proofs into in `flp_test`. + const NUM_SHARES: usize = 3; + + struct ValidityTestCase<F> { + expect_valid: bool, + expected_output: Option<Vec<F>>, + } + + #[test] + fn test_count() { + let count: Count<TestField> = Count::new(); + let zero = TestField::zero(); + let one = TestField::one(); + + // Round trip + assert_eq!( + count + .decode_result( + &count + .truncate(count.encode_measurement(&1).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + 1, + ); + + // Test FLP on valid input. + flp_validity_test( + &count, + &count.encode_measurement(&1).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one]), + }, + ) + .unwrap(); + + flp_validity_test( + &count, + &count.encode_measurement(&0).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero]), + }, + ) + .unwrap(); + + // Test FLP on invalid input. + flp_validity_test( + &count, + &[TestField::from(1337)], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + // Try running the validity circuit on an input that's too short. + count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); + count + .valid(&mut count.gadget(), &[1.into(), 2.into()], &[], 1) + .unwrap_err(); + } + + #[test] + fn test_sum() { + let sum = Sum::new(11).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + // Round trip + assert_eq!( + sum.decode_result( + &sum.truncate(sum.encode_measurement(&27).unwrap()).unwrap(), + 1 + ) + .unwrap(), + 27, + ); + + // Test FLP on valid input. + flp_validity_test( + &sum, + &sum.encode_measurement(&1337).unwrap(), + &ValidityTestCase { + expect_valid: true, + expected_output: Some(vec![TestField::from(1337)]), + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(0).unwrap(), + &[], + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero]), + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(2).unwrap(), + &[one, zero], + &ValidityTestCase { + expect_valid: true, + expected_output: Some(vec![one]), + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(9).unwrap(), + &[one, zero, one, one, zero, one, one, one, zero], + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![TestField::from(237)]), + }, + ) + .unwrap(); + + // Test FLP on invalid input. + flp_validity_test( + &Sum::new(3).unwrap(), + &[one, nine, zero], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(5).unwrap(), + &[zero, zero, zero, zero, nine], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + } + + #[test] + fn test_average() { + let average = Average::new(11).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let ten = TestField::from(10); + + // Testing that average correctly quotients the sum of the measurements + // by the number of measurements. + assert_eq!(average.decode_result(&[zero], 1).unwrap(), 0.0); + assert_eq!(average.decode_result(&[one], 1).unwrap(), 1.0); + assert_eq!(average.decode_result(&[one], 2).unwrap(), 0.5); + assert_eq!(average.decode_result(&[one], 4).unwrap(), 0.25); + assert_eq!(average.decode_result(&[ten], 8).unwrap(), 1.25); + + // round trip of 12 with `num_measurements`=1 + assert_eq!( + average + .decode_result( + &average + .truncate(average.encode_measurement(&12).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + 12.0 + ); + + // round trip of 12 with `num_measurements`=24 + assert_eq!( + average + .decode_result( + &average + .truncate(average.encode_measurement(&12).unwrap()) + .unwrap(), + 24 + ) + .unwrap(), + 0.5 + ); + } + + #[test] + fn test_histogram() { + let hist = Histogram::new(vec![10, 20]).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + assert_eq!(&hist.encode_measurement(&7).unwrap(), &[one, zero, zero]); + assert_eq!(&hist.encode_measurement(&10).unwrap(), &[one, zero, zero]); + assert_eq!(&hist.encode_measurement(&17).unwrap(), &[zero, one, zero]); + assert_eq!(&hist.encode_measurement(&20).unwrap(), &[zero, one, zero]); + assert_eq!(&hist.encode_measurement(&27).unwrap(), &[zero, zero, one]); + + // Round trip + assert_eq!( + hist.decode_result( + &hist + .truncate(hist.encode_measurement(&27).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + [0, 0, 1] + ); + + // Invalid bucket boundaries. + Histogram::<TestField>::new(vec![10, 0]).unwrap_err(); + Histogram::<TestField>::new(vec![10, 10]).unwrap_err(); + + // Test valid inputs. + flp_validity_test( + &hist, + &hist.encode_measurement(&0).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one, zero, zero]), + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &hist.encode_measurement(&17).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero, one, zero]), + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &hist.encode_measurement(&1337).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero, zero, one]), + }, + ) + .unwrap(); + + // Test invalid inputs. + flp_validity_test( + &hist, + &[zero, zero, nine], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[zero, one, one], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[one, one, one], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[zero, zero, zero], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + } + + fn test_count_vec<F, S>(f: F) + where + F: Fn(usize) -> CountVec<TestField, S>, + S: 'static + ParallelSumGadget<TestField, BlindPolyEval<TestField>> + Eq, + { + let one = TestField::one(); + let nine = TestField::from(9); + + // Test on valid inputs. + for len in 0..10 { + let count_vec = f(len); + flp_validity_test( + &count_vec, + &count_vec.encode_measurement(&vec![1; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one; len]), + }, + ) + .unwrap(); + } + + let len = 100; + let count_vec = f(len); + flp_validity_test( + &count_vec, + &count_vec.encode_measurement(&vec![1; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one; len]), + }, + ) + .unwrap(); + + // Test on invalid inputs. + for len in 1..10 { + let count_vec = f(len); + flp_validity_test( + &count_vec, + &vec![nine; len], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + }, + ) + .unwrap(); + } + + // Round trip + let want = vec![1; len]; + assert_eq!( + count_vec + .decode_result( + &count_vec + .truncate(count_vec.encode_measurement(&want).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + want + ); + } + + #[test] + fn test_count_vec_serial() { + test_count_vec(CountVec::<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>>::new) + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_count_vec_parallel() { + test_count_vec(CountVec::<TestField, ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>>::new) + } + + #[test] + fn count_vec_serial_long() { + let typ: CountVec<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>> = + CountVec::new(1000); + let input = typ.encode_measurement(&vec![0; 1000]).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn count_vec_parallel_long() { + let typ: CountVec< + TestField, + ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>, + > = CountVec::new(1000); + let input = typ.encode_measurement(&vec![0; 1000]).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + fn flp_validity_test<T: Type>( + typ: &T, + input: &[T::Field], + t: &ValidityTestCase<T::Field>, + ) -> Result<(), FlpError> { + let mut gadgets = typ.gadget(); + + if input.len() != typ.input_len() { + return Err(FlpError::Test(format!( + "unexpected input length: got {}; want {}", + input.len(), + typ.input_len() + ))); + } + + if typ.query_rand_len() != gadgets.len() { + return Err(FlpError::Test(format!( + "query rand length: got {}; want {}", + typ.query_rand_len(), + gadgets.len() + ))); + } + + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + + // Run the validity circuit. + let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?; + if v != T::Field::zero() && t.expect_valid { + return Err(FlpError::Test(format!( + "expected valid input: valid() returned {}", + v + ))); + } + if v == T::Field::zero() && !t.expect_valid { + return Err(FlpError::Test(format!( + "expected invalid input: valid() returned {}", + v + ))); + } + + // Generate the proof. + let proof = typ.prove(input, &prove_rand, &joint_rand)?; + if proof.len() != typ.proof_len() { + return Err(FlpError::Test(format!( + "unexpected proof length: got {}; want {}", + proof.len(), + typ.proof_len() + ))); + } + + // Query the proof. + let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?; + if verifier.len() != typ.verifier_len() { + return Err(FlpError::Test(format!( + "unexpected verifier length: got {}; want {}", + verifier.len(), + typ.verifier_len() + ))); + } + + // Decide if the input is valid. + let res = typ.decide(&verifier)?; + if res != t.expect_valid { + return Err(FlpError::Test(format!( + "decision is {}; want {}", + res, t.expect_valid, + ))); + } + + // Run distributed FLP. + let input_shares: Vec<Vec<T::Field>> = split_vector(input, NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let verifier: Vec<T::Field> = (0..NUM_SHARES) + .map(|i| { + typ.query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + NUM_SHARES, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + + let res = typ.decide(&verifier)?; + if res != t.expect_valid { + return Err(FlpError::Test(format!( + "distributed decision is {}; want {}", + res, t.expect_valid, + ))); + } + + // Try verifying various proof mutants. + for i in 0..proof.len() { + let mut mutated_proof = proof.clone(); + mutated_proof[i] += T::Field::one(); + let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?; + if typ.decide(&verifier)? { + return Err(FlpError::Test(format!( + "decision for proof mutant {} is {}; want {}", + i, true, false, + ))); + } + } + + // Try verifying a proof that is too short. + let mut mutated_proof = proof.clone(); + mutated_proof.truncate(gadgets[0].arity() - 1); + if typ + .query(input, &mutated_proof, &query_rand, &joint_rand, 1) + .is_ok() + { + return Err(FlpError::Test( + "query on short proof succeeded; want failure".to_string(), + )); + } + + // Try verifying a proof that is too long. + let mut mutated_proof = proof; + mutated_proof.extend_from_slice(&[T::Field::one(); 17]); + if typ + .query(input, &mutated_proof, &query_rand, &joint_rand, 1) + .is_ok() + { + return Err(FlpError::Test( + "query on long proof succeeded; want failure".to_string(), + )); + } + + if let Some(ref want) = t.expected_output { + let got = typ.truncate(input.to_vec())?; + + if got.len() != typ.output_len() { + return Err(FlpError::Test(format!( + "unexpected output length: got {}; want {}", + got.len(), + typ.output_len() + ))); + } + + if &got != want { + return Err(FlpError::Test(format!( + "unexpected output: got {:?}; want {:?}", + got, want + ))); + } + } + + Ok(()) + } +} diff --git a/third_party/rust/prio/src/fp.rs b/third_party/rust/prio/src/fp.rs new file mode 100644 index 0000000000..d828fb7daf --- /dev/null +++ b/third_party/rust/prio/src/fp.rs @@ -0,0 +1,561 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic for any field GF(p) for which p < 2^128. + +#[cfg(test)] +use rand::{prelude::*, Rng}; + +/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots +/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This +/// is the largest input size we would ever need for the cryptographic applications in this crate. +pub(crate) const MAX_ROOTS: usize = 20; + +/// This structure represents the parameters of a finite field GF(p) for which p < 2^128. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct FieldParameters { + /// The prime modulus `p`. + pub p: u128, + /// `mu = -p^(-1) mod 2^64`. + pub mu: u64, + /// `r2 = (2^128)^2 mod p`. + pub r2: u128, + /// The `2^num_roots`-th -principal root of unity. This element is used to generate the + /// elements of `roots`. + pub g: u128, + /// The number of principal roots of unity in `roots`. + pub num_roots: usize, + /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. + pub bit_mask: u128, + /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the + /// multiplicative group. `roots[0]` is equal to one by definition. + pub roots: [u128; MAX_ROOTS + 1], +} + +impl FieldParameters { + /// Addition. The result will be in [0, p), so long as both x and y are as well. + pub fn add(&self, x: u128, y: u128) -> u128 { + // 0,x + // + 0,y + // ===== + // c,z + let (z, carry) = x.overflowing_add(y); + // c, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(self.p); + let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128); + // if b1 == 1: return z + // else: return s0 + let m = 0u128.wrapping_sub(b1 as u128); + (z & m) | (s0 & !m) + } + + /// Subtraction. The result will be in [0, p), so long as both x and y are as well. + pub fn sub(&self, x: u128, y: u128) -> u128 { + // 0, x + // - 0, y + // ======== + // b1,z1,z0 + let (z0, b0) = x.overflowing_sub(y); + let (_z1, b1) = 0u128.overflowing_sub(b0 as u128); + let m = 0u128.wrapping_sub(b1 as u128); + // z1,z0 + // + 0, p + // ======== + // s1,s0 + z0.wrapping_add(m & self.p) + // if b1 == 1: return s0 + // else: return z0 + } + + /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm + /// described + /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf). + /// The result will be in [0, p). + /// + /// # Example usage + /// ```text + /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); + /// ``` + pub fn mul(&self, x: u128, y: u128) -> u128 { + let x = [lo64(x), hi64(x)]; + let y = [lo64(y), hi64(y)]; + let p = [lo64(self.p), hi64(self.p)]; + let mut zz = [0; 4]; + + // Integer multiplication + // z = x * y + + // x1,x0 + // * y1,y0 + // =========== + // z3,z2,z1,z0 + let mut result = x[0] * y[0]; + let mut carry = hi64(result); + zz[0] = lo64(result); + result = x[0] * y[1]; + let mut hi = hi64(result); + let mut lo = lo64(result); + result = lo + carry; + zz[1] = lo64(result); + let mut cc = hi64(result); + result = hi + cc; + zz[2] = lo64(result); + + result = x[1] * y[0]; + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = x[1] * y[1]; + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = hi + cc; + zz[3] = lo64(result); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. + + // z3,z2,z1,z0 + // + p1,p0 + // * w = mu*z0 + // =========== + // z3,z2,z1, 0 + let w = self.mu.wrapping_mul(zz[0] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[0] + lo; + zz[0] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = zz[2] + hi + cc; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + cc; + zz[3] = lo64(result); + + // z3,z2,z1 + // + p1,p0 + // * w = mu*z1 + // =========== + // z3,z2, 0 + let w = self.mu.wrapping_mul(zz[1] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + hi + cc; + zz[3] = lo64(result); + cc = hi64(result); + + // z = (z3,z2) + let prod = zz[2] | (zz[3] << 64); + + // Final subtraction + // If z >= p, then z = z - p + + // 0, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = prod.overflowing_sub(self.p); + let (_s1, b1) = (cc as u128).overflowing_sub(b0 as u128); + // if b1 == 1: return z + // else: return s0 + let mask = 0u128.wrapping_sub(b1 as u128); + (prod & mask) | (s0 & !mask) + } + + /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the + /// runtime of this algorithm is linear in the bit length of `exp`. + pub fn pow(&self, x: u128, exp: u128) -> u128 { + let mut t = self.montgomery(1); + for i in (0..128 - exp.leading_zeros()).rev() { + t = self.mul(t, t); + if (exp >> i) & 1 != 0 { + t = self.mul(t, x); + } + } + t + } + + /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of + /// this algorithm is linear in the bit length of `p`. + pub fn inv(&self, x: u128) -> u128 { + self.pow(x, self.p - 2) + } + + /// Negation, i.e., `-x (mod p)` where `p` is the modulus. + pub fn neg(&self, x: u128) -> u128 { + self.sub(0, x) + } + + /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery + /// domain in order to carry out field arithmetic. The result will be in [0, p). + /// + /// # Example usage + /// ```text + /// let integer = 1; // Standard integer representation + /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain + /// assert_eq!(elem, 2564090464); + /// ``` + pub fn montgomery(&self, x: u128) -> u128 { + modp(self.mul(x, self.r2), self.p) + } + + /// Returns a random field element mapped. + #[cfg(test)] + pub fn rand_elem<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 { + let uniform = rand::distributions::Uniform::from(0..self.p); + self.montgomery(uniform.sample(rng)) + } + + /// Maps a field element to its representation as an integer. The result will be in [0, p). + /// + /// #Example usage + /// ```text + /// let elem = 2564090464; // Internal representation in the Montgomery domain + /// let integer = fp.residue(elem); // Standard integer representation + /// assert_eq!(integer, 1); + /// ``` + pub fn residue(&self, x: u128) -> u128 { + modp(self.mul(x, 1), self.p) + } + + #[cfg(test)] + pub fn check(&self, p: u128, g: u128, order: u128) { + use modinverse::modinverse; + use num_bigint::{BigInt, ToBigInt}; + use std::cmp::max; + + assert_eq!(self.p, p, "p mismatch"); + + let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { + Some(mu) => mu as u64, + None => panic!("inverse of -p (mod 2^64) is undefined"), + }; + assert_eq!(self.mu, mu, "mu mismatch"); + + let big_p = &p.to_bigint().unwrap(); + let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p); + let big_r2: &BigInt = &(&(big_r * big_r) % big_p); + let mut it = big_r2.iter_u64_digits(); + let mut r2 = 0; + r2 |= it.next().unwrap() as u128; + if let Some(x) = it.next() { + r2 |= (x as u128) << 64; + } + assert_eq!(self.r2, r2, "r2 mismatch"); + + assert_eq!(self.g, self.montgomery(g), "g mismatch"); + assert_eq!( + self.residue(self.pow(self.g, order)), + 1, + "g order incorrect" + ); + + let num_roots = log2(order) as usize; + assert_eq!(order, 1 << num_roots, "order not a power of 2"); + assert_eq!(self.num_roots, num_roots, "num_roots mismatch"); + + let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; + roots[num_roots] = self.montgomery(g); + for i in (0..num_roots).rev() { + roots[i] = self.mul(roots[i + 1], roots[i + 1]); + } + assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch"); + assert_eq!(self.residue(self.roots[0]), 1, "first root is not one"); + + let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); + assert_eq!( + self.bit_mask.to_bigint().unwrap(), + bit_mask, + "bit_mask mismatch" + ); + } +} + +fn lo64(x: u128) -> u128 { + x & ((1 << 64) - 1) +} + +fn hi64(x: u128) -> u128 { + x >> 64 +} + +fn modp(x: u128, p: u128) -> u128 { + let (z, carry) = x.overflowing_sub(p); + let m = 0u128.wrapping_sub(carry as u128); + z.wrapping_add(m & p) +} + +pub(crate) const FP32: FieldParameters = FieldParameters { + p: 4293918721, // 32-bit prime + mu: 17302828673139736575, + r2: 1676699750, + g: 1074114499, + num_roots: 20, + bit_mask: 4294967295, + roots: [ + 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, + 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, + 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, + ], +}; + +pub(crate) const FP64: FieldParameters = FieldParameters { + p: 18446744069414584321, // 64-bit prime + mu: 18446744069414584319, + r2: 4294967295, + g: 959634606461954525, + num_roots: 32, + bit_mask: 18446744073709551615, + roots: [ + 18446744065119617025, + 4294967296, + 18446462594437939201, + 72057594037927936, + 1152921504338411520, + 16384, + 18446743519658770561, + 18446735273187346433, + 6519596376689022014, + 9996039020351967275, + 15452408553935940313, + 15855629130643256449, + 8619522106083987867, + 13036116919365988132, + 1033106119984023956, + 16593078884869787648, + 16980581328500004402, + 12245796497946355434, + 8709441440702798460, + 8611358103550827629, + 8120528636261052110, + ], +}; + +pub(crate) const FP96: FieldParameters = FieldParameters { + p: 79228148845226978974766202881, // 96-bit prime + mu: 18446744073709551615, + r2: 69162923446439011319006025217, + g: 11329412859948499305522312170, + num_roots: 64, + bit_mask: 79228162514264337593543950335, + roots: [ + 10128756682736510015896859, + 79218020088544242464750306022, + 9188608122889034248261485869, + 10170869429050723924726258983, + 36379376833245035199462139324, + 20898601228930800484072244511, + 2845758484723985721473442509, + 71302585629145191158180162028, + 76552499132904394167108068662, + 48651998692455360626769616967, + 36570983454832589044179852640, + 72716740645782532591407744342, + 73296872548531908678227377531, + 14831293153408122430659535205, + 61540280632476003580389854060, + 42256269782069635955059793151, + 51673352890110285959979141934, + 43102967204983216507957944322, + 3990455111079735553382399289, + 68042997008257313116433801954, + 44344622755749285146379045633, + ], +}; + +pub(crate) const FP128: FieldParameters = FieldParameters { + p: 340282366920938462946865773367900766209, // 128-bit prime + mu: 18446744073709551615, + r2: 403909908237944342183153, + g: 107630958476043550189608038630704257141, + num_roots: 66, + bit_mask: 340282366920938463463374607431768211455, + roots: [ + 516508834063867445247, + 340282366920938462430356939304033320962, + 129526470195413442198896969089616959958, + 169031622068548287099117778531474117974, + 81612939378432101163303892927894236156, + 122401220764524715189382260548353967708, + 199453575871863981432000940507837456190, + 272368408887745135168960576051472383806, + 24863773656265022616993900367764287617, + 257882853788779266319541142124730662203, + 323732363244658673145040701829006542956, + 57532865270871759635014308631881743007, + 149571414409418047452773959687184934208, + 177018931070866797456844925926211239962, + 268896136799800963964749917185333891349, + 244556960591856046954834420512544511831, + 118945432085812380213390062516065622346, + 202007153998709986841225284843501908420, + 332677126194796691532164818746739771387, + 258279638927684931537542082169183965856, + 148221243758794364405224645520862378432, + ], +}; + +// Compute the ceiling of the base-2 logarithm of `x`. +pub(crate) fn log2(x: u128) -> u128 { + let y = (127 - x.leading_zeros()) as u128; + y + ((x > 1 << y) as u128) +} + +#[cfg(test)] +mod tests { + use super::*; + use num_bigint::ToBigInt; + + #[test] + fn test_log2() { + assert_eq!(log2(1), 0); + assert_eq!(log2(2), 1); + assert_eq!(log2(3), 2); + assert_eq!(log2(4), 2); + assert_eq!(log2(15), 4); + assert_eq!(log2(16), 4); + assert_eq!(log2(30), 5); + assert_eq!(log2(32), 5); + assert_eq!(log2(1 << 127), 127); + assert_eq!(log2((1 << 127) + 13), 128); + } + + struct TestFieldParametersData { + fp: FieldParameters, // The paramters being tested + expected_p: u128, // Expected fp.p + expected_g: u128, // Expected fp.residue(fp.g) + expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1 + } + + #[test] + fn test_fp() { + let test_fps = vec![ + TestFieldParametersData { + fp: FP32, + expected_p: 4293918721, + expected_g: 3925978153, + expected_order: 1 << 20, + }, + TestFieldParametersData { + fp: FP64, + expected_p: 18446744069414584321, + expected_g: 1753635133440165772, + expected_order: 1 << 32, + }, + TestFieldParametersData { + fp: FP96, + expected_p: 79228148845226978974766202881, + expected_g: 34233996298771126927060021012, + expected_order: 1 << 64, + }, + TestFieldParametersData { + fp: FP128, + expected_p: 340282366920938462946865773367900766209, + expected_g: 145091266659756586618791329697897684742, + expected_order: 1 << 66, + }, + ]; + + for t in test_fps.into_iter() { + // Check that the field parameters have been constructed properly. + t.fp.check(t.expected_p, t.expected_g, t.expected_order); + + // Check that the generator has the correct order. + assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1); + assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1); + + // Test arithmetic using the field parameters. + arithmetic_test(&t.fp); + } + } + + fn arithmetic_test(fp: &FieldParameters) { + let mut rng = rand::thread_rng(); + let big_p = &fp.p.to_bigint().unwrap(); + + for _ in 0..100 { + let x = fp.rand_elem(&mut rng); + let y = fp.rand_elem(&mut rng); + let big_x = &fp.residue(x).to_bigint().unwrap(); + let big_y = &fp.residue(y).to_bigint().unwrap(); + + // Test addition. + let got = fp.add(x, y); + let want = (big_x + big_y) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test subtraction. + let got = fp.sub(x, y); + let want = if big_x >= big_y { + big_x - big_y + } else { + big_p - big_y + big_x + }; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test multiplication. + let got = fp.mul(x, y); + let want = (big_x * big_y) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test inversion. + let got = fp.inv(x); + let want = big_x.modpow(&(big_p - 2u128), big_p); + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + assert_eq!(fp.residue(fp.mul(got, x)), 1); + + // Test negation. + let got = fp.neg(x); + let want = (big_p - big_x) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + assert_eq!(fp.residue(fp.add(got, x)), 0); + } + } +} diff --git a/third_party/rust/prio/src/lib.rs b/third_party/rust/prio/src/lib.rs new file mode 100644 index 0000000000..f0e00de3c0 --- /dev/null +++ b/third_party/rust/prio/src/lib.rs @@ -0,0 +1,33 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Libprio-rs +//! +//! Implementation of the [Prio](https://crypto.stanford.edu/prio/) private data aggregation +//! protocol. For now we only support 0 / 1 vectors. + +pub mod benchmarked; +#[cfg(feature = "prio2")] +pub mod client; +#[cfg(feature = "prio2")] +pub mod encrypt; +#[cfg(feature = "prio2")] +pub mod server; + +pub mod codec; +mod fft; +pub mod field; +pub mod flp; +mod fp; +mod polynomial; +mod prng; +// Module test_vector depends on crate `rand` so we make it an optional feature +// to spare most clients the extra dependency. +#[cfg(all(any(feature = "test-util", test), feature = "prio2"))] +pub mod test_vector; +#[cfg(feature = "prio2")] +pub mod util; +pub mod vdaf; diff --git a/third_party/rust/prio/src/polynomial.rs b/third_party/rust/prio/src/polynomial.rs new file mode 100644 index 0000000000..7c38341e36 --- /dev/null +++ b/third_party/rust/prio/src/polynomial.rs @@ -0,0 +1,384 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Functions for polynomial interpolation and evaluation + +use crate::field::FieldElement; + +use std::convert::TryFrom; + +/// Temporary memory used for FFT +#[derive(Clone, Debug)] +pub struct PolyFFTTempMemory<F> { + fft_tmp: Vec<F>, + fft_y_sub: Vec<F>, + fft_roots_sub: Vec<F>, +} + +impl<F: FieldElement> PolyFFTTempMemory<F> { + fn new(length: usize) -> Self { + PolyFFTTempMemory { + fft_tmp: vec![F::zero(); length], + fft_y_sub: vec![F::zero(); length], + fft_roots_sub: vec![F::zero(); length], + } + } +} + +/// Auxiliary memory for polynomial interpolation and evaluation +#[derive(Clone, Debug)] +pub struct PolyAuxMemory<F> { + pub roots_2n: Vec<F>, + pub roots_2n_inverted: Vec<F>, + pub roots_n: Vec<F>, + pub roots_n_inverted: Vec<F>, + pub coeffs: Vec<F>, + pub fft_memory: PolyFFTTempMemory<F>, +} + +impl<F: FieldElement> PolyAuxMemory<F> { + pub fn new(n: usize) -> Self { + PolyAuxMemory { + roots_2n: fft_get_roots(2 * n, false), + roots_2n_inverted: fft_get_roots(2 * n, true), + roots_n: fft_get_roots(n, false), + roots_n_inverted: fft_get_roots(n, true), + coeffs: vec![F::zero(); 2 * n], + fft_memory: PolyFFTTempMemory::new(2 * n), + } + } +} + +fn fft_recurse<F: FieldElement>( + out: &mut [F], + n: usize, + roots: &[F], + ys: &[F], + tmp: &mut [F], + y_sub: &mut [F], + roots_sub: &mut [F], +) { + if n == 1 { + out[0] = ys[0]; + return; + } + + let half_n = n / 2; + + let (tmp_first, tmp_second) = tmp.split_at_mut(half_n); + let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n); + let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n); + + // Recurse on the first half + for i in 0..half_n { + y_sub_first[i] = ys[i] + ys[i + half_n]; + roots_sub_first[i] = roots[2 * i]; + } + fft_recurse( + tmp_first, + half_n, + roots_sub_first, + y_sub_first, + tmp_second, + y_sub_second, + roots_sub_second, + ); + for i in 0..half_n { + out[2 * i] = tmp_first[i]; + } + + // Recurse on the second half + for i in 0..half_n { + y_sub_first[i] = ys[i] - ys[i + half_n]; + y_sub_first[i] *= roots[i]; + } + fft_recurse( + tmp_first, + half_n, + roots_sub_first, + y_sub_first, + tmp_second, + y_sub_second, + roots_sub_second, + ); + for i in 0..half_n { + out[2 * i + 1] = tmp[i]; + } +} + +/// Calculate `count` number of roots of unity of order `count` +fn fft_get_roots<F: FieldElement>(count: usize, invert: bool) -> Vec<F> { + let mut roots = vec![F::zero(); count]; + let mut gen = F::generator(); + if invert { + gen = gen.inv(); + } + + roots[0] = F::one(); + let step_size = F::generator_order() / F::Integer::try_from(count).unwrap(); + // generator for subgroup of order count + gen = gen.pow(step_size); + + roots[1] = gen; + + for i in 2..count { + roots[i] = gen * roots[i - 1]; + } + + roots +} + +fn fft_interpolate_raw<F: FieldElement>( + out: &mut [F], + ys: &[F], + n_points: usize, + roots: &[F], + invert: bool, + mem: &mut PolyFFTTempMemory<F>, +) { + fft_recurse( + out, + n_points, + roots, + ys, + &mut mem.fft_tmp, + &mut mem.fft_y_sub, + &mut mem.fft_roots_sub, + ); + if invert { + let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv(); + #[allow(clippy::needless_range_loop)] + for i in 0..n_points { + out[i] *= n_inverse; + } + } +} + +pub fn poly_fft<F: FieldElement>( + points_out: &mut [F], + points_in: &[F], + scaled_roots: &[F], + n_points: usize, + invert: bool, + mem: &mut PolyFFTTempMemory<F>, +) { + fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) +} + +// Evaluate a polynomial using Horner's method. +pub fn poly_eval<F: FieldElement>(poly: &[F], eval_at: F) -> F { + if poly.is_empty() { + return F::zero(); + } + + let mut result = poly[poly.len() - 1]; + for i in (0..poly.len() - 1).rev() { + result *= eval_at; + result += poly[i]; + } + + result +} + +// Returns the degree of polynomial `p`. +pub fn poly_deg<F: FieldElement>(p: &[F]) -> usize { + let mut d = p.len(); + while d > 0 && p[d - 1] == F::zero() { + d -= 1; + } + d.saturating_sub(1) +} + +// Multiplies polynomials `p` and `q` and returns the result. +pub fn poly_mul<F: FieldElement>(p: &[F], q: &[F]) -> Vec<F> { + let p_size = poly_deg(p) + 1; + let q_size = poly_deg(q) + 1; + let mut out = vec![F::zero(); p_size + q_size]; + for i in 0..p_size { + for j in 0..q_size { + out[i + j] += p[i] * q[j]; + } + } + out.truncate(poly_deg(&out) + 1); + out +} + +#[cfg(feature = "prio2")] +pub fn poly_interpret_eval<F: FieldElement>( + points: &[F], + roots: &[F], + eval_at: F, + tmp_coeffs: &mut [F], + fft_memory: &mut PolyFFTTempMemory<F>, +) -> F { + poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory); + poly_eval(&tmp_coeffs[..points.len()], eval_at) +} + +// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise, +// the output is not `0`. +pub(crate) fn poly_range_check<F: FieldElement>(start: usize, end: usize) -> Vec<F> { + let mut p = vec![F::one()]; + let mut q = [F::zero(), F::one()]; + for i in start..end { + q[0] = -F::from(F::Integer::try_from(i).unwrap()); + p = poly_mul(&p, &q); + } + p +} + +#[test] +fn test_roots() { + use crate::field::Field32; + + let count = 128; + let roots = fft_get_roots::<Field32>(count, false); + let roots_inv = fft_get_roots::<Field32>(count, true); + + for i in 0..count { + assert_eq!(roots[i] * roots_inv[i], 1); + assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1); + assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1); + } +} + +#[test] +fn test_eval() { + use crate::field::Field32; + + let mut poly = vec![Field32::from(0); 4]; + poly[0] = 2.into(); + poly[1] = 1.into(); + poly[2] = 5.into(); + // 5*3^2 + 3 + 2 = 50 + assert_eq!(poly_eval(&poly[..3], 3.into()), 50); + poly[3] = 4.into(); + // 4*3^3 + 5*3^2 + 3 + 2 = 158 + assert_eq!(poly_eval(&poly[..4], 3.into()), 158); +} + +#[test] +fn test_poly_deg() { + use crate::field::Field32; + + let zero = Field32::zero(); + let one = Field32::root(0).unwrap(); + assert_eq!(poly_deg(&[zero]), 0); + assert_eq!(poly_deg(&[one]), 0); + assert_eq!(poly_deg(&[zero, one]), 1); + assert_eq!(poly_deg(&[zero, zero, one]), 2); + assert_eq!(poly_deg(&[zero, one, one]), 2); + assert_eq!(poly_deg(&[zero, one, one, one]), 3); + assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3); + assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3); +} + +#[test] +fn test_poly_mul() { + use crate::field::Field64; + + let p = [ + Field64::from(u64::try_from(2).unwrap()), + Field64::from(u64::try_from(3).unwrap()), + ]; + + let q = [ + Field64::one(), + Field64::zero(), + Field64::from(u64::try_from(5).unwrap()), + ]; + + let want = [ + Field64::from(u64::try_from(2).unwrap()), + Field64::from(u64::try_from(3).unwrap()), + Field64::from(u64::try_from(10).unwrap()), + Field64::from(u64::try_from(15).unwrap()), + ]; + + let got = poly_mul(&p, &q); + assert_eq!(&got, &want); +} + +#[test] +fn test_poly_range_check() { + use crate::field::Field64; + + let start = 74; + let end = 112; + let p = poly_range_check(start, end); + + // Check each number in the range. + for i in start..end { + let x = Field64::from(i as u64); + let y = poly_eval(&p, x); + assert_eq!(y, Field64::zero(), "range check failed for {}", i); + } + + // Check the number below the range. + let x = Field64::from((start - 1) as u64); + let y = poly_eval(&p, x); + assert_ne!(y, Field64::zero()); + + // Check a number above the range. + let x = Field64::from(end as u64); + let y = poly_eval(&p, x); + assert_ne!(y, Field64::zero()); +} + +#[test] +fn test_fft() { + use crate::field::Field32; + + use rand::prelude::*; + use std::convert::TryFrom; + + let count = 128; + let mut mem = PolyAuxMemory::new(count / 2); + + let mut poly = vec![Field32::from(0); count]; + let mut points2 = vec![Field32::from(0); count]; + + let points = (0..count) + .into_iter() + .map(|_| Field32::from(random::<u32>())) + .collect::<Vec<Field32>>(); + + // From points to coeffs and back + poly_fft( + &mut poly, + &points, + &mem.roots_2n, + count, + false, + &mut mem.fft_memory, + ); + poly_fft( + &mut points2, + &poly, + &mem.roots_2n_inverted, + count, + true, + &mut mem.fft_memory, + ); + + assert_eq!(points, points2); + + // interpolation + poly_fft( + &mut poly, + &points, + &mem.roots_2n, + count, + false, + &mut mem.fft_memory, + ); + + #[allow(clippy::needless_range_loop)] + for i in 0..count { + let mut should_be = Field32::from(0); + for j in 0..count { + should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be; + } + assert_eq!(should_be, poly[i]); + } +} diff --git a/third_party/rust/prio/src/prng.rs b/third_party/rust/prio/src/prng.rs new file mode 100644 index 0000000000..764cd7b025 --- /dev/null +++ b/third_party/rust/prio/src/prng.rs @@ -0,0 +1,208 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Tool for generating pseudorandom field elements. +//! +//! NOTE: The public API for this module is a work in progress. + +use crate::field::{FieldElement, FieldError}; +use crate::vdaf::prg::SeedStream; +#[cfg(feature = "crypto-dependencies")] +use crate::vdaf::prg::SeedStreamAes128; +#[cfg(feature = "crypto-dependencies")] +use getrandom::getrandom; + +use std::marker::PhantomData; + +const BUFFER_SIZE_IN_ELEMENTS: usize = 128; + +/// Errors propagated by methods in this module. +#[derive(Debug, thiserror::Error)] +pub enum PrngError { + /// Failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), +} + +/// This type implements an iterator that generates a pseudorandom sequence of field elements. The +/// sequence is derived from the key stream of AES-128 in CTR mode with a random IV. +#[derive(Debug)] +pub(crate) struct Prng<F, S> { + phantom: PhantomData<F>, + seed_stream: S, + buffer: Vec<u8>, + buffer_index: usize, + output_written: usize, +} + +#[cfg(feature = "crypto-dependencies")] +impl<F: FieldElement> Prng<F, SeedStreamAes128> { + /// Create a [`Prng`] from a seed for Prio 2. The first 16 bytes of the seed and the last 16 + /// bytes of the seed are used, respectively, for the key and initialization vector for AES128 + /// in CTR mode. + pub(crate) fn from_prio2_seed(seed: &[u8; 32]) -> Self { + let seed_stream = SeedStreamAes128::new(&seed[..16], &seed[16..]); + Self::from_seed_stream(seed_stream) + } + + /// Create a [`Prng`] from a randomly generated seed. + pub(crate) fn new() -> Result<Self, PrngError> { + let mut seed = [0; 32]; + getrandom(&mut seed)?; + Ok(Self::from_prio2_seed(&seed)) + } +} + +impl<F, S> Prng<F, S> +where + F: FieldElement, + S: SeedStream, +{ + pub(crate) fn from_seed_stream(mut seed_stream: S) -> Self { + let mut buffer = vec![0; BUFFER_SIZE_IN_ELEMENTS * F::ENCODED_SIZE]; + seed_stream.fill(&mut buffer); + + Self { + phantom: PhantomData::<F>, + seed_stream, + buffer, + buffer_index: 0, + output_written: 0, + } + } + + pub(crate) fn get(&mut self) -> F { + loop { + // Seek to the next chunk of the buffer that encodes an element of F. + for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) { + let j = i + F::ENCODED_SIZE; + if let Some(x) = match F::try_from_random(&self.buffer[i..j]) { + Ok(x) => Some(x), + Err(FieldError::ModulusOverflow) => None, // reject this sample + Err(err) => panic!("unexpected error: {}", err), + } { + // Set the buffer index to the next chunk. + self.buffer_index = j; + self.output_written += 1; + return x; + } + } + + // Refresh buffer with the next chunk of PRG output. + self.seed_stream.fill(&mut self.buffer); + self.buffer_index = 0; + } + } +} + +impl<F, S> Iterator for Prng<F, S> +where + F: FieldElement, + S: SeedStream, +{ + type Item = F; + + fn next(&mut self) -> Option<F> { + Some(self.get()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + codec::Decode, + field::{Field96, FieldPrio2}, + vdaf::prg::{Prg, PrgAes128, Seed}, + }; + use std::convert::TryInto; + + #[test] + fn secret_sharing_interop() { + let seed = [ + 0xcd, 0x85, 0x5b, 0xd4, 0x86, 0x48, 0xa4, 0xce, 0x52, 0x5c, 0x36, 0xee, 0x5a, 0x71, + 0xf3, 0x0f, 0x66, 0x80, 0xd3, 0x67, 0x53, 0x9a, 0x39, 0x6f, 0x12, 0x2f, 0xad, 0x94, + 0x4d, 0x34, 0xcb, 0x58, + ]; + + let reference = [ + 0xd0056ec5, 0xe23f9c52, 0x47e4ddb4, 0xbe5dacf6, 0x4b130aba, 0x530c7a90, 0xe8fc4ee5, + 0xb0569cb7, 0x7774cd3c, 0x7f24e6a5, 0xcc82355d, 0xc41f4f13, 0x67fe193c, 0xc94d63a4, + 0x5d7b474c, 0xcc5c9f5f, 0xe368e1d5, 0x020fa0cf, 0x9e96aa2a, 0xe924137d, 0xfa026ab9, + 0x8ebca0cc, 0x26fc58a5, 0x10a7b173, 0xb9c97291, 0x53ef0e28, 0x069cfb8e, 0xe9383cae, + 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58, + ]; + + let share2 = extract_share_from_seed::<FieldPrio2>(reference.len(), &seed); + + assert_eq!(share2, reference); + } + + /// takes a seed and hash as base64 encoded strings + #[cfg(feature = "prio2")] + fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) { + let seed = base64::decode(seed_base64).unwrap(); + let random_data = extract_share_from_seed::<FieldPrio2>(len, &seed); + + let random_bytes = FieldPrio2::slice_into_byte_vec(&random_data); + + let digest = ring::digest::digest(&ring::digest::SHA256, &random_bytes); + assert_eq!(base64::encode(digest), hash_base64); + } + + #[test] + #[cfg(feature = "prio2")] + fn test_hash_interop() { + random_data_interop( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", + "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=", + 100_000, + ); + + // zero seed + random_data_interop( + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "3wHQbSwAn9GPfoNkKe1qSzWdKnu/R+hPPyRwwz6Di+w=", + 100_000, + ); + // 0, 1, 2 ... seed + random_data_interop( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", + "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=", + 100_000, + ); + // one arbirtary fixed seed + random_data_interop( + "rkLrnVcU8ULaiuXTvR3OKrfpMX0kQidqVzta1pleKKg=", + "b1fMXYrGUNR3wOZ/7vmUMmY51QHoPDBzwok0fz6xC0I=", + 100_000, + ); + // all bits set seed + random_data_interop( + "//////////////////////////////////////////8=", + "iBiDaqLrv7/rX/+vs6akPiprGgYfULdh/XhoD61HQXA=", + 100_000, + ); + } + + fn extract_share_from_seed<F: FieldElement>(length: usize, seed: &[u8]) -> Vec<F> { + assert_eq!(seed.len(), 32); + Prng::from_prio2_seed(seed.try_into().unwrap()) + .take(length) + .collect() + } + + #[test] + fn rejection_sampling_test_vector() { + // These constants were found in a brute-force search, and they test that the PRG performs + // rejection sampling correctly when raw AES-CTR output exceeds the prime modulus. + let seed_stream = PrgAes128::seed_stream( + &Seed::get_decoded(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 95]).unwrap(), + b"", + ); + let mut prng = Prng::<Field96, _>::from_seed_stream(seed_stream); + let expected = Field96::from(39729620190871453347343769187); + let actual = prng.nth(145).unwrap(); + assert_eq!(actual, expected); + } +} diff --git a/third_party/rust/prio/src/server.rs b/third_party/rust/prio/src/server.rs new file mode 100644 index 0000000000..01f309e797 --- /dev/null +++ b/third_party/rust/prio/src/server.rs @@ -0,0 +1,469 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! The Prio v2 server. Only 0 / 1 vectors are supported for now. +use crate::{ + encrypt::{decrypt_share, EncryptError, PrivateKey}, + field::{merge_vector, FieldElement, FieldError}, + polynomial::{poly_interpret_eval, PolyAuxMemory}, + prng::{Prng, PrngError}, + util::{proof_length, unpack_proof, SerializeError}, + vdaf::prg::SeedStreamAes128, +}; +use serde::{Deserialize, Serialize}; +use std::convert::TryInto; + +/// Possible errors from server operations +#[derive(Debug, thiserror::Error)] +pub enum ServerError { + /// Unexpected Share Length + #[error("unexpected share length")] + ShareLength, + /// Encryption/decryption error + #[error("encryption/decryption error")] + Encrypt(#[from] EncryptError), + /// Finite field operation error + #[error("finite field operation error")] + Field(#[from] FieldError), + /// Serialization/deserialization error + #[error("serialization/deserialization error")] + Serialize(#[from] SerializeError), + /// Failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), + /// PRNG error. + #[error("prng error: {0}")] + Prng(#[from] PrngError), +} + +/// Auxiliary memory for constructing a +/// [`VerificationMessage`](struct.VerificationMessage.html) +#[derive(Debug)] +pub struct ValidationMemory<F> { + points_f: Vec<F>, + points_g: Vec<F>, + points_h: Vec<F>, + poly_mem: PolyAuxMemory<F>, +} + +impl<F: FieldElement> ValidationMemory<F> { + /// Construct a new ValidationMemory object for validating proof shares of + /// length `dimension`. + pub fn new(dimension: usize) -> Self { + let n: usize = (dimension + 1).next_power_of_two(); + ValidationMemory { + points_f: vec![F::zero(); n], + points_g: vec![F::zero(); n], + points_h: vec![F::zero(); 2 * n], + poly_mem: PolyAuxMemory::new(n), + } + } +} + +/// Main workhorse of the server. +#[derive(Debug)] +pub struct Server<F> { + prng: Prng<F, SeedStreamAes128>, + dimension: usize, + is_first_server: bool, + accumulator: Vec<F>, + validation_mem: ValidationMemory<F>, + private_key: PrivateKey, +} + +impl<F: FieldElement> Server<F> { + /// Construct a new server instance + /// + /// Params: + /// * `dimension`: the number of elements in the aggregation vector. + /// * `is_first_server`: only one of the servers should have this true. + /// * `private_key`: the private key for decrypting the share of the proof. + pub fn new( + dimension: usize, + is_first_server: bool, + private_key: PrivateKey, + ) -> Result<Server<F>, ServerError> { + Ok(Server { + prng: Prng::new()?, + dimension, + is_first_server, + accumulator: vec![F::zero(); dimension], + validation_mem: ValidationMemory::new(dimension), + private_key, + }) + } + + /// Decrypt and deserialize + fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<F>, ServerError> { + let len = proof_length(self.dimension); + let share = decrypt_share(encrypted_share, &self.private_key)?; + Ok(if self.is_first_server { + F::byte_slice_into_vec(&share)? + } else { + if share.len() != 32 { + return Err(ServerError::ShareLength); + } + + Prng::from_prio2_seed(&share.try_into().unwrap()) + .take(len) + .collect() + }) + } + + /// Generate verification message from an encrypted share + /// + /// This decrypts the share of the proof and constructs the + /// [`VerificationMessage`](struct.VerificationMessage.html). + /// The `eval_at` field should be generate by + /// [choose_eval_at](#method.choose_eval_at). + pub fn generate_verification_message( + &mut self, + eval_at: F, + share: &[u8], + ) -> Result<VerificationMessage<F>, ServerError> { + let share_field = self.deserialize_share(share)?; + generate_verification_message( + self.dimension, + eval_at, + &share_field, + self.is_first_server, + &mut self.validation_mem, + ) + } + + /// Add the content of the encrypted share into the accumulator + /// + /// This only changes the accumulator if the verification messages `v1` and + /// `v2` indicate that the share passed validation. + pub fn aggregate( + &mut self, + share: &[u8], + v1: &VerificationMessage<F>, + v2: &VerificationMessage<F>, + ) -> Result<bool, ServerError> { + let share_field = self.deserialize_share(share)?; + let is_valid = is_valid_share(v1, v2); + if is_valid { + // Add to the accumulator. share_field also includes the proof + // encoding, so we slice off the first dimension fields, which are + // the actual data share. + merge_vector(&mut self.accumulator, &share_field[..self.dimension])?; + } + + Ok(is_valid) + } + + /// Return the current accumulated shares. + /// + /// These can be merged together using + /// [`reconstruct_shares`](../util/fn.reconstruct_shares.html). + pub fn total_shares(&self) -> &[F] { + &self.accumulator + } + + /// Merge shares from another server. + /// + /// This modifies the current accumulator. + /// + /// # Errors + /// + /// Returns an error if `other_total_shares.len()` is not equal to this + //// server's `dimension`. + pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> { + Ok(merge_vector(&mut self.accumulator, other_total_shares)?) + } + + /// Choose a random point for polynomial evaluation + /// + /// The point returned is not one of the roots used for polynomial + /// evaluation. + pub fn choose_eval_at(&mut self) -> F { + loop { + let eval_at = self.prng.get(); + if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) { + break eval_at; + } + } + } +} + +/// Verification message for proof validation +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VerificationMessage<F> { + /// f evaluated at random point + pub f_r: F, + /// g evaluated at random point + pub g_r: F, + /// h evaluated at random point + pub h_r: F, +} + +/// Given a proof and evaluation point, this constructs the verification +/// message. +pub fn generate_verification_message<F: FieldElement>( + dimension: usize, + eval_at: F, + proof: &[F], + is_first_server: bool, + mem: &mut ValidationMemory<F>, +) -> Result<VerificationMessage<F>, ServerError> { + let unpacked = unpack_proof(proof, dimension)?; + let proof_length = 2 * (dimension + 1).next_power_of_two(); + + // set zero terms + mem.points_f[0] = *unpacked.f0; + mem.points_g[0] = *unpacked.g0; + mem.points_h[0] = *unpacked.h0; + + // set points_f and points_g + for (i, x) in unpacked.data.iter().enumerate() { + mem.points_f[i + 1] = *x; + + if is_first_server { + // only one server needs to subtract one for point_g + mem.points_g[i + 1] = *x - F::one(); + } else { + mem.points_g[i + 1] = *x; + } + } + + // set points_h, skipping over elements that should be zero + let mut i = 1; + let mut j = 0; + while i < proof_length { + mem.points_h[i] = unpacked.points_h_packed[j]; + j += 1; + i += 2; + } + + // evaluate polynomials at random point + let f_r = poly_interpret_eval( + &mem.points_f, + &mem.poly_mem.roots_n_inverted, + eval_at, + &mut mem.poly_mem.coeffs, + &mut mem.poly_mem.fft_memory, + ); + let g_r = poly_interpret_eval( + &mem.points_g, + &mem.poly_mem.roots_n_inverted, + eval_at, + &mut mem.poly_mem.coeffs, + &mut mem.poly_mem.fft_memory, + ); + let h_r = poly_interpret_eval( + &mem.points_h, + &mem.poly_mem.roots_2n_inverted, + eval_at, + &mut mem.poly_mem.coeffs, + &mut mem.poly_mem.fft_memory, + ); + + Ok(VerificationMessage { f_r, g_r, h_r }) +} + +/// Decides if the distributed proof is valid +pub fn is_valid_share<F: FieldElement>( + v1: &VerificationMessage<F>, + v2: &VerificationMessage<F>, +) -> bool { + // reconstruct f_r, g_r, h_r + let f_r = v1.f_r + v2.f_r; + let g_r = v1.g_r + v2.g_r; + let h_r = v1.h_r + v2.h_r; + // validity check + f_r * g_r == h_r +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + encrypt::{encrypt_share, PublicKey}, + field::{Field32, FieldPrio2}, + test_vector::Priov2TestVector, + util::{self, unpack_proof_mut}, + }; + use serde_json; + + #[test] + fn test_validation() { + let dim = 8; + let proof_u32: Vec<u32> = vec![ + 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722, + 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680, + 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149, + ]; + + let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect(); + let share2 = util::tests::secret_share(&mut proof); + let eval_at = Field32::from(12313); + + let mut validation_mem = ValidationMemory::new(dim); + + let v1 = + generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap(); + let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem) + .unwrap(); + assert!(is_valid_share(&v1, &v2)); + } + + #[test] + fn test_verification_message_serde() { + let dim = 8; + let proof_u32: Vec<u32> = vec![ + 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722, + 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680, + 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149, + ]; + + let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect(); + let share2 = util::tests::secret_share(&mut proof); + let eval_at = Field32::from(12313); + + let mut validation_mem = ValidationMemory::new(dim); + + let v1 = + generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap(); + let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem) + .unwrap(); + + // serialize and deserialize the first verification message + let serialized = serde_json::to_string(&v1).unwrap(); + let deserialized: VerificationMessage<Field32> = serde_json::from_str(&serialized).unwrap(); + + assert!(is_valid_share(&deserialized, &v2)); + } + + #[derive(Debug, Clone, Copy, PartialEq)] + enum Tweak { + None, + WrongInput, + DataPartOfShare, + ZeroTermF, + ZeroTermG, + ZeroTermH, + PointsH, + VerificationF, + VerificationG, + VerificationH, + } + + fn tweaks(tweak: Tweak) { + let dim = 123; + + // We generate a test vector just to get a `Client` and `Server`s with + // encryption keys but construct and tweak inputs below. + let test_vector = Priov2TestVector::new(dim, 0).unwrap(); + let mut server1 = test_vector.server_1().unwrap(); + let mut server2 = test_vector.server_2().unwrap(); + let mut client = test_vector.client().unwrap(); + + // all zero data + let mut data = vec![FieldPrio2::zero(); dim]; + + if let Tweak::WrongInput = tweak { + data[0] = FieldPrio2::from(2); + } + + let (share1_original, share2) = client.encode_simple(&data).unwrap(); + + let decrypted_share1 = decrypt_share(&share1_original, &server1.private_key).unwrap(); + let mut share1_field = FieldPrio2::byte_slice_into_vec(&decrypted_share1).unwrap(); + let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap(); + + let one = FieldPrio2::from(1); + + match tweak { + Tweak::DataPartOfShare => unpacked_share1.data[0] += one, + Tweak::ZeroTermF => *unpacked_share1.f0 += one, + Tweak::ZeroTermG => *unpacked_share1.g0 += one, + Tweak::ZeroTermH => *unpacked_share1.h0 += one, + Tweak::PointsH => unpacked_share1.points_h_packed[0] += one, + _ => (), + }; + + // reserialize altered share1 + let share1_modified = encrypt_share( + &FieldPrio2::slice_into_byte_vec(&share1_field), + &PublicKey::from(&server1.private_key), + ) + .unwrap(); + + let eval_at = server1.choose_eval_at(); + + let mut v1 = server1 + .generate_verification_message(eval_at, &share1_modified) + .unwrap(); + let v2 = server2 + .generate_verification_message(eval_at, &share2) + .unwrap(); + + match tweak { + Tweak::VerificationF => v1.f_r += one, + Tweak::VerificationG => v1.g_r += one, + Tweak::VerificationH => v1.h_r += one, + _ => (), + } + + let should_be_valid = matches!(tweak, Tweak::None); + assert_eq!( + server1.aggregate(&share1_modified, &v1, &v2).unwrap(), + should_be_valid + ); + assert_eq!( + server2.aggregate(&share2, &v1, &v2).unwrap(), + should_be_valid + ); + } + + #[test] + fn tweak_none() { + tweaks(Tweak::None); + } + + #[test] + fn tweak_input() { + tweaks(Tweak::WrongInput); + } + + #[test] + fn tweak_data() { + tweaks(Tweak::DataPartOfShare); + } + + #[test] + fn tweak_f_zero() { + tweaks(Tweak::ZeroTermF); + } + + #[test] + fn tweak_g_zero() { + tweaks(Tweak::ZeroTermG); + } + + #[test] + fn tweak_h_zero() { + tweaks(Tweak::ZeroTermH); + } + + #[test] + fn tweak_h_points() { + tweaks(Tweak::PointsH); + } + + #[test] + fn tweak_f_verif() { + tweaks(Tweak::VerificationF); + } + + #[test] + fn tweak_g_verif() { + tweaks(Tweak::VerificationG); + } + + #[test] + fn tweak_h_verif() { + tweaks(Tweak::VerificationH); + } +} diff --git a/third_party/rust/prio/src/test_vector.rs b/third_party/rust/prio/src/test_vector.rs new file mode 100644 index 0000000000..1306bc26f9 --- /dev/null +++ b/third_party/rust/prio/src/test_vector.rs @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Module `test_vector` generates test vectors of serialized Prio inputs and +//! support for working with test vectors, enabling backward compatibility +//! testing. + +use crate::{ + client::{Client, ClientError}, + encrypt::{PrivateKey, PublicKey}, + field::{FieldElement, FieldPrio2}, + server::{Server, ServerError}, +}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +/// Errors propagated by functions in this module. +#[derive(Debug, thiserror::Error)] +pub enum TestVectorError { + /// Error from Prio client + #[error("Prio client error {0}")] + Client(#[from] ClientError), + /// Error from Prio server + #[error("Prio server error {0}")] + Server(#[from] ServerError), + /// Error while converting primitive to FieldElement associated integer type + #[error("Integer conversion error {0}")] + IntegerConversion(String), +} + +const SERVER_1_PRIVATE_KEY: &str = + "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBH\ + fNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw=="; +const SERVER_2_PRIVATE_KEY: &str = + "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rD\ + ULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w=="; + +/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate +/// test vectors and later to decrypt them. +fn server_1_private_key() -> PrivateKey { + PrivateKey::from_base64(SERVER_1_PRIVATE_KEY).unwrap() +} + +/// The public portion of [`server_1_private_key`]. +fn server_1_public_key() -> PublicKey { + PublicKey::from(&server_1_private_key()) +} + +/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate +/// test vectors and later to decrypt them. +fn server_2_private_key() -> PrivateKey { + PrivateKey::from_base64(SERVER_2_PRIVATE_KEY).unwrap() +} + +/// The public portion of [`server_2_private_key`]. +fn server_2_public_key() -> PublicKey { + PublicKey::from(&server_2_private_key()) +} + +/// A test vector of Prio inputs, serialized and encrypted in the Priov2 format, +/// along with a reference sum. The field is always [`FieldPrio2`]. +#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct Priov2TestVector { + /// Base64 encoded private key for the "first" a.k.a. "PHA" server, which + /// may be used to decrypt `server_1_shares`. + pub server_1_private_key: String, + /// Base64 encoded private key for the non-"first" a.k.a. "facilitator" + /// server, which may be used to decrypt `server_2_shares`. + pub server_2_private_key: String, + /// Dimension (number of buckets) of the inputs + pub dimension: usize, + /// Encrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA" + /// server. The inner `Vec`s are encrypted bytes. + #[serde( + serialize_with = "base64::serialize_bytes", + deserialize_with = "base64::deserialize_bytes" + )] + pub server_1_shares: Vec<Vec<u8>>, + /// Encrypted share of Priov2 format inputs for the non-"first" a.k.a. + /// "facilitator" server. + #[serde( + serialize_with = "base64::serialize_bytes", + deserialize_with = "base64::deserialize_bytes" + )] + pub server_2_shares: Vec<Vec<u8>>, + /// The sum over the inputs. + #[serde( + serialize_with = "base64::serialize_field", + deserialize_with = "base64::deserialize_field" + )] + pub reference_sum: Vec<FieldPrio2>, + /// The version of the crate that generated this test vector + pub prio_crate_version: String, +} + +impl Priov2TestVector { + /// Construct a test vector of `number_of_clients` inputs, each of which is a + /// `dimension`-dimension vector of random Boolean values encoded as + /// [`FieldPrio2`]. + pub fn new(dimension: usize, number_of_clients: usize) -> Result<Self, TestVectorError> { + let mut client: Client<FieldPrio2> = + Client::new(dimension, server_1_public_key(), server_2_public_key())?; + + let mut reference_sum = vec![FieldPrio2::zero(); dimension]; + let mut server_1_shares = Vec::with_capacity(number_of_clients); + let mut server_2_shares = Vec::with_capacity(number_of_clients); + + let mut rng = rand::thread_rng(); + + for _ in 0..number_of_clients { + // Generate a random vector of booleans + let data: Vec<FieldPrio2> = (0..dimension) + .map(|_| FieldPrio2::from(rng.gen_range(0..2))) + .collect(); + + // Update reference sum + for (r, d) in reference_sum.iter_mut().zip(&data) { + *r += *d; + } + + let (server_1_share, server_2_share) = client.encode_simple(&data)?; + + server_1_shares.push(server_1_share); + server_2_shares.push(server_2_share); + } + + Ok(Self { + server_1_private_key: SERVER_1_PRIVATE_KEY.to_owned(), + server_2_private_key: SERVER_2_PRIVATE_KEY.to_owned(), + dimension, + server_1_shares, + server_2_shares, + reference_sum, + prio_crate_version: env!("CARGO_PKG_VERSION").to_owned(), + }) + } + + /// Construct a [`Client`] that can encrypt input shares to this test + /// vector's servers. + pub fn client(&self) -> Result<Client<FieldPrio2>, TestVectorError> { + Ok(Client::new( + self.dimension, + PublicKey::from(&PrivateKey::from_base64(&self.server_1_private_key).unwrap()), + PublicKey::from(&PrivateKey::from_base64(&self.server_2_private_key).unwrap()), + )?) + } + + /// Construct a [`Server`] that can decrypt `server_1_shares`. + pub fn server_1(&self) -> Result<Server<FieldPrio2>, TestVectorError> { + Ok(Server::new( + self.dimension, + true, + PrivateKey::from_base64(&self.server_1_private_key).unwrap(), + )?) + } + + /// Construct a [`Server`] that can decrypt `server_2_shares`. + pub fn server_2(&self) -> Result<Server<FieldPrio2>, TestVectorError> { + Ok(Server::new( + self.dimension, + false, + PrivateKey::from_base64(&self.server_2_private_key).unwrap(), + )?) + } +} + +mod base64 { + //! Custom serialization module used for some members of struct + //! `Priov2TestVector` so that byte slices are serialized as base64 strings + //! instead of an array of an array of integers when serializing to JSON. + // + // Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2 + use crate::field::{FieldElement, FieldPrio2}; + use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> { + let base64_vec = v.iter().map(base64::encode).collect(); + <Vec<String>>::serialize(&base64_vec, s) + } + + pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> { + <Vec<String>>::deserialize(d)? + .iter() + .map(|s| base64::decode(s.as_bytes()).map_err(Error::custom)) + .collect() + } + + pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> { + String::serialize(&base64::encode(FieldPrio2::slice_into_byte_vec(v)), s) + } + + pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> { + let bytes = base64::decode(String::deserialize(d)?.as_bytes()).map_err(Error::custom)?; + FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::util::reconstruct_shares; + + #[test] + fn roundtrip_test_vector_serialization() { + let test_vector = Priov2TestVector::new(123, 100).unwrap(); + let serialized = serde_json::to_vec(&test_vector).unwrap(); + let test_vector_again: Priov2TestVector = serde_json::from_slice(&serialized).unwrap(); + + assert_eq!(test_vector, test_vector_again); + } + + #[test] + fn accumulation_field_priov2() { + let dimension = 123; + let test_vector = Priov2TestVector::new(dimension, 100).unwrap(); + + let mut server1 = test_vector.server_1().unwrap(); + let mut server2 = test_vector.server_2().unwrap(); + + for (server_1_share, server_2_share) in test_vector + .server_1_shares + .iter() + .zip(&test_vector.server_2_shares) + { + let eval_at = server1.choose_eval_at(); + + let v1 = server1 + .generate_verification_message(eval_at, server_1_share) + .unwrap(); + let v2 = server2 + .generate_verification_message(eval_at, server_2_share) + .unwrap(); + + assert!(server1.aggregate(server_1_share, &v1, &v2).unwrap()); + assert!(server2.aggregate(server_2_share, &v1, &v2).unwrap()); + } + + let total1 = server1.total_shares(); + let total2 = server2.total_shares(); + + let reconstructed = reconstruct_shares(total1, total2).unwrap(); + assert_eq!(reconstructed, test_vector.reference_sum); + } +} diff --git a/third_party/rust/prio/src/util.rs b/third_party/rust/prio/src/util.rs new file mode 100644 index 0000000000..0518112d83 --- /dev/null +++ b/third_party/rust/prio/src/util.rs @@ -0,0 +1,201 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Utility functions for handling Prio stuff. + +use crate::field::{FieldElement, FieldError}; + +/// Serialization errors +#[derive(Debug, thiserror::Error)] +pub enum SerializeError { + /// Emitted by `unpack_proof[_mut]` if the serialized share+proof has the wrong length + #[error("serialized input has wrong length")] + UnpackInputSizeMismatch, + /// Finite field operation error. + #[error("finite field operation error")] + Field(#[from] FieldError), +} + +/// Returns the number of field elements in the proof for given dimension of +/// data elements +/// +/// Proof is a vector, where the first `dimension` elements are the data +/// elements, the next 3 elements are the zero terms for polynomials f, g and h +/// and the remaining elements are non-zero points of h(x). +pub fn proof_length(dimension: usize) -> usize { + // number of data items + number of zero terms + N + dimension + 3 + (dimension + 1).next_power_of_two() +} + +/// Unpacked proof with subcomponents +#[derive(Debug)] +pub struct UnpackedProof<'a, F: FieldElement> { + /// Data + pub data: &'a [F], + /// Zeroth coefficient of polynomial f + pub f0: &'a F, + /// Zeroth coefficient of polynomial g + pub g0: &'a F, + /// Zeroth coefficient of polynomial h + pub h0: &'a F, + /// Non-zero points of polynomial h + pub points_h_packed: &'a [F], +} + +/// Unpacked proof with mutable subcomponents +#[derive(Debug)] +pub struct UnpackedProofMut<'a, F: FieldElement> { + /// Data + pub data: &'a mut [F], + /// Zeroth coefficient of polynomial f + pub f0: &'a mut F, + /// Zeroth coefficient of polynomial g + pub g0: &'a mut F, + /// Zeroth coefficient of polynomial h + pub h0: &'a mut F, + /// Non-zero points of polynomial h + pub points_h_packed: &'a mut [F], +} + +/// Unpacks the proof vector into subcomponents +pub(crate) fn unpack_proof<F: FieldElement>( + proof: &[F], + dimension: usize, +) -> Result<UnpackedProof<F>, SerializeError> { + // check the proof length + if proof.len() != proof_length(dimension) { + return Err(SerializeError::UnpackInputSizeMismatch); + } + // split share into components + let (data, rest) = proof.split_at(dimension); + if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) { + Ok(UnpackedProof { + data, + f0, + g0, + h0, + points_h_packed, + }) + } else { + Err(SerializeError::UnpackInputSizeMismatch) + } +} + +/// Unpacks a mutable proof vector into mutable subcomponents +// TODO(timg): This is public because it is used by tests/tweaks.rs. We should +// refactor that test so it doesn't require the crate to expose this function or +// UnpackedProofMut. +pub fn unpack_proof_mut<F: FieldElement>( + proof: &mut [F], + dimension: usize, +) -> Result<UnpackedProofMut<F>, SerializeError> { + // check the share length + if proof.len() != proof_length(dimension) { + return Err(SerializeError::UnpackInputSizeMismatch); + } + // split share into components + let (data, rest) = proof.split_at_mut(dimension); + if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) { + Ok(UnpackedProofMut { + data, + f0, + g0, + h0, + points_h_packed, + }) + } else { + Err(SerializeError::UnpackInputSizeMismatch) + } +} + +/// Add two field element arrays together elementwise. +/// +/// Returns None, when array lengths are not equal. +pub fn reconstruct_shares<F: FieldElement>(share1: &[F], share2: &[F]) -> Option<Vec<F>> { + if share1.len() != share2.len() { + return None; + } + + let mut reconstructed: Vec<F> = vec![F::zero(); share1.len()]; + + for (r, (s1, s2)) in reconstructed + .iter_mut() + .zip(share1.iter().zip(share2.iter())) + { + *r = *s1 + *s2; + } + + Some(reconstructed) +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::field::{Field32, Field64}; + use assert_matches::assert_matches; + + pub fn secret_share(share: &mut [Field32]) -> Vec<Field32> { + use rand::Rng; + let mut rng = rand::thread_rng(); + let mut random = vec![0u32; share.len()]; + let mut share2 = vec![Field32::zero(); share.len()]; + + rng.fill(&mut random[..]); + + for (r, f) in random.iter().zip(share2.iter_mut()) { + *f = Field32::from(*r); + } + + for (f1, f2) in share.iter_mut().zip(share2.iter()) { + *f1 -= *f2; + } + + share2 + } + + #[test] + fn test_unpack_share_mut() { + let dim = 15; + let len = proof_length(dim); + + let mut share = vec![Field32::from(0); len]; + let unpacked = unpack_proof_mut(&mut share, dim).unwrap(); + *unpacked.f0 = Field32::from(12); + assert_eq!(share[dim], 12); + + let mut short_share = vec![Field32::from(0); len - 1]; + assert_matches!( + unpack_proof_mut(&mut short_share, dim), + Err(SerializeError::UnpackInputSizeMismatch) + ); + } + + #[test] + fn test_unpack_share() { + let dim = 15; + let len = proof_length(dim); + + let share = vec![Field64::from(0); len]; + unpack_proof(&share, dim).unwrap(); + + let short_share = vec![Field64::from(0); len - 1]; + assert_matches!( + unpack_proof(&short_share, dim), + Err(SerializeError::UnpackInputSizeMismatch) + ); + } + + #[test] + fn secret_sharing() { + let mut share1 = vec![Field32::zero(); 10]; + share1[3] = 21.into(); + share1[8] = 123.into(); + + let original_data = share1.clone(); + + let share2 = secret_share(&mut share1); + + let reconstructed = reconstruct_shares(&share1, &share2).unwrap(); + assert_eq!(reconstructed, original_data); + } +} diff --git a/third_party/rust/prio/src/vdaf.rs b/third_party/rust/prio/src/vdaf.rs new file mode 100644 index 0000000000..f75a2c488b --- /dev/null +++ b/third_party/rust/prio/src/vdaf.rs @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Verifiable Distributed Aggregation Functions (VDAFs) as described in +//! [[draft-irtf-cfrg-vdaf-03]]. +//! +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; +use crate::field::{FieldElement, FieldError}; +use crate::flp::FlpError; +use crate::prng::PrngError; +use crate::vdaf::prg::Seed; +use serde::{Deserialize, Serialize}; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::io::Cursor; + +/// A component of the domain-separation tag, used to bind the VDAF operations to the document +/// version. This will be revised with each draft with breaking changes. +const VERSION: &[u8] = b"vdaf-03"; +/// Length of the domain-separation tag, including document version and algorithm ID. +const DST_LEN: usize = VERSION.len() + 4; + +/// Errors emitted by this module. +#[derive(Debug, thiserror::Error)] +pub enum VdafError { + /// An error occurred. + #[error("vdaf error: {0}")] + Uncategorized(String), + + /// Field error. + #[error("field error: {0}")] + Field(#[from] FieldError), + + /// An error occured while parsing a message. + #[error("io error: {0}")] + IoError(#[from] std::io::Error), + + /// FLP error. + #[error("flp error: {0}")] + Flp(#[from] FlpError), + + /// PRNG error. + #[error("prng error: {0}")] + Prng(#[from] PrngError), + + /// failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), +} + +/// An additive share of a vector of field elements. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Share<F, const L: usize> { + /// An uncompressed share, typically sent to the leader. + Leader(Vec<F>), + + /// A compressed share, typically sent to the helper. + Helper(Seed<L>), +} + +impl<F: Clone, const L: usize> Share<F, L> { + /// Truncate the Leader's share to the given length. If this is the Helper's share, then this + /// method clones the input without modifying it. + #[cfg(feature = "prio2")] + pub(crate) fn truncated(&self, len: usize) -> Self { + match self { + Self::Leader(ref data) => Self::Leader(data[..len].to_vec()), + Self::Helper(ref seed) => Self::Helper(seed.clone()), + } + } +} + +/// Parameters needed to decode a [`Share`] +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum ShareDecodingParameter<const L: usize> { + Leader(usize), + Helper, +} + +impl<F: FieldElement, const L: usize> ParameterizedDecode<ShareDecodingParameter<L>> + for Share<F, L> +{ + fn decode_with_param( + decoding_parameter: &ShareDecodingParameter<L>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + match decoding_parameter { + ShareDecodingParameter::Leader(share_length) => { + let mut data = Vec::with_capacity(*share_length); + for _ in 0..*share_length { + data.push(F::decode(bytes)?) + } + Ok(Self::Leader(data)) + } + ShareDecodingParameter::Helper => { + let seed = Seed::decode(bytes)?; + Ok(Self::Helper(seed)) + } + } + } +} + +impl<F: FieldElement, const L: usize> Encode for Share<F, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + match self { + Share::Leader(share_data) => { + for x in share_data { + x.encode(bytes); + } + } + Share::Helper(share_seed) => { + share_seed.encode(bytes); + } + } + } +} + +/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`], +/// and [`Collector`], which define the roles of the various parties involved in the execution of +/// the VDAF. +// TODO(brandon): once GATs are stabilized [https://github.com/rust-lang/rust/issues/44265], +// state the "&AggregateShare must implement Into<Vec<u8>>" constraint in terms of a where clause +// on the associated type instead of a where clause on the trait. +pub trait Vdaf: Clone + Debug +where + for<'a> &'a Self::AggregateShare: Into<Vec<u8>>, +{ + /// Algorithm identifier for this VDAF. + const ID: u32; + + /// The type of Client measurement to be aggregated. + type Measurement: Clone + Debug; + + /// The aggregate result of the VDAF execution. + type AggregateResult: Clone + Debug; + + /// The aggregation parameter, used by the Aggregators to map their input shares to output + /// shares. + type AggregationParam: Clone + Debug + Decode + Encode; + + /// A public share sent by a Client. + type PublicShare: Clone + Debug + for<'a> ParameterizedDecode<&'a Self> + Encode; + + /// An input share sent by a Client. + type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode; + + /// An output share recovered from an input share by an Aggregator. + type OutputShare: Clone + Debug; + + /// An Aggregator's share of the aggregate result. + type AggregateShare: Aggregatable<OutputShare = Self::OutputShare> + for<'a> TryFrom<&'a [u8]>; + + /// The number of Aggregators. The Client generates as many input shares as there are + /// Aggregators. + fn num_aggregators(&self) -> usize; +} + +/// The Client's role in the execution of a VDAF. +pub trait Client: Vdaf +where + for<'a> &'a Self::AggregateShare: Into<Vec<u8>>, +{ + /// Shards a measurement into a public share and a sequence of input shares, one for each + /// Aggregator. + fn shard( + &self, + measurement: &Self::Measurement, + ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError>; +} + +/// The Aggregator's role in the execution of a VDAF. +pub trait Aggregator<const L: usize>: Vdaf +where + for<'a> &'a Self::AggregateShare: Into<Vec<u8>>, +{ + /// State of the Aggregator during the Prepare process. + type PrepareState: Clone + Debug; + + /// The type of messages broadcast by each aggregator at each round of the Prepare Process. + /// + /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be + /// associated with any aggregator involved in the execution of the VDAF. + type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode; + + /// Result of preprocessing a round of preparation shares. + /// + /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be + /// associated with any aggregator involved in the execution of the VDAF. + type PrepareMessage: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode; + + /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned + /// is passed to [`Aggregator::prepare_step`] to get this aggregator's first-round prepare + /// message. + fn prepare_init( + &self, + verify_key: &[u8; L], + agg_id: usize, + agg_param: &Self::AggregationParam, + nonce: &[u8], + public_share: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>; + + /// Preprocess a round of preparation shares into a single input to [`Aggregator::prepare_step`]. + fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>( + &self, + inputs: M, + ) -> Result<Self::PrepareMessage, VdafError>; + + /// Compute the next state transition from the current state and the previous round of input + /// messages. If this returns [`PrepareTransition::Continue`], then the returned + /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from + /// this round and passed into another call to this method. This continues until this method + /// returns [`PrepareTransition::Finish`], at which point the returned output share may be + /// aggregated. If the method returns an error, the aggregator should consider its input share + /// invalid and not attempt to process it any further. + fn prepare_step( + &self, + state: Self::PrepareState, + input: Self::PrepareMessage, + ) -> Result<PrepareTransition<Self, L>, VdafError>; + + /// Aggregates a sequence of output shares into an aggregate share. + fn aggregate<M: IntoIterator<Item = Self::OutputShare>>( + &self, + agg_param: &Self::AggregationParam, + output_shares: M, + ) -> Result<Self::AggregateShare, VdafError>; +} + +/// The Collector's role in the execution of a VDAF. +pub trait Collector: Vdaf +where + for<'a> &'a Self::AggregateShare: Into<Vec<u8>>, +{ + /// Combines aggregate shares into the aggregate result. + fn unshard<M: IntoIterator<Item = Self::AggregateShare>>( + &self, + agg_param: &Self::AggregationParam, + agg_shares: M, + num_measurements: usize, + ) -> Result<Self::AggregateResult, VdafError>; +} + +/// A state transition of an Aggregator during the Prepare process. +#[derive(Debug)] +pub enum PrepareTransition<V: Aggregator<L>, const L: usize> +where + for<'a> &'a V::AggregateShare: Into<Vec<u8>>, +{ + /// Continue processing. + Continue(V::PrepareState, V::PrepareShare), + + /// Finish processing and return the output share. + Finish(V::OutputShare), +} + +/// An aggregate share resulting from aggregating output shares together that +/// can merged with aggregate shares of the same type. +pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> { + /// Type of output shares that can be accumulated into an aggregate share. + type OutputShare; + + /// Update an aggregate share by merging it with another (`agg_share`). + fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>; + + /// Update an aggregate share by adding `output share` + fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>; +} + +/// An output share comprised of a vector of `F` elements. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OutputShare<F>(Vec<F>); + +impl<F> AsRef<[F]> for OutputShare<F> { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl<F> From<Vec<F>> for OutputShare<F> { + fn from(other: Vec<F>) -> Self { + Self(other) + } +} + +impl<F: FieldElement> TryFrom<&[u8]> for OutputShare<F> { + type Error = FieldError; + + fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> { + fieldvec_try_from_bytes(bytes) + } +} + +impl<F: FieldElement> From<&OutputShare<F>> for Vec<u8> { + fn from(output_share: &OutputShare<F>) -> Self { + fieldvec_to_vec(&output_share.0) + } +} + +/// An aggregate share suitable for VDAFs whose output shares and aggregate +/// shares are vectors of `F` elements, and an output share needs no special +/// transformation to be merged into an aggregate share. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct AggregateShare<F>(Vec<F>); + +impl<F> AsRef<[F]> for AggregateShare<F> { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl<F> From<OutputShare<F>> for AggregateShare<F> { + fn from(other: OutputShare<F>) -> Self { + Self(other.0) + } +} + +impl<F> From<Vec<F>> for AggregateShare<F> { + fn from(other: Vec<F>) -> Self { + Self(other) + } +} + +impl<F: FieldElement> Aggregatable for AggregateShare<F> { + type OutputShare = OutputShare<F>; + + fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> { + self.sum(agg_share.as_ref()) + } + + fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> { + // For prio3 and poplar1, no conversion is needed between output shares and aggregation + // shares. + self.sum(output_share.as_ref()) + } +} + +impl<F: FieldElement> AggregateShare<F> { + fn sum(&mut self, other: &[F]) -> Result<(), VdafError> { + if self.0.len() != other.len() { + return Err(VdafError::Uncategorized(format!( + "cannot sum shares of different lengths (left = {}, right = {}", + self.0.len(), + other.len() + ))); + } + + for (x, y) in self.0.iter_mut().zip(other) { + *x += *y; + } + + Ok(()) + } +} + +impl<F: FieldElement> TryFrom<&[u8]> for AggregateShare<F> { + type Error = FieldError; + + fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> { + fieldvec_try_from_bytes(bytes) + } +} + +impl<F: FieldElement> From<&AggregateShare<F>> for Vec<u8> { + fn from(aggregate_share: &AggregateShare<F>) -> Self { + fieldvec_to_vec(&aggregate_share.0) + } +} + +/// fieldvec_try_from_bytes converts a slice of bytes to a type that is equivalent to a vector of +/// field elements. +#[inline(always)] +fn fieldvec_try_from_bytes<F: FieldElement, T: From<Vec<F>>>( + bytes: &[u8], +) -> Result<T, FieldError> { + F::byte_slice_into_vec(bytes).map(T::from) +} + +/// fieldvec_to_vec converts a type that is equivalent to a vector of field elements into a vector +/// of bytes. +#[inline(always)] +fn fieldvec_to_vec<F: FieldElement, T: AsRef<[F]>>(val: T) -> Vec<u8> { + F::slice_into_byte_vec(val.as_ref()) +} + +#[cfg(test)] +pub(crate) fn run_vdaf<V, M, const L: usize>( + vdaf: &V, + agg_param: &V::AggregationParam, + measurements: M, +) -> Result<V::AggregateResult, VdafError> +where + V: Client + Aggregator<L> + Collector, + for<'a> &'a V::AggregateShare: Into<Vec<u8>>, + M: IntoIterator<Item = V::Measurement>, +{ + use rand::prelude::*; + let mut verify_key = [0; L]; + thread_rng().fill(&mut verify_key[..]); + + // NOTE Here we use the same nonce for each measurement for testing purposes. However, this is + // not secure. In use, the Aggregators MUST ensure that nonces are unique for each measurement. + let nonce = b"this is a nonce"; + + let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()]; + let mut num_measurements: usize = 0; + for measurement in measurements.into_iter() { + num_measurements += 1; + let (public_share, input_shares) = vdaf.shard(&measurement)?; + let out_shares = run_vdaf_prepare( + vdaf, + &verify_key, + agg_param, + nonce, + public_share, + input_shares, + )?; + for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) { + if let Some(ref mut inner) = agg_share { + inner.merge(&out_share.into())?; + } else { + *agg_share = Some(out_share.into()); + } + } + } + + let res = vdaf.unshard( + agg_param, + agg_shares.into_iter().map(|option| option.unwrap()), + num_measurements, + )?; + Ok(res) +} + +#[cfg(test)] +pub(crate) fn run_vdaf_prepare<V, M, const L: usize>( + vdaf: &V, + verify_key: &[u8; L], + agg_param: &V::AggregationParam, + nonce: &[u8], + public_share: V::PublicShare, + input_shares: M, +) -> Result<Vec<V::OutputShare>, VdafError> +where + V: Client + Aggregator<L> + Collector, + for<'a> &'a V::AggregateShare: Into<Vec<u8>>, + M: IntoIterator<Item = V::InputShare>, +{ + let input_shares = input_shares + .into_iter() + .map(|input_share| input_share.get_encoded()); + + let mut states = Vec::new(); + let mut outbound = Vec::new(); + for (agg_id, input_share) in input_shares.enumerate() { + let (state, msg) = vdaf.prepare_init( + verify_key, + agg_id, + agg_param, + nonce, + &public_share, + &V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share) + .expect("failed to decode input share"), + )?; + states.push(state); + outbound.push(msg.get_encoded()); + } + + let mut inbound = vdaf + .prepare_preprocess(outbound.iter().map(|encoded| { + V::PrepareShare::get_decoded_with_param(&states[0], encoded) + .expect("failed to decode prep share") + }))? + .get_encoded(); + + let mut out_shares = Vec::new(); + loop { + let mut outbound = Vec::new(); + for state in states.iter_mut() { + match vdaf.prepare_step( + state.clone(), + V::PrepareMessage::get_decoded_with_param(state, &inbound) + .expect("failed to decode prep message"), + )? { + PrepareTransition::Continue(new_state, msg) => { + outbound.push(msg.get_encoded()); + *state = new_state + } + PrepareTransition::Finish(out_share) => { + out_shares.push(out_share); + } + } + } + + if outbound.len() == vdaf.num_aggregators() { + // Another round is required before output shares are computed. + inbound = vdaf + .prepare_preprocess(outbound.iter().map(|encoded| { + V::PrepareShare::get_decoded_with_param(&states[0], encoded) + .expect("failed to decode prep share") + }))? + .get_encoded(); + } else if outbound.is_empty() { + // Each Aggregator recovered an output share. + break; + } else { + panic!("Aggregators did not finish the prepare phase at the same time"); + } + } + + Ok(out_shares) +} + +#[cfg(test)] +mod tests { + use super::{AggregateShare, OutputShare}; + use crate::field::{Field128, Field64, FieldElement}; + use itertools::iterate; + use std::convert::TryFrom; + use std::fmt::Debug; + + fn fieldvec_roundtrip_test<F, T>() + where + F: FieldElement, + for<'a> T: Debug + PartialEq + From<Vec<F>> + TryFrom<&'a [u8]>, + for<'a> <T as TryFrom<&'a [u8]>>::Error: Debug, + for<'a> Vec<u8>: From<&'a T>, + { + // Generate a value based on an arbitrary vector of field elements. + let g = F::generator(); + let want_value = T::from(iterate(F::one(), |&v| g * v).take(10).collect()); + + // Round-trip the value through a byte-vector. + let buf: Vec<u8> = (&want_value).into(); + let got_value = T::try_from(&buf).unwrap(); + + assert_eq!(want_value, got_value); + } + + #[test] + fn roundtrip_output_share() { + fieldvec_roundtrip_test::<Field64, OutputShare<Field64>>(); + fieldvec_roundtrip_test::<Field128, OutputShare<Field128>>(); + } + + #[test] + fn roundtrip_aggregate_share() { + fieldvec_roundtrip_test::<Field64, AggregateShare<Field64>>(); + fieldvec_roundtrip_test::<Field128, AggregateShare<Field128>>(); + } +} + +#[cfg(feature = "crypto-dependencies")] +pub mod poplar1; +pub mod prg; +#[cfg(feature = "prio2")] +pub mod prio2; +pub mod prio3; +#[cfg(test)] +mod prio3_test; diff --git a/third_party/rust/prio/src/vdaf/poplar1.rs b/third_party/rust/prio/src/vdaf/poplar1.rs new file mode 100644 index 0000000000..f6ab110ebb --- /dev/null +++ b/third_party/rust/prio/src/vdaf/poplar1.rs @@ -0,0 +1,933 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! **(NOTE: This module is experimental. Applications should not use it yet.)** This module +//! partially implements the core component of the Poplar protocol [[BBCG+21]]. Named for the +//! Poplar1 section of [[draft-irtf-cfrg-vdaf-03]], the specification of this VDAF is under active +//! development. Thus this code should be regarded as experimental and not compliant with any +//! existing speciication. +//! +//! TODO Make the input shares stateful so that applications can efficiently evaluate the IDPF over +//! multiple rounds. Question: Will this require API changes to [`crate::vdaf::Vdaf`]? +//! +//! TODO Update trait [`Idpf`] so that the IDPF can have a different field type at the leaves than +//! at the inner nodes. +//! +//! TODO Implement the efficient IDPF of [[BBCG+21]]. [`ToyIdpf`] is not space efficient and is +//! merely intended as a proof-of-concept. +//! +//! [BBCG+21]: https://eprint.iacr.org/2021/017 +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +use std::cmp::Ordering; +use std::collections::{BTreeMap, BTreeSet}; +use std::convert::{TryFrom, TryInto}; +use std::fmt::Debug; +use std::io::Cursor; +use std::iter::FromIterator; +use std::marker::PhantomData; + +use crate::codec::{ + decode_u16_items, decode_u24_items, encode_u16_items, encode_u24_items, CodecError, Decode, + Encode, ParameterizedDecode, +}; +use crate::field::{split_vector, FieldElement}; +use crate::fp::log2; +use crate::prng::Prng; +use crate::vdaf::prg::{Prg, Seed}; +use crate::vdaf::{ + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, + Share, ShareDecodingParameter, Vdaf, VdafError, +}; + +/// An input for an IDPF ([`Idpf`]). +/// +/// TODO Make this an associated type of `Idpf`. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct IdpfInput { + index: usize, + level: usize, +} + +impl IdpfInput { + /// Constructs an IDPF input using the first `level` bits of `data`. + pub fn new(data: &[u8], level: usize) -> Result<Self, VdafError> { + if level > data.len() << 3 { + return Err(VdafError::Uncategorized(format!( + "desired bit length ({} bits) exceeds data length ({} bytes)", + level, + data.len() + ))); + } + + let mut index = 0; + let mut i = 0; + for byte in data { + for j in 0..8 { + let bit = (byte >> j) & 1; + if i < level { + index |= (bit as usize) << i; + } + i += 1; + } + } + + Ok(Self { index, level }) + } + + /// Construct a new input that is a prefix of `self`. Bounds checking is performed by the + /// caller. + fn prefix(&self, level: usize) -> Self { + let index = self.index & ((1 << level) - 1); + Self { index, level } + } + + /// Return the position of `self` in the look-up table of `ToyIdpf`. + fn data_index(&self) -> usize { + self.index | (1 << self.level) + } +} + +impl Ord for IdpfInput { + fn cmp(&self, other: &Self) -> Ordering { + match self.level.cmp(&other.level) { + Ordering::Equal => self.index.cmp(&other.index), + ord => ord, + } + } +} + +impl PartialOrd for IdpfInput { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Encode for IdpfInput { + fn encode(&self, bytes: &mut Vec<u8>) { + (self.index as u64).encode(bytes); + (self.level as u64).encode(bytes); + } +} + +impl Decode for IdpfInput { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let index = u64::decode(bytes)? as usize; + let level = u64::decode(bytes)? as usize; + + Ok(Self { index, level }) + } +} + +/// An Incremental Distributed Point Function (IDPF), as defined by [[BBCG+21]]. +/// +/// [BBCG+21]: https://eprint.iacr.org/2021/017 +// +// NOTE(cjpatton) The real IDPF API probably needs to be stateful. +pub trait Idpf<const KEY_LEN: usize, const OUT_LEN: usize>: + Sized + Clone + Debug + Encode + Decode +{ + /// The finite field over which the IDPF is defined. + // + // NOTE(cjpatton) The IDPF of [BBCG+21] might use different fields for different levels of the + // prefix tree. + type Field: FieldElement; + + /// Generate and return a sequence of IDPF shares for `input`. Parameter `output` is an + /// iterator that is invoked to get the output value for each successive level of the prefix + /// tree. + fn gen<M: IntoIterator<Item = [Self::Field; OUT_LEN]>>( + input: &IdpfInput, + values: M, + ) -> Result<[Self; KEY_LEN], VdafError>; + + /// Evaluate an IDPF share on `prefix`. + fn eval(&self, prefix: &IdpfInput) -> Result<[Self::Field; OUT_LEN], VdafError>; +} + +/// A "toy" IDPF used for demonstration purposes. The space consumed by each share is `O(2^n)`, +/// where `n` is the length of the input. The size of each share is restricted to 1MB, so this IDPF +/// is only suitable for very short inputs. +// +// NOTE(cjpatton) It would be straight-forward to generalize this construction to any `KEY_LEN` and +// `OUT_LEN`. +#[derive(Debug, Clone)] +pub struct ToyIdpf<F> { + data0: Vec<F>, + data1: Vec<F>, + level: usize, +} + +impl<F: FieldElement> Idpf<2, 2> for ToyIdpf<F> { + type Field = F; + + fn gen<M: IntoIterator<Item = [Self::Field; 2]>>( + input: &IdpfInput, + values: M, + ) -> Result<[Self; 2], VdafError> { + const MAX_DATA_BYTES: usize = 1024 * 1024; // 1MB + + let max_input_len = + usize::try_from(log2((MAX_DATA_BYTES / F::ENCODED_SIZE) as u128)).unwrap(); + if input.level > max_input_len { + return Err(VdafError::Uncategorized(format!( + "input length ({}) exceeds maximum of ({})", + input.level, max_input_len + ))); + } + + let data_len = 1 << (input.level + 1); + let mut data0 = vec![F::zero(); data_len]; + let mut data1 = vec![F::zero(); data_len]; + let mut values = values.into_iter(); + for level in 0..input.level + 1 { + let value = values.next().unwrap(); + let index = input.prefix(level).data_index(); + data0[index] = value[0]; + data1[index] = value[1]; + } + + let mut data0 = split_vector(&data0, 2)?.into_iter(); + let mut data1 = split_vector(&data1, 2)?.into_iter(); + Ok([ + ToyIdpf { + data0: data0.next().unwrap(), + data1: data1.next().unwrap(), + level: input.level, + }, + ToyIdpf { + data0: data0.next().unwrap(), + data1: data1.next().unwrap(), + level: input.level, + }, + ]) + } + + fn eval(&self, prefix: &IdpfInput) -> Result<[F; 2], VdafError> { + if prefix.level > self.level { + return Err(VdafError::Uncategorized(format!( + "prefix length ({}) exceeds input length ({})", + prefix.level, self.level + ))); + } + + let index = prefix.data_index(); + Ok([self.data0[index], self.data1[index]]) + } +} + +impl<F: FieldElement> Encode for ToyIdpf<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + encode_u24_items(bytes, &(), &self.data0); + encode_u24_items(bytes, &(), &self.data1); + (self.level as u64).encode(bytes); + } +} + +impl<F: FieldElement> Decode for ToyIdpf<F> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let data0 = decode_u24_items(&(), bytes)?; + let data1 = decode_u24_items(&(), bytes)?; + let level = u64::decode(bytes)? as usize; + + Ok(Self { + data0, + data1, + level, + }) + } +} + +impl Encode for BTreeSet<IdpfInput> { + fn encode(&self, bytes: &mut Vec<u8>) { + // Encodes the aggregation parameter as a variable length vector of + // [`IdpfInput`], because the size of the aggregation parameter is not + // determined by the VDAF. + let items: Vec<IdpfInput> = self.iter().map(IdpfInput::clone).collect(); + encode_u24_items(bytes, &(), &items); + } +} + +impl Decode for BTreeSet<IdpfInput> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let inputs = decode_u24_items(&(), bytes)?; + Ok(Self::from_iter(inputs.into_iter())) + } +} + +/// An input share for the `poplar1` VDAF. +#[derive(Debug, Clone)] +pub struct Poplar1InputShare<I: Idpf<2, 2>, const L: usize> { + /// IDPF share of input + idpf: I, + + /// PRNG seed used to generate the aggregator's share of the randomness used in the first part + /// of the sketching protocol. + sketch_start_seed: Seed<L>, + + /// Aggregator's share of the randomness used in the second part of the sketching protocol. + sketch_next: Share<I::Field, L>, +} + +impl<I: Idpf<2, 2>, const L: usize> Encode for Poplar1InputShare<I, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + self.idpf.encode(bytes); + self.sketch_start_seed.encode(bytes); + self.sketch_next.encode(bytes); + } +} + +impl<'a, I, P, const L: usize> ParameterizedDecode<(&'a Poplar1<I, P, L>, usize)> + for Poplar1InputShare<I, L> +where + I: Idpf<2, 2>, +{ + fn decode_with_param( + (poplar1, agg_id): &(&'a Poplar1<I, P, L>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let idpf = I::decode(bytes)?; + let sketch_start_seed = Seed::decode(bytes)?; + let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?; + + let share_decoding_parameter = if is_leader { + // The sketch is two field elements for every bit of input, plus two more, corresponding + // to construction of shares in `Poplar1::shard`. + ShareDecodingParameter::Leader((poplar1.input_length + 1) * 2) + } else { + ShareDecodingParameter::Helper + }; + + let sketch_next = + <Share<I::Field, L>>::decode_with_param(&share_decoding_parameter, bytes)?; + + Ok(Self { + idpf, + sketch_start_seed, + sketch_next, + }) + } +} + +/// The poplar1 VDAF. +#[derive(Debug)] +pub struct Poplar1<I, P, const L: usize> { + input_length: usize, + phantom: PhantomData<(I, P)>, +} + +impl<I, P, const L: usize> Poplar1<I, P, L> { + /// Create an instance of the poplar1 VDAF. The caller provides a cipher suite `suite` used for + /// deriving pseudorandom sequences of field elements, and a input length in bits, corresponding + /// to `BITS` as defined in the [VDAF specification][1]. + /// + /// [1]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + pub fn new(bits: usize) -> Self { + Self { + input_length: bits, + phantom: PhantomData, + } + } +} + +impl<I, P, const L: usize> Clone for Poplar1<I, P, L> { + fn clone(&self) -> Self { + Self::new(self.input_length) + } +} +impl<I, P, const L: usize> Vdaf for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + // TODO: This currently uses a codepoint reserved for testing purposes. Replace it with + // 0x00001000 once the implementation is updated to match draft-irtf-cfrg-vdaf-03. + const ID: u32 = 0xFFFF0000; + type Measurement = IdpfInput; + type AggregateResult = BTreeMap<IdpfInput, u64>; + type AggregationParam = BTreeSet<IdpfInput>; + type PublicShare = (); // TODO: Replace this when the IDPF from [BBCGGI21] is implemented. + type InputShare = Poplar1InputShare<I, L>; + type OutputShare = OutputShare<I::Field>; + type AggregateShare = AggregateShare<I::Field>; + + fn num_aggregators(&self) -> usize { + 2 + } +} + +impl<I, P, const L: usize> Client for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + #[allow(clippy::many_single_char_names)] + fn shard(&self, input: &IdpfInput) -> Result<((), Vec<Poplar1InputShare<I, L>>), VdafError> { + let idpf_values: Vec<[I::Field; 2]> = Prng::new()? + .take(input.level + 1) + .map(|k| [I::Field::one(), k]) + .collect(); + + // For each level of the prefix tree, generate correlated randomness that the aggregators use + // to validate the output. See [BBCG+21, Appendix C.4]. + let leader_sketch_start_seed = Seed::generate()?; + let helper_sketch_start_seed = Seed::generate()?; + let helper_sketch_next_seed = Seed::generate()?; + let mut leader_sketch_start_prng: Prng<I::Field, _> = + Prng::from_seed_stream(P::seed_stream(&leader_sketch_start_seed, b"")); + let mut helper_sketch_start_prng: Prng<I::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper_sketch_start_seed, b"")); + let mut helper_sketch_next_prng: Prng<I::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper_sketch_next_seed, b"")); + let mut leader_sketch_next: Vec<I::Field> = Vec::with_capacity(2 * idpf_values.len()); + for value in idpf_values.iter() { + let k = value[1]; + + // [BBCG+21, Appendix C.4] + // + // $(a, b, c)$ + let a = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + let b = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + let c = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + + // $A = -2a + k$ + // $B = a^2 + b + -ak + c$ + let d = k - (a + a); + let e = (a * a) + b - (a * k) + c; + leader_sketch_next.push(d - helper_sketch_next_prng.get()); + leader_sketch_next.push(e - helper_sketch_next_prng.get()); + } + + // Generate IDPF shares of the data and authentication vectors. + let idpf_shares = I::gen(input, idpf_values)?; + + Ok(( + (), + vec![ + Poplar1InputShare { + idpf: idpf_shares[0].clone(), + sketch_start_seed: leader_sketch_start_seed, + sketch_next: Share::Leader(leader_sketch_next), + }, + Poplar1InputShare { + idpf: idpf_shares[1].clone(), + sketch_start_seed: helper_sketch_start_seed, + sketch_next: Share::Helper(helper_sketch_next_seed), + }, + ], + )) + } +} + +fn get_level(agg_param: &BTreeSet<IdpfInput>) -> Result<usize, VdafError> { + let mut level = None; + for prefix in agg_param { + if let Some(l) = level { + if prefix.level != l { + return Err(VdafError::Uncategorized( + "prefixes must all have the same length".to_string(), + )); + } + } else { + level = Some(prefix.level); + } + } + + match level { + Some(level) => Ok(level), + None => Err(VdafError::Uncategorized("prefix set is empty".to_string())), + } +} + +impl<I, P, const L: usize> Aggregator<L> for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + type PrepareState = Poplar1PrepareState<I::Field>; + type PrepareShare = Poplar1PrepareMessage<I::Field>; + type PrepareMessage = Poplar1PrepareMessage<I::Field>; + + #[allow(clippy::type_complexity)] + fn prepare_init( + &self, + verify_key: &[u8; L], + agg_id: usize, + agg_param: &BTreeSet<IdpfInput>, + nonce: &[u8], + _public_share: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result< + ( + Poplar1PrepareState<I::Field>, + Poplar1PrepareMessage<I::Field>, + ), + VdafError, + > { + let level = get_level(agg_param)?; + let is_leader = role_try_from(agg_id)?; + + // Derive the verification randomness. + let mut p = P::init(verify_key); + p.update(nonce); + let mut verify_rand_prng: Prng<I::Field, _> = Prng::from_seed_stream(p.into_seed_stream()); + + // Evaluate the IDPF shares and compute the polynomial coefficients. + let mut z = [I::Field::zero(); 3]; + let mut output_share = Vec::with_capacity(agg_param.len()); + for prefix in agg_param.iter() { + let value = input_share.idpf.eval(prefix)?; + let (v, k) = (value[0], value[1]); + let r = verify_rand_prng.get(); + + // [BBCG+21, Appendix C.4] + // + // $(z_\sigma, z^*_\sigma, z^{**}_\sigma)$ + let tmp = r * v; + z[0] += tmp; + z[1] += r * tmp; + z[2] += r * k; + output_share.push(v); + } + + // [BBCG+21, Appendix C.4] + // + // Add blind shares $(a_\sigma b_\sigma, c_\sigma)$ + // + // NOTE(cjpatton) We can make this faster by a factor of 3 by using three seed shares instead + // of one. On the other hand, if the input shares are made stateful, then we could store + // the PRNG state theire and avoid fast-forwarding. + let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream( + &input_share.sketch_start_seed, + b"", + )) + .skip(3 * level); + z[0] += prng.next().unwrap(); + z[1] += prng.next().unwrap(); + z[2] += prng.next().unwrap(); + + let (d, e) = match &input_share.sketch_next { + Share::Leader(data) => (data[2 * level], data[2 * level + 1]), + Share::Helper(seed) => { + let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream(seed, b"")) + .skip(2 * level); + (prng.next().unwrap(), prng.next().unwrap()) + } + }; + + let x = if is_leader { + I::Field::one() + } else { + I::Field::zero() + }; + + Ok(( + Poplar1PrepareState { + sketch: SketchState::RoundOne, + output_share: OutputShare(output_share), + d, + e, + x, + }, + Poplar1PrepareMessage(z.to_vec()), + )) + } + + fn prepare_preprocess<M: IntoIterator<Item = Poplar1PrepareMessage<I::Field>>>( + &self, + inputs: M, + ) -> Result<Poplar1PrepareMessage<I::Field>, VdafError> { + let mut output: Option<Vec<I::Field>> = None; + let mut count = 0; + for data_share in inputs.into_iter() { + count += 1; + if let Some(ref mut data) = output { + if data_share.0.len() != data.len() { + return Err(VdafError::Uncategorized(format!( + "unexpected message length: got {}; want {}", + data_share.0.len(), + data.len(), + ))); + } + + for (x, y) in data.iter_mut().zip(data_share.0.iter()) { + *x += *y; + } + } else { + output = Some(data_share.0); + } + } + + if count != 2 { + return Err(VdafError::Uncategorized(format!( + "unexpected message count: got {}; want 2", + count, + ))); + } + + Ok(Poplar1PrepareMessage(output.unwrap())) + } + + fn prepare_step( + &self, + mut state: Poplar1PrepareState<I::Field>, + msg: Poplar1PrepareMessage<I::Field>, + ) -> Result<PrepareTransition<Self, L>, VdafError> { + match &state.sketch { + SketchState::RoundOne => { + if msg.0.len() != 3 { + return Err(VdafError::Uncategorized(format!( + "unexpected message length ({:?}): got {}; want 3", + state.sketch, + msg.0.len(), + ))); + } + + // Compute polynomial coefficients. + let z: [I::Field; 3] = msg.0.try_into().unwrap(); + let y_share = + vec![(state.d * z[0]) + state.e + state.x * ((z[0] * z[0]) - z[1] - z[2])]; + + state.sketch = SketchState::RoundTwo; + Ok(PrepareTransition::Continue( + state, + Poplar1PrepareMessage(y_share), + )) + } + + SketchState::RoundTwo => { + if msg.0.len() != 1 { + return Err(VdafError::Uncategorized(format!( + "unexpected message length ({:?}): got {}; want 1", + state.sketch, + msg.0.len(), + ))); + } + + let y = msg.0[0]; + if y != I::Field::zero() { + return Err(VdafError::Uncategorized(format!( + "output is invalid: polynomial evaluated to {}; want {}", + y, + I::Field::zero(), + ))); + } + + Ok(PrepareTransition::Finish(state.output_share)) + } + } + } + + fn aggregate<M: IntoIterator<Item = OutputShare<I::Field>>>( + &self, + agg_param: &BTreeSet<IdpfInput>, + output_shares: M, + ) -> Result<AggregateShare<I::Field>, VdafError> { + let mut agg_share = AggregateShare(vec![I::Field::zero(); agg_param.len()]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + + Ok(agg_share) + } +} + +/// A prepare message sent exchanged between Poplar1 aggregators +#[derive(Clone, Debug)] +pub struct Poplar1PrepareMessage<F>(Vec<F>); + +impl<F> AsRef<[F]> for Poplar1PrepareMessage<F> { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl<F: FieldElement> Encode for Poplar1PrepareMessage<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + // TODO: This is encoded as a variable length vector of F, but we may + // be able to make this a fixed-length vector for specific Poplar1 + // instantations + encode_u16_items(bytes, &(), &self.0); + } +} + +impl<F: FieldElement> ParameterizedDecode<Poplar1PrepareState<F>> for Poplar1PrepareMessage<F> { + fn decode_with_param( + _decoding_parameter: &Poplar1PrepareState<F>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + // TODO: This is decoded as a variable length vector of F, but we may be + // able to make this a fixed-length vector for specific Poplar1 + // instantiations. + let items = decode_u16_items(&(), bytes)?; + + Ok(Self(items)) + } +} + +/// The state of each Aggregator during the Prepare process. +#[derive(Clone, Debug)] +pub struct Poplar1PrepareState<F> { + /// State of the secure sketching protocol. + sketch: SketchState, + + /// The output share. + output_share: OutputShare<F>, + + /// Aggregator's share of $A = -2a + k$. + d: F, + + /// Aggregator's share of $B = a^2 + b -ak + c$. + e: F, + + /// Equal to 1 if this Aggregator is the "leader" and 0 otherwise. + x: F, +} + +#[derive(Clone, Debug)] +enum SketchState { + RoundOne, + RoundTwo, +} + +impl<I, P, const L: usize> Collector for Poplar1<I, P, L> +where + I: Idpf<2, 2>, + P: Prg<L>, +{ + fn unshard<M: IntoIterator<Item = AggregateShare<I::Field>>>( + &self, + agg_param: &BTreeSet<IdpfInput>, + agg_shares: M, + _num_measurements: usize, + ) -> Result<BTreeMap<IdpfInput, u64>, VdafError> { + let mut agg_data = AggregateShare(vec![I::Field::zero(); agg_param.len()]); + for agg_share in agg_shares.into_iter() { + agg_data.merge(&agg_share)?; + } + + let mut agg = BTreeMap::new(); + for (prefix, count) in agg_param.iter().zip(agg_data.as_ref()) { + let count = <I::Field as FieldElement>::Integer::from(*count); + let count: u64 = count + .try_into() + .map_err(|_| VdafError::Uncategorized("aggregate overflow".to_string()))?; + agg.insert(*prefix, count); + } + Ok(agg) + } +} + +fn role_try_from(agg_id: usize) -> Result<bool, VdafError> { + match agg_id { + 0 => Ok(true), + 1 => Ok(false), + _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::field::Field128; + use crate::vdaf::prg::PrgAes128; + use crate::vdaf::{run_vdaf, run_vdaf_prepare}; + use rand::prelude::*; + + #[test] + fn test_idpf() { + // IDPF input equality tests. + assert_eq!( + IdpfInput::new(b"hello", 40).unwrap(), + IdpfInput::new(b"hello", 40).unwrap() + ); + assert_eq!( + IdpfInput::new(b"hi", 9).unwrap(), + IdpfInput::new(b"ha", 9).unwrap(), + ); + assert_eq!( + IdpfInput::new(b"hello", 25).unwrap(), + IdpfInput::new(b"help", 25).unwrap() + ); + assert_ne!( + IdpfInput::new(b"hello", 40).unwrap(), + IdpfInput::new(b"hello", 39).unwrap() + ); + assert_ne!( + IdpfInput::new(b"hello", 40).unwrap(), + IdpfInput::new(b"hell-", 40).unwrap() + ); + + // IDPF uniqueness tests + let mut unique = BTreeSet::new(); + assert!(unique.insert(IdpfInput::new(b"hello", 40).unwrap())); + assert!(!unique.insert(IdpfInput::new(b"hello", 40).unwrap())); + assert!(unique.insert(IdpfInput::new(b"hello", 39).unwrap())); + assert!(unique.insert(IdpfInput::new(b"bye", 20).unwrap())); + + // Generate IDPF keys. + let input = IdpfInput::new(b"hi", 16).unwrap(); + let keys = ToyIdpf::<Field128>::gen( + &input, + std::iter::repeat([Field128::one(), Field128::one()]), + ) + .unwrap(); + + // Try evaluating the IDPF keys on all prefixes. + for prefix_len in 0..input.level + 1 { + let res = eval_idpf( + &keys, + &input.prefix(prefix_len), + &[Field128::one(), Field128::one()], + ); + assert!(res.is_ok(), "prefix_len={} error: {:?}", prefix_len, res); + } + + // Try evaluating the IDPF keys on incorrect prefixes. + eval_idpf( + &keys, + &IdpfInput::new(&[2], 2).unwrap(), + &[Field128::zero(), Field128::zero()], + ) + .unwrap(); + + eval_idpf( + &keys, + &IdpfInput::new(&[23, 1], 12).unwrap(), + &[Field128::zero(), Field128::zero()], + ) + .unwrap(); + } + + fn eval_idpf<I, const KEY_LEN: usize, const OUT_LEN: usize>( + keys: &[I; KEY_LEN], + input: &IdpfInput, + expected_output: &[I::Field; OUT_LEN], + ) -> Result<(), VdafError> + where + I: Idpf<KEY_LEN, OUT_LEN>, + { + let mut output = [I::Field::zero(); OUT_LEN]; + for key in keys { + let output_share = key.eval(input)?; + for (x, y) in output.iter_mut().zip(output_share) { + *x += y; + } + } + + if expected_output != &output { + return Err(VdafError::Uncategorized(format!( + "eval_idpf(): unexpected output: got {:?}; want {:?}", + output, expected_output + ))); + } + + Ok(()) + } + + #[test] + fn test_poplar1() { + const INPUT_LEN: usize = 8; + + let vdaf: Poplar1<ToyIdpf<Field128>, PrgAes128, 16> = Poplar1::new(INPUT_LEN); + assert_eq!(vdaf.num_aggregators(), 2); + + // Run the VDAF input-distribution algorithm. + let input = vec![IdpfInput::new(&[0b0110_1000], INPUT_LEN).unwrap()]; + + let mut agg_param = BTreeSet::new(); + agg_param.insert(input[0]); + check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]); + + // Try evaluating the VDAF on each prefix of the input. + for prefix_len in 0..input[0].level + 1 { + let mut agg_param = BTreeSet::new(); + agg_param.insert(input[0].prefix(prefix_len)); + check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]); + } + + // Try various prefixes. + let prefix_len = 4; + let mut agg_param = BTreeSet::new(); + // At length 4, the next two prefixes are equal. Neither one matches the input. + agg_param.insert(IdpfInput::new(&[0b0000_0000], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0001_0000], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_0001], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap()); + // At length 4, the next two prefixes are equal. Both match the input. + agg_param.insert(IdpfInput::new(&[0b0111_1101], prefix_len).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap()); + let aggregate = run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(); + assert_eq!(aggregate.len(), agg_param.len()); + check_btree( + &aggregate, + // We put six prefixes in the aggregation parameter, but the vector we get back is only + // 4 elements because at the given prefix length, some of the prefixes are equal. + &[0, 0, 0, 1], + ); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = b"this is a nonce"; + + // Try evaluating the VDAF with an invalid aggregation parameter. (It's an error to have a + // mixture of prefix lengths.) + let mut agg_param = BTreeSet::new(); + agg_param.insert(IdpfInput::new(&[0b0000_0111], 6).unwrap()); + agg_param.insert(IdpfInput::new(&[0b0000_1000], 7).unwrap()); + let (public_share, input_shares) = vdaf.shard(&input[0]).unwrap(); + run_vdaf_prepare( + &vdaf, + &verify_key, + &agg_param, + nonce, + public_share, + input_shares, + ) + .unwrap_err(); + + // Try evaluating the VDAF with malformed inputs. + // + // This IDPF key pair evaluates to 1 everywhere, which is illegal. + let (public_share, mut input_shares) = vdaf.shard(&input[0]).unwrap(); + for (i, x) in input_shares[0].idpf.data0.iter_mut().enumerate() { + if i != input[0].index { + *x += Field128::one(); + } + } + let mut agg_param = BTreeSet::new(); + agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap()); + run_vdaf_prepare( + &vdaf, + &verify_key, + &agg_param, + nonce, + public_share, + input_shares, + ) + .unwrap_err(); + + // This IDPF key pair has a garbled authentication vector. + let (public_share, mut input_shares) = vdaf.shard(&input[0]).unwrap(); + for x in input_shares[0].idpf.data1.iter_mut() { + *x = Field128::zero(); + } + let mut agg_param = BTreeSet::new(); + agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap()); + run_vdaf_prepare( + &vdaf, + &verify_key, + &agg_param, + nonce, + public_share, + input_shares, + ) + .unwrap_err(); + } + + fn check_btree(btree: &BTreeMap<IdpfInput, u64>, counts: &[u64]) { + for (got, want) in btree.values().zip(counts.iter()) { + assert_eq!(got, want, "got {:?} want {:?}", btree.values(), counts); + } + } +} diff --git a/third_party/rust/prio/src/vdaf/prg.rs b/third_party/rust/prio/src/vdaf/prg.rs new file mode 100644 index 0000000000..a5930f1283 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prg.rs @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementations of PRGs specified in [[draft-irtf-cfrg-vdaf-03]]. +//! +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +use crate::vdaf::{CodecError, Decode, Encode}; +#[cfg(feature = "crypto-dependencies")] +use aes::{ + cipher::{KeyIvInit, StreamCipher}, + Aes128, +}; +#[cfg(feature = "crypto-dependencies")] +use cmac::{Cmac, Mac}; +#[cfg(feature = "crypto-dependencies")] +use ctr::Ctr64BE; +#[cfg(feature = "crypto-dependencies")] +use std::fmt::Formatter; +use std::{ + fmt::Debug, + io::{Cursor, Read}, +}; + +/// Function pointer to fill a buffer with random bytes. Under normal operation, +/// `getrandom::getrandom()` will be used, but other implementations can be used to control +/// randomness when generating or verifying test vectors. +pub(crate) type RandSource = fn(&mut [u8]) -> Result<(), getrandom::Error>; + +/// Input of [`Prg`]. +#[derive(Clone, Debug, Eq)] +pub struct Seed<const L: usize>(pub(crate) [u8; L]); + +impl<const L: usize> Seed<L> { + /// Generate a uniform random seed. + pub fn generate() -> Result<Self, getrandom::Error> { + Self::from_rand_source(getrandom::getrandom) + } + + pub(crate) fn from_rand_source(rand_source: RandSource) -> Result<Self, getrandom::Error> { + let mut seed = [0; L]; + rand_source(&mut seed)?; + Ok(Self(seed)) + } +} + +impl<const L: usize> AsRef<[u8; L]> for Seed<L> { + fn as_ref(&self) -> &[u8; L] { + &self.0 + } +} + +impl<const L: usize> PartialEq for Seed<L> { + fn eq(&self, other: &Self) -> bool { + // Do constant-time compare. + let mut r = 0; + for (x, y) in self.0[..].iter().zip(&other.0[..]) { + r |= x ^ y; + } + r == 0 + } +} + +impl<const L: usize> Encode for Seed<L> { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&self.0[..]); + } +} + +impl<const L: usize> Decode for Seed<L> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut seed = [0; L]; + bytes.read_exact(&mut seed)?; + Ok(Seed(seed)) + } +} + +/// A stream of pseudorandom bytes derived from a seed. +pub trait SeedStream { + /// Fill `buf` with the next `buf.len()` bytes of output. + fn fill(&mut self, buf: &mut [u8]); +} + +/// A pseudorandom generator (PRG) with the interface specified in [[draft-irtf-cfrg-vdaf-03]]. +/// +/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +pub trait Prg<const L: usize>: Clone + Debug { + /// The type of stream produced by this PRG. + type SeedStream: SeedStream; + + /// Construct an instance of [`Prg`] with the given seed. + fn init(seed_bytes: &[u8; L]) -> Self; + + /// Update the PRG state by passing in the next fragment of the info string. The final info + /// string is assembled from the concatenation of sequence of fragments passed to this method. + fn update(&mut self, data: &[u8]); + + /// Finalize the PRG state, producing a seed stream. + fn into_seed_stream(self) -> Self::SeedStream; + + /// Finalize the PRG state, producing a seed. + fn into_seed(self) -> Seed<L> { + let mut new_seed = [0; L]; + let mut seed_stream = self.into_seed_stream(); + seed_stream.fill(&mut new_seed); + Seed(new_seed) + } + + /// Construct a seed stream from the given seed and info string. + fn seed_stream(seed: &Seed<L>, info: &[u8]) -> Self::SeedStream { + let mut prg = Self::init(seed.as_ref()); + prg.update(info); + prg.into_seed_stream() + } +} + +/// The PRG based on AES128 as specified in [[draft-irtf-cfrg-vdaf-03]]. +/// +/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +#[derive(Clone, Debug)] +#[cfg(feature = "crypto-dependencies")] +pub struct PrgAes128(Cmac<Aes128>); + +#[cfg(feature = "crypto-dependencies")] +impl Prg<16> for PrgAes128 { + type SeedStream = SeedStreamAes128; + + fn init(seed_bytes: &[u8; 16]) -> Self { + Self(Cmac::new_from_slice(seed_bytes).unwrap()) + } + + fn update(&mut self, data: &[u8]) { + self.0.update(data); + } + + fn into_seed_stream(self) -> SeedStreamAes128 { + let key = self.0.finalize().into_bytes(); + SeedStreamAes128::new(&key, &[0; 16]) + } +} + +/// The key stream produced by AES128 in CTR-mode. +#[cfg(feature = "crypto-dependencies")] +pub struct SeedStreamAes128(Ctr64BE<Aes128>); + +#[cfg(feature = "crypto-dependencies")] +impl SeedStreamAes128 { + pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self { + SeedStreamAes128(Ctr64BE::<Aes128>::new(key.into(), iv.into())) + } +} + +#[cfg(feature = "crypto-dependencies")] +impl SeedStream for SeedStreamAes128 { + fn fill(&mut self, buf: &mut [u8]) { + buf.fill(0); + self.0.apply_keystream(buf); + } +} + +#[cfg(feature = "crypto-dependencies")] +impl Debug for SeedStreamAes128 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + // Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that + // with [`cipher::StreamCipherCoreWrapper::get_core`][2]. + // + // [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html + // [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html + self.0.get_core().fmt(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{field::Field128, prng::Prng}; + use serde::{Deserialize, Serialize}; + use std::convert::TryInto; + + #[derive(Deserialize, Serialize)] + struct PrgTestVector { + #[serde(with = "hex")] + seed: Vec<u8>, + #[serde(with = "hex")] + info: Vec<u8>, + length: usize, + #[serde(with = "hex")] + derived_seed: Vec<u8>, + #[serde(with = "hex")] + expanded_vec_field128: Vec<u8>, + } + + // Test correctness of dervied methods. + fn test_prg<P, const L: usize>() + where + P: Prg<L>, + { + let seed = Seed::generate().unwrap(); + let info = b"info string"; + + let mut prg = P::init(seed.as_ref()); + prg.update(info); + + let mut want = Seed([0; L]); + prg.clone().into_seed_stream().fill(&mut want.0[..]); + let got = prg.clone().into_seed(); + assert_eq!(got, want); + + let mut want = [0; 45]; + prg.clone().into_seed_stream().fill(&mut want); + let mut got = [0; 45]; + P::seed_stream(&seed, info).fill(&mut got); + assert_eq!(got, want); + } + + #[test] + fn prg_aes128() { + let t: PrgTestVector = + serde_json::from_str(include_str!("test_vec/03/PrgAes128.json")).unwrap(); + let mut prg = PrgAes128::init(&t.seed.try_into().unwrap()); + prg.update(&t.info); + + assert_eq!( + prg.clone().into_seed(), + Seed(t.derived_seed.try_into().unwrap()) + ); + + let mut bytes = std::io::Cursor::new(t.expanded_vec_field128.as_slice()); + let mut want = Vec::with_capacity(t.length); + while (bytes.position() as usize) < t.expanded_vec_field128.len() { + want.push(Field128::decode(&mut bytes).unwrap()) + } + let got: Vec<Field128> = Prng::from_seed_stream(prg.clone().into_seed_stream()) + .take(t.length) + .collect(); + assert_eq!(got, want); + + test_prg::<PrgAes128, 16>(); + } +} diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs new file mode 100644 index 0000000000..47fc076790 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio2.rs @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Port of the ENPA Prio system to a VDAF. It is backwards compatible with +//! [`Client`](crate::client::Client) and [`Server`](crate::server::Server). + +use crate::{ + client as v2_client, + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{FieldElement, FieldPrio2}, + prng::Prng, + server as v2_server, + util::proof_length, + vdaf::{ + prg::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, + PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError, + }, +}; +use ring::hmac; +use std::{ + convert::{TryFrom, TryInto}, + io::Cursor, +}; + +/// The Prio2 VDAF. It supports the same measurement type as +/// [`Prio3Aes128CountVec`](crate::vdaf::prio3::Prio3Aes128CountVec) but uses the proof system +/// and finite field deployed in ENPA. +#[derive(Clone, Debug)] +pub struct Prio2 { + input_len: usize, +} + +impl Prio2 { + /// Returns an instance of the VDAF for the given input length. + pub fn new(input_len: usize) -> Result<Self, VdafError> { + let n = (input_len + 1).next_power_of_two(); + if let Ok(size) = u32::try_from(2 * n) { + if size > FieldPrio2::generator_order() { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + } else { + return Err(VdafError::Uncategorized( + "input size exceeds memory capacity".into(), + )); + } + + Ok(Prio2 { input_len }) + } + + /// Prepare an input share for aggregation using the given field element `query_rand` to + /// compute the verifier share. + /// + /// In the [`Aggregator`](crate::vdaf::Aggregator) trait implementation for [`Prio2`], the + /// query randomness is computed jointly by the Aggregators. This method is designed to be used + /// in applications, like ENPA, in which the query randomness is instead chosen by a + /// third-party. + pub fn prepare_init_with_query_rand( + &self, + query_rand: FieldPrio2, + input_share: &Share<FieldPrio2, 32>, + is_leader: bool, + ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { + let expanded_data: Option<Vec<FieldPrio2>> = match input_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_prio2_seed(seed.as_ref()); + Some(prng.take(proof_length(self.input_len)).collect()) + } + }; + let data = match input_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_data.as_ref().unwrap(), + }; + + let mut mem = v2_server::ValidationMemory::new(self.input_len); + let verifier_share = v2_server::generate_verification_message( + self.input_len, + query_rand, + data, // Combined input and proof shares + is_leader, + &mut mem, + ) + .map_err(|e| VdafError::Uncategorized(e.to_string()))?; + + Ok(( + Prio2PrepareState(input_share.truncated(self.input_len)), + Prio2PrepareShare(verifier_share), + )) + } +} + +impl Vdaf for Prio2 { + const ID: u32 = 0xFFFF0000; + type Measurement = Vec<u32>; + type AggregateResult = Vec<u32>; + type AggregationParam = (); + type PublicShare = (); + type InputShare = Share<FieldPrio2, 32>; + type OutputShare = OutputShare<FieldPrio2>; + type AggregateShare = AggregateShare<FieldPrio2>; + + fn num_aggregators(&self) -> usize { + // Prio2 can easily be extended to support more than two Aggregators. + 2 + } +} + +impl Client for Prio2 { + fn shard(&self, measurement: &Vec<u32>) -> Result<((), Vec<Share<FieldPrio2, 32>>), VdafError> { + if measurement.len() != self.input_len { + return Err(VdafError::Uncategorized("incorrect input length".into())); + } + let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len()); + for int in measurement { + input.push((*int).into()); + } + + let mut mem = v2_client::ClientMemory::new(self.input_len)?; + let copy_data = |share_data: &mut [FieldPrio2]| { + share_data[..].clone_from_slice(&input); + }; + let mut leader_data = mem.prove_with(self.input_len, copy_data); + + let helper_seed = Seed::generate()?; + let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref()); + for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) { + *s1 -= d; + } + + Ok(( + (), + vec![Share::Leader(leader_data), Share::Helper(helper_seed)], + )) + } +} + +/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Prio2PrepareState(Share<FieldPrio2, 32>); + +impl Encode for Prio2PrepareState { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes); + } +} + +impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState { + fn decode_with_param( + (prio2, agg_id): &(&'a Prio2, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let share_decoder = if *agg_id == 0 { + ShareDecodingParameter::Leader(prio2.input_len) + } else { + ShareDecodingParameter::Helper + }; + let out_share = Share::decode_with_param(&share_decoder, bytes)?; + Ok(Self(out_share)) + } +} + +/// Message emitted by each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. +#[derive(Clone, Debug)] +pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>); + +impl Encode for Prio2PrepareShare { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.f_r.encode(bytes); + self.0.g_r.encode(bytes); + self.0.h_r.encode(bytes); + } +} + +impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare { + fn decode_with_param( + _state: &Prio2PrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Ok(Self(v2_server::VerificationMessage { + f_r: FieldPrio2::decode(bytes)?, + g_r: FieldPrio2::decode(bytes)?, + h_r: FieldPrio2::decode(bytes)?, + })) + } +} + +impl Aggregator<32> for Prio2 { + type PrepareState = Prio2PrepareState; + type PrepareShare = Prio2PrepareShare; + type PrepareMessage = (); + + fn prepare_init( + &self, + agg_key: &[u8; 32], + agg_id: usize, + _agg_param: &(), + nonce: &[u8], + _public_share: &Self::PublicShare, + input_share: &Share<FieldPrio2, 32>, + ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { + let is_leader = role_try_from(agg_id)?; + + // In the ENPA Prio system, the query randomness is generated by a third party and + // distributed to the Aggregators after they receive their input shares. In a VDAF, shared + // randomness is derived from a nonce selected by the client. For Prio2 we compute the + // query using HMAC-SHA256 evaluated over the nonce. + let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, agg_key); + let hmac_tag = hmac::sign(&hmac_key, nonce); + let query_rand = Prng::from_prio2_seed(hmac_tag.as_ref().try_into().unwrap()) + .next() + .unwrap(); + + self.prepare_init_with_query_rand(query_rand, input_share, is_leader) + } + + fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>( + &self, + inputs: M, + ) -> Result<(), VdafError> { + let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> = + inputs.into_iter().map(|msg| msg.0).collect(); + if verifier_shares.len() != 2 { + return Err(VdafError::Uncategorized( + "wrong number of verifier shares".into(), + )); + } + + if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) { + return Err(VdafError::Uncategorized( + "proof verifier check failed".into(), + )); + } + + Ok(()) + } + + fn prepare_step( + &self, + state: Prio2PrepareState, + _input: (), + ) -> Result<PrepareTransition<Self, 32>, VdafError> { + let data = match state.0 { + Share::Leader(data) => data, + Share::Helper(seed) => { + let prng = Prng::from_prio2_seed(seed.as_ref()); + prng.take(self.input_len).collect() + } + }; + Ok(PrepareTransition::Finish(OutputShare::from(data))) + } + + fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>( + &self, + _agg_param: &(), + out_shares: M, + ) -> Result<AggregateShare<FieldPrio2>, VdafError> { + let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + for out_share in out_shares.into_iter() { + agg_share.accumulate(&out_share)?; + } + + Ok(agg_share) + } +} + +impl Collector for Prio2 { + fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>( + &self, + _agg_param: &(), + agg_shares: M, + _num_measurements: usize, + ) -> Result<Vec<u32>, VdafError> { + let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + for agg_share in agg_shares.into_iter() { + agg.merge(&agg_share)?; + } + + Ok(agg.0.into_iter().map(u32::from).collect()) + } +} + +impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> { + fn decode_with_param( + (prio2, agg_id): &(&'a Prio2, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?; + let decoder = if is_leader { + ShareDecodingParameter::Leader(proof_length(prio2.input_len)) + } else { + ShareDecodingParameter::Helper + }; + + Share::decode_with_param(&decoder, bytes) + } +} + +fn role_try_from(agg_id: usize) -> Result<bool, VdafError> { + match agg_id { + 0 => Ok(true), + 1 => Ok(false), + _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + client::encode_simple, + encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey}, + field::random_vector, + server::Server, + vdaf::{run_vdaf, run_vdaf_prepare}, + }; + use rand::prelude::*; + + const PRIV_KEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw=="; + const PRIV_KEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w=="; + + #[test] + fn run_prio2() { + let prio2 = Prio2::new(6).unwrap(); + + assert_eq!( + run_vdaf( + &prio2, + &(), + [ + vec![0, 0, 0, 0, 1, 0], + vec![0, 1, 0, 0, 0, 0], + vec![0, 1, 1, 0, 0, 0], + vec![1, 1, 1, 0, 0, 0], + vec![0, 0, 0, 0, 1, 1], + ] + ) + .unwrap(), + vec![1, 3, 2, 0, 2, 1], + ); + } + + #[test] + fn enpa_client_interop() { + let mut rng = thread_rng(); + let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap(); + let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap(); + let pub_key1 = PublicKey::from(&priv_key1); + let pub_key2 = PublicKey::from(&priv_key2); + + let data: Vec<FieldPrio2> = [0, 0, 1, 1, 0] + .iter() + .map(|x| FieldPrio2::from(*x)) + .collect(); + let (encrypted_input_share1, encrypted_input_share2) = + encode_simple(&data, pub_key1, pub_key2).unwrap(); + + let input_share1 = decrypt_share(&encrypted_input_share1, &priv_key1).unwrap(); + let input_share2 = decrypt_share(&encrypted_input_share2, &priv_key2).unwrap(); + + let prio2 = Prio2::new(data.len()).unwrap(); + let input_shares = vec![ + Share::get_decoded_with_param(&(&prio2, 0), &input_share1).unwrap(), + Share::get_decoded_with_param(&(&prio2, 1), &input_share2).unwrap(), + ]; + + let verify_key = rng.gen(); + let mut nonce = [0; 16]; + rng.fill(&mut nonce); + run_vdaf_prepare(&prio2, &verify_key, &(), &nonce, (), input_shares).unwrap(); + } + + #[test] + fn enpa_server_interop() { + let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap(); + let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap(); + let pub_key1 = PublicKey::from(&priv_key1); + let pub_key2 = PublicKey::from(&priv_key2); + + let data = vec![0, 0, 1, 1, 0]; + let prio2 = Prio2::new(data.len()).unwrap(); + let (_public_share, input_shares) = prio2.shard(&data).unwrap(); + + let encrypted_input_share1 = + encrypt_share(&input_shares[0].get_encoded(), &pub_key1).unwrap(); + let encrypted_input_share2 = + encrypt_share(&input_shares[1].get_encoded(), &pub_key2).unwrap(); + + let mut server1 = Server::new(data.len(), true, priv_key1).unwrap(); + let mut server2 = Server::new(data.len(), false, priv_key2).unwrap(); + + let eval_at: FieldPrio2 = random_vector(1).unwrap()[0]; + let verifier1 = server1 + .generate_verification_message(eval_at, &encrypted_input_share1) + .unwrap(); + let verifier2 = server2 + .generate_verification_message(eval_at, &encrypted_input_share2) + .unwrap(); + + server1 + .aggregate(&encrypted_input_share1, &verifier1, &verifier2) + .unwrap(); + server2 + .aggregate(&encrypted_input_share2, &verifier1, &verifier2) + .unwrap(); + } + + #[test] + fn prepare_state_serialization() { + let mut verify_key = [0; 32]; + thread_rng().fill(&mut verify_key[..]); + let data = vec![0, 0, 1, 1, 0]; + let prio2 = Prio2::new(data.len()).unwrap(); + let (public_share, input_shares) = prio2.shard(&data).unwrap(); + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (want, _msg) = prio2 + .prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share) + .unwrap(); + let got = + Prio2PrepareState::get_decoded_with_param(&(&prio2, agg_id), &want.get_encoded()) + .expect("failed to decode prepare step"); + assert_eq!(got, want); + } + } +} diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs new file mode 100644 index 0000000000..31853f15ab --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio3.rs @@ -0,0 +1,1168 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-03]]. +//! +//! **WARNING:** Neither this code nor the cryptographic construction it implements has undergone +//! significant security analysis. Use at your own risk. +//! +//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented +//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO +//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication +//! cost. +//! +//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-03]] into +//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of +//! which are instantiated here: +//! +//! - [`Prio3Aes128Count`] for aggregating a counter (*) +//! - [`Prio3Aes128CountVec`] for aggregating a vector of counters +//! - [`Prio3Aes128Sum`] for copmputing the sum of integers (*) +//! - [`Prio3Aes128Histogram`] for estimating a distribution via a histogram (*) +//! +//! Additional types can be constructed from [`Prio3`] as needed. +//! +//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-03]]. +//! +//! [BBCG+19]: https://ia.cr/2019/188 +//! [CGB17]: https://crypto.stanford.edu/prio/ +//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ + +#[cfg(feature = "crypto-dependencies")] +use super::prg::PrgAes128; +use super::{DST_LEN, VERSION}; +use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; +use crate::field::FieldElement; +#[cfg(feature = "crypto-dependencies")] +use crate::field::{Field128, Field64}; +#[cfg(feature = "multithreaded")] +use crate::flp::gadgets::ParallelSumMultithreaded; +#[cfg(feature = "crypto-dependencies")] +use crate::flp::gadgets::{BlindPolyEval, ParallelSum}; +#[cfg(feature = "crypto-dependencies")] +use crate::flp::types::{Average, Count, CountVec, Histogram, Sum}; +use crate::flp::Type; +use crate::prng::Prng; +use crate::vdaf::prg::{Prg, RandSource, Seed}; +use crate::vdaf::{ + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, + Share, ShareDecodingParameter, Vdaf, VdafError, +}; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::io::Cursor; +use std::iter::IntoIterator; +use std::marker::PhantomData; + +/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Count = Prio3<Count<Field64>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Count { + /// Construct an instance of Prio3Aes128Count with the given number of aggregators. + pub fn new_aes128_count(num_aggregators: u8) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, Count::new()) + } +} + +/// The count-vector type. Each measurement is a vector of integers in `[0,2)` and the aggregate is +/// the element-wise sum. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128CountVec = + Prio3<CountVec<Field128, ParallelSum<Field128, BlindPolyEval<Field128>>>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128CountVec { + /// Construct an instance of Prio3Aes1238CountVec with the given number of aggregators. `len` + /// defines the length of each measurement. + pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, CountVec::new(len)) + } +} + +/// Like [`Prio3Aes128CountVec`] except this type uses multithreading to improve sharding and +/// preparation time. Note that the improvement is only noticeable for very large input lengths, +/// e.g., 201 and up. (Your system's mileage may vary.) +#[cfg(feature = "multithreaded")] +#[cfg(feature = "crypto-dependencies")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +pub type Prio3Aes128CountVecMultithreaded = Prio3< + CountVec<Field128, ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>>, + PrgAes128, + 16, +>; + +#[cfg(feature = "multithreaded")] +#[cfg(feature = "crypto-dependencies")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +impl Prio3Aes128CountVecMultithreaded { + /// Construct an instance of Prio3Aes1238CountVecMultithreaded with the given number of + /// aggregators. `len` defines the length of each measurement. + pub fn new_aes128_count_vec_multithreaded( + num_aggregators: u8, + len: usize, + ) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, CountVec::new(len)) + } +} + +/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the +/// aggregate is the sum. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Sum = Prio3<Sum<Field128>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Sum { + /// Construct an instance of Prio3Aes128Sum with the given number of aggregators and required + /// bit length. The bit length must not exceed 64. + pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> { + if bits > 64 { + return Err(VdafError::Uncategorized(format!( + "bit length ({}) exceeds limit for aggregate type (64)", + bits + ))); + } + + Prio3::new(num_aggregators, Sum::new(bits as usize)?) + } +} + +/// The histogram type. Each measurement is an unsigned integer and the result is a histogram +/// representation of the distribution. The bucket boundaries are fixed in advance. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Histogram = Prio3<Histogram<Field128>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Histogram { + /// Constructs an instance of Prio3Aes128Histogram with the given number of aggregators and + /// desired histogram bucket boundaries. + pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result<Self, VdafError> { + let buckets = buckets.iter().map(|bucket| *bucket as u128).collect(); + + Prio3::new(num_aggregators, Histogram::new(buckets)?) + } +} + +/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and +/// the aggregate is the arithmetic average. +#[cfg(feature = "crypto-dependencies")] +pub type Prio3Aes128Average = Prio3<Average<Field128>, PrgAes128, 16>; + +#[cfg(feature = "crypto-dependencies")] +impl Prio3Aes128Average { + /// Construct an instance of Prio3Aes128Average with the given number of aggregators and + /// required bit length. The bit length must not exceed 64. + pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + + if bits > 64 { + return Err(VdafError::Uncategorized(format!( + "bit length ({}) exceeds limit for aggregate type (64)", + bits + ))); + } + + Ok(Prio3 { + num_aggregators, + typ: Average::new(bits as usize)?, + phantom: PhantomData, + }) + } +} + +/// The base type for Prio3. +/// +/// An instance of Prio3 is determined by: +/// +/// - a [`Type`](crate::flp::Type) that defines the set of valid input measurements; and +/// - a [`Prg`](crate::vdaf::prg::Prg) for deriving vectors of field elements from seeds. +/// +/// New instances can be defined by aliasing the base type. For example, [`Prio3Aes128Count`] is an +/// alias for `Prio3<Count<Field64>, PrgAes128, 16>`. +/// +/// ``` +/// use prio::vdaf::{ +/// Aggregator, Client, Collector, PrepareTransition, +/// prio3::Prio3, +/// }; +/// use rand::prelude::*; +/// +/// let num_shares = 2; +/// let vdaf = Prio3::new_aes128_count(num_shares).unwrap(); +/// +/// let mut out_shares = vec![vec![]; num_shares.into()]; +/// let mut rng = thread_rng(); +/// let verify_key = rng.gen(); +/// let measurements = [0, 1, 1, 1, 0]; +/// for measurement in measurements { +/// // Shard +/// let (public_share, input_shares) = vdaf.shard(&measurement).unwrap(); +/// let mut nonce = [0; 16]; +/// rng.fill(&mut nonce); +/// +/// // Prepare +/// let mut prep_states = vec![]; +/// let mut prep_shares = vec![]; +/// for (agg_id, input_share) in input_shares.iter().enumerate() { +/// let (state, share) = vdaf.prepare_init( +/// &verify_key, +/// agg_id, +/// &(), +/// &nonce, +/// &public_share, +/// input_share +/// ).unwrap(); +/// prep_states.push(state); +/// prep_shares.push(share); +/// } +/// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); +/// +/// for (agg_id, state) in prep_states.into_iter().enumerate() { +/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() { +/// PrepareTransition::Finish(out_share) => out_share, +/// _ => panic!("unexpected transition"), +/// }; +/// out_shares[agg_id].push(out_share); +/// } +/// } +/// +/// // Aggregate +/// let agg_shares = out_shares.into_iter() +/// .map(|o| vdaf.aggregate(&(), o).unwrap()); +/// +/// // Unshard +/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap(); +/// assert_eq!(agg_res, 3); +/// ``` +/// +/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +#[derive(Clone, Debug)] +pub struct Prio3<T, P, const L: usize> +where + T: Type, + P: Prg<L>, +{ + num_aggregators: u8, + typ: T, + phantom: PhantomData<P>, +} + +impl<T, P, const L: usize> Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the + /// underlying type. + pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + Ok(Self { + num_aggregators, + typ, + phantom: PhantomData, + }) + } + + /// The output length of the underlying FLP. + pub fn output_len(&self) -> usize { + self.typ.output_len() + } + + /// The verifier length of the underlying FLP. + pub fn verifier_len(&self) -> usize { + self.typ.verifier_len() + } + + fn derive_joint_randomness<'a>(parts: impl Iterator<Item = &'a Seed<L>>) -> Seed<L> { + let mut info = [0; VERSION.len() + 5]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes()); + info[VERSION.len() + 4] = 255; + let mut deriver = P::init(&[0; L]); + deriver.update(&info); + for part in parts { + deriver.update(part.as_ref()); + } + deriver.into_seed() + } + + fn shard_with_rand_source( + &self, + measurement: &T::Measurement, + rand_source: RandSource, + ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> { + let mut info = [0; DST_LEN + 1]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); + + let num_aggregators = self.num_aggregators; + let input = self.typ.encode_measurement(measurement)?; + + // Generate the input shares and compute the joint randomness. + let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1); + let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 { + Some(Vec::with_capacity(num_aggregators as usize - 1)) + } else { + None + }; + let mut leader_input_share = input.clone(); + for agg_id in 1..num_aggregators { + let helper = HelperShare::from_rand_source(rand_source)?; + + let mut deriver = P::init(helper.joint_rand_param.blind.as_ref()); + info[DST_LEN] = agg_id; + deriver.update(&info); + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info)); + for (x, y) in leader_input_share + .iter_mut() + .zip(prng) + .take(self.typ.input_len()) + { + *x -= y; + deriver.update(&y.into()); + } + + if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() { + helper_joint_rand_parts.push(deriver.into_seed()); + } + helper_shares.push(helper); + } + + let leader_blind = Seed::from_rand_source(rand_source)?; + + info[DST_LEN] = 0; // ID of the leader + let mut deriver = P::init(leader_blind.as_ref()); + deriver.update(&info); + for x in leader_input_share.iter() { + deriver.update(&(*x).into()); + } + + let leader_joint_rand_seed_part = deriver.into_seed(); + + // Compute the joint randomness seed. + let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| { + Self::derive_joint_randomness( + std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()), + ) + }); + + // Run the proof-generation algorithm. + let domain_separation_tag = &info[..DST_LEN]; + let joint_rand: Vec<T::Field> = joint_rand_seed + .map(|joint_rand_seed| { + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag)); + prng.take(self.typ.joint_rand_len()).collect() + }) + .unwrap_or_default(); + let prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream( + &Seed::from_rand_source(rand_source)?, + domain_separation_tag, + )); + let prove_rand: Vec<T::Field> = prng.take(self.typ.prove_rand_len()).collect(); + let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?; + + // Generate the proof shares and distribute the joint randomness seed hints. + for (j, helper) in helper_shares.iter_mut().enumerate() { + info[DST_LEN] = j as u8 + 1; + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info)); + for (x, y) in leader_proof_share + .iter_mut() + .zip(prng) + .take(self.typ.proof_len()) + { + *x -= y; + } + + if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() { + let mut hint = Vec::with_capacity(num_aggregators as usize - 1); + hint.push(leader_joint_rand_seed_part.clone()); + hint.extend(helper_joint_rand_parts[..j].iter().cloned()); + hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned()); + helper.joint_rand_param.seed_hint = hint; + } + } + + let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 { + Some(JointRandParam { + seed_hint: helper_joint_rand_parts.unwrap_or_default(), + blind: leader_blind, + }) + } else { + None + }; + + // Prep the output messages. + let mut out = Vec::with_capacity(num_aggregators as usize); + out.push(Prio3InputShare { + input_share: Share::Leader(leader_input_share), + proof_share: Share::Leader(leader_proof_share), + joint_rand_param: leader_joint_rand_param, + }); + + for helper in helper_shares.into_iter() { + let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 { + Some(helper.joint_rand_param) + } else { + None + }; + + out.push(Prio3InputShare { + input_share: Share::Helper(helper.input_share), + proof_share: Share::Helper(helper.proof_share), + joint_rand_param: helper_joint_rand_param, + }); + } + + Ok(out) + } + + /// Shard measurement with constant randomness of repeated bytes. + /// This method is not secure. It is used for running test vectors for Prio3. + #[cfg(feature = "test-util")] + pub fn test_vec_shard( + &self, + measurement: &T::Measurement, + ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> { + self.shard_with_rand_source(measurement, |buf| { + buf.fill(1); + Ok(()) + }) + } + + fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> { + if agg_id >= self.num_aggregators as usize { + return Err(VdafError::Uncategorized("unexpected aggregator id".into())); + } + Ok(u8::try_from(agg_id).unwrap()) + } +} + +impl<T, P, const L: usize> Vdaf for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + const ID: u32 = T::ID; + type Measurement = T::Measurement; + type AggregateResult = T::AggregateResult; + type AggregationParam = (); + type PublicShare = (); + type InputShare = Prio3InputShare<T::Field, L>; + type OutputShare = OutputShare<T::Field>; + type AggregateShare = AggregateShare<T::Field>; + + fn num_aggregators(&self) -> usize { + self.num_aggregators as usize + } +} + +/// Message sent by the [`Client`](crate::vdaf::Client) to each +/// [`Aggregator`](crate::vdaf::Aggregator) during the Sharding phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Prio3InputShare<F, const L: usize> { + /// The input share. + input_share: Share<F, L>, + + /// The proof share. + proof_share: Share<F, L>, + + /// Parameters used by the Aggregator to compute the joint randomness. This field is optional + /// because not every [`Type`](`crate::flp::Type`) requires joint randomness. + joint_rand_param: Option<JointRandParam<L>>, +} + +impl<F: FieldElement, const L: usize> Encode for Prio3InputShare<F, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + if matches!( + (&self.input_share, &self.proof_share), + (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_)) + ) { + panic!("tried to encode input share with ambiguous encoding") + } + + self.input_share.encode(bytes); + self.proof_share.encode(bytes); + if let Some(ref param) = self.joint_rand_param { + param.blind.encode(bytes); + for part in param.seed_hint.iter() { + part.encode(bytes); + } + } + } +} + +impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)> + for Prio3InputShare<T::Field, L> +where + T: Type, + P: Prg<L>, +{ + fn decode_with_param( + (prio3, agg_id): &(&'a Prio3<T, P, L>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let agg_id = prio3 + .role_try_from(*agg_id) + .map_err(|e| CodecError::Other(Box::new(e)))?; + let (input_decoder, proof_decoder) = if agg_id == 0 { + ( + ShareDecodingParameter::Leader(prio3.typ.input_len()), + ShareDecodingParameter::Leader(prio3.typ.proof_len()), + ) + } else { + ( + ShareDecodingParameter::Helper, + ShareDecodingParameter::Helper, + ) + }; + + let input_share = Share::decode_with_param(&input_decoder, bytes)?; + let proof_share = Share::decode_with_param(&proof_decoder, bytes)?; + let joint_rand_param = if prio3.typ.joint_rand_len() > 0 { + let num_aggregators = prio3.num_aggregators(); + let blind = Seed::decode(bytes)?; + let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes)) + .take(num_aggregators - 1) + .collect::<Result<Vec<_>, _>>()?; + Some(JointRandParam { blind, seed_hint }) + } else { + None + }; + + Ok(Prio3InputShare { + input_share, + proof_share, + joint_rand_param, + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +/// Message broadcast by each [`Aggregator`](crate::vdaf::Aggregator) in each round of the +/// Preparation phase. +pub struct Prio3PrepareShare<F, const L: usize> { + /// A share of the FLP verifier message. (See [`Type`](crate::flp::Type).) + verifier: Vec<F>, + + /// A part of the joint randomness seed. + joint_rand_part: Option<Seed<L>>, +} + +impl<F: FieldElement, const L: usize> Encode for Prio3PrepareShare<F, L> { + fn encode(&self, bytes: &mut Vec<u8>) { + for x in &self.verifier { + x.encode(bytes); + } + if let Some(ref seed) = self.joint_rand_part { + seed.encode(bytes); + } + } +} + +impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>> + for Prio3PrepareShare<F, L> +{ + fn decode_with_param( + decoding_parameter: &Prio3PrepareState<F, L>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len); + for _ in 0..decoding_parameter.verifier_len { + verifier.push(F::decode(bytes)?); + } + + let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Prio3PrepareShare { + verifier, + joint_rand_part, + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +/// Result of combining a round of [`Prio3PrepareShare`] messages. +pub struct Prio3PrepareMessage<const L: usize> { + /// The joint randomness seed computed by the Aggregators. + joint_rand_seed: Option<Seed<L>>, +} + +impl<const L: usize> Encode for Prio3PrepareMessage<L> { + fn encode(&self, bytes: &mut Vec<u8>) { + if let Some(ref seed) = self.joint_rand_seed { + seed.encode(bytes); + } + } +} + +impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>> + for Prio3PrepareMessage<L> +{ + fn decode_with_param( + decoding_parameter: &Prio3PrepareState<F, L>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Prio3PrepareMessage { joint_rand_seed }) + } +} + +impl<T, P, const L: usize> Client for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + #[allow(clippy::type_complexity)] + fn shard( + &self, + measurement: &T::Measurement, + ) -> Result<((), Vec<Prio3InputShare<T::Field, L>>), VdafError> { + self.shard_with_rand_source(measurement, getrandom::getrandom) + .map(|input_shares| ((), input_shares)) + } +} + +/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Prio3PrepareState<F, const L: usize> { + input_share: Share<F, L>, + joint_rand_seed: Option<Seed<L>>, + agg_id: u8, + verifier_len: usize, +} + +impl<F: FieldElement, const L: usize> Encode for Prio3PrepareState<F, L> { + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + fn encode(&self, bytes: &mut Vec<u8>) { + self.input_share.encode(bytes); + if let Some(ref seed) = self.joint_rand_seed { + seed.encode(bytes); + } + } +} + +impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)> + for Prio3PrepareState<T::Field, L> +where + T: Type, + P: Prg<L>, +{ + fn decode_with_param( + (prio3, agg_id): &(&'a Prio3<T, P, L>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let agg_id = prio3 + .role_try_from(*agg_id) + .map_err(|e| CodecError::Other(Box::new(e)))?; + + let share_decoder = if agg_id == 0 { + ShareDecodingParameter::Leader(prio3.typ.input_len()) + } else { + ShareDecodingParameter::Helper + }; + let input_share = Share::decode_with_param(&share_decoder, bytes)?; + + let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Self { + input_share, + joint_rand_seed, + agg_id, + verifier_len: prio3.typ.verifier_len(), + }) + } +} + +impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + type PrepareState = Prio3PrepareState<T::Field, L>; + type PrepareShare = Prio3PrepareShare<T::Field, L>; + type PrepareMessage = Prio3PrepareMessage<L>; + + /// Begins the Prep process with the other aggregators. The result of this process is + /// the aggregator's output share. + #[allow(clippy::type_complexity)] + fn prepare_init( + &self, + verify_key: &[u8; L], + agg_id: usize, + _agg_param: &(), + nonce: &[u8], + _public_share: &(), + msg: &Prio3InputShare<T::Field, L>, + ) -> Result< + ( + Prio3PrepareState<T::Field, L>, + Prio3PrepareShare<T::Field, L>, + ), + VdafError, + > { + let agg_id = self.role_try_from(agg_id)?; + let mut info = [0; DST_LEN + 1]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); + info[DST_LEN] = agg_id; + let domain_separation_tag = &info[..DST_LEN]; + + let mut deriver = P::init(verify_key); + deriver.update(domain_separation_tag); + deriver.update(&[255]); + deriver.update(nonce); + let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream()); + + // Create a reference to the (expanded) input share. + let expanded_input_share: Option<Vec<T::Field>> = match msg.input_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_seed_stream(P::seed_stream(seed, &info)); + Some(prng.take(self.typ.input_len()).collect()) + } + }; + let input_share = match msg.input_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_input_share.as_ref().unwrap(), + }; + + // Create a reference to the (expanded) proof share. + let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_seed_stream(P::seed_stream(seed, &info)); + Some(prng.take(self.typ.proof_len()).collect()) + } + }; + let proof_share = match msg.proof_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_proof_share.as_ref().unwrap(), + }; + + // Compute the joint randomness. + let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 { + let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref()); + deriver.update(&info); + for x in input_share { + deriver.update(&(*x).into()); + } + let joint_rand_seed_part = deriver.into_seed(); + + let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint; + let joint_rand_seed = Self::derive_joint_randomness( + hints[..agg_id as usize] + .iter() + .chain(std::iter::once(&joint_rand_seed_part)) + .chain(hints[agg_id as usize..].iter()), + ); + + let prng: Prng<T::Field, _> = + Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag)); + ( + Some(joint_rand_seed), + Some(joint_rand_seed_part), + prng.take(self.typ.joint_rand_len()).collect(), + ) + } else { + (None, None, Vec::new()) + }; + + // Compute the query randomness. + let query_rand: Vec<T::Field> = query_rand_prng.take(self.typ.query_rand_len()).collect(); + + // Run the query-generation algorithm. + let verifier_share = self.typ.query( + input_share, + proof_share, + &query_rand, + &joint_rand, + self.num_aggregators as usize, + )?; + + Ok(( + Prio3PrepareState { + input_share: msg.input_share.clone(), + joint_rand_seed, + agg_id, + verifier_len: verifier_share.len(), + }, + Prio3PrepareShare { + verifier: verifier_share, + joint_rand_part: joint_rand_seed_part, + }, + )) + } + + fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, L>>>( + &self, + inputs: M, + ) -> Result<Prio3PrepareMessage<L>, VdafError> { + let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()]; + let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators()); + let mut count = 0; + for share in inputs.into_iter() { + count += 1; + + if share.verifier.len() != verifier.len() { + return Err(VdafError::Uncategorized(format!( + "unexpected verifier share length: got {}; want {}", + share.verifier.len(), + verifier.len(), + ))); + } + + if self.typ.joint_rand_len() > 0 { + let joint_rand_seed_part = share.joint_rand_part.unwrap(); + joint_rand_parts.push(joint_rand_seed_part); + } + + for (x, y) in verifier.iter_mut().zip(share.verifier) { + *x += y; + } + } + + if count != self.num_aggregators { + return Err(VdafError::Uncategorized(format!( + "unexpected message count: got {}; want {}", + count, self.num_aggregators, + ))); + } + + // Check the proof verifier. + match self.typ.decide(&verifier) { + Ok(true) => (), + Ok(false) => { + return Err(VdafError::Uncategorized( + "proof verifier check failed".into(), + )) + } + Err(err) => return Err(VdafError::from(err)), + }; + + let joint_rand_seed = if self.typ.joint_rand_len() > 0 { + Some(Self::derive_joint_randomness(joint_rand_parts.iter())) + } else { + None + }; + + Ok(Prio3PrepareMessage { joint_rand_seed }) + } + + fn prepare_step( + &self, + step: Prio3PrepareState<T::Field, L>, + msg: Prio3PrepareMessage<L>, + ) -> Result<PrepareTransition<Self, L>, VdafError> { + if self.typ.joint_rand_len() > 0 { + // Check that the joint randomness was correct. + if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() { + return Err(VdafError::Uncategorized( + "joint randomness mismatch".to_string(), + )); + } + } + + // Compute the output share. + let input_share = match step.input_share { + Share::Leader(data) => data, + Share::Helper(seed) => { + let mut info = [0; DST_LEN + 1]; + info[..VERSION.len()].copy_from_slice(VERSION); + info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes()); + info[DST_LEN] = step.agg_id; + let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info)); + prng.take(self.typ.input_len()).collect() + } + }; + + let output_share = match self.typ.truncate(input_share) { + Ok(data) => OutputShare(data), + Err(err) => { + return Err(VdafError::from(err)); + } + }; + + Ok(PrepareTransition::Finish(output_share)) + } + + /// Aggregates a sequence of output shares into an aggregate share. + fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>( + &self, + _agg_param: &(), + output_shares: It, + ) -> Result<AggregateShare<T::Field>, VdafError> { + let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + + Ok(agg_share) + } +} + +impl<T, P, const L: usize> Collector for Prio3<T, P, L> +where + T: Type, + P: Prg<L>, +{ + /// Combines aggregate shares into the aggregate result. + fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>( + &self, + _agg_param: &(), + agg_shares: It, + num_measurements: usize, + ) -> Result<T::AggregateResult, VdafError> { + let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + for agg_share in agg_shares.into_iter() { + agg.merge(&agg_share)?; + } + + Ok(self.typ.decode_result(&agg.0, num_measurements)?) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct JointRandParam<const L: usize> { + /// The joint randomness seed parts corresponding to the other Aggregators' shares. + seed_hint: Vec<Seed<L>>, + + /// The blinding factor, used to derive the aggregator's joint randomness seed part. + blind: Seed<L>, +} + +#[derive(Clone)] +struct HelperShare<const L: usize> { + input_share: Seed<L>, + proof_share: Seed<L>, + joint_rand_param: JointRandParam<L>, +} + +impl<const L: usize> HelperShare<L> { + fn from_rand_source(rand_source: RandSource) -> Result<Self, VdafError> { + Ok(HelperShare { + input_share: Seed::from_rand_source(rand_source)?, + proof_share: Seed::from_rand_source(rand_source)?, + joint_rand_param: JointRandParam { + seed_hint: Vec::new(), + blind: Seed::from_rand_source(rand_source)?, + }, + }) + } +} + +fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> { + if num_aggregators == 0 { + return Err(VdafError::Uncategorized(format!( + "at least one aggregator is required; got {}", + num_aggregators + ))); + } else if num_aggregators > 254 { + return Err(VdafError::Uncategorized(format!( + "number of aggregators must not exceed 254; got {}", + num_aggregators + ))); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vdaf::{run_vdaf, run_vdaf_prepare}; + use assert_matches::assert_matches; + use rand::prelude::*; + + #[test] + fn test_prio3_count() { + let prio3 = Prio3::new_aes128_count(2).unwrap(); + + assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = b"This is a good nonce."; + + let (public_share, input_shares) = prio3.shard(&0).unwrap(); + run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap(); + + let (public_share, input_shares) = prio3.shard(&1).unwrap(); + run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap(); + + test_prepare_state_serialization(&prio3, &1).unwrap(); + + let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap(); + assert_eq!( + run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(), + 3, + ); + } + + #[test] + fn test_prio3_sum() { + let prio3 = Prio3::new_aes128_sum(3, 16).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), + (1 << 16) + 1 + ); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = b"This is a good nonce."; + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255; + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255; + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1).unwrap(); + assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + test_prepare_state_serialization(&prio3, &1).unwrap(); + } + + #[test] + fn test_prio3_countvec() { + let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap(); + assert_eq!( + run_vdaf( + &prio3, + &(), + [vec![ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, + ]] + ) + .unwrap(), + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,] + ); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_prio3_countvec_multithreaded() { + let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap(); + assert_eq!( + run_vdaf( + &prio3, + &(), + [vec![ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, + ]] + ) + .unwrap(), + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,] + ); + } + + #[test] + fn test_prio3_histogram() { + let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(), + vec![1, 1, 1, 1] + ); + assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]); + assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]); + assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]); + test_prepare_state_serialization(&prio3, &23).unwrap(); + } + + #[test] + fn test_prio3_average() { + let prio3 = Prio3::new_aes128_average(2, 64).unwrap(); + + assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64); + assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); + assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64); + assert_eq!( + run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), + 207.5f64 + ); + } + + #[test] + fn test_prio3_input_share() { + let prio3 = Prio3::new_aes128_sum(5, 16).unwrap(); + let (_public_share, input_shares) = prio3.shard(&1).unwrap(); + + // Check that seed shares are distinct. + for (i, x) in input_shares.iter().enumerate() { + for (j, y) in input_shares.iter().enumerate() { + if i != j { + if let (Share::Helper(left), Share::Helper(right)) = + (&x.input_share, &y.input_share) + { + assert_ne!(left, right); + } + + if let (Share::Helper(left), Share::Helper(right)) = + (&x.proof_share, &y.proof_share) + { + assert_ne!(left, right); + } + + assert_ne!(x.joint_rand_param, y.joint_rand_param); + } + } + } + } + + fn test_prepare_state_serialization<T, P, const L: usize>( + prio3: &Prio3<T, P, L>, + measurement: &T::Measurement, + ) -> Result<(), VdafError> + where + T: Type, + P: Prg<L>, + { + let mut verify_key = [0; L]; + thread_rng().fill(&mut verify_key[..]); + let (public_share, input_shares) = prio3.shard(measurement)?; + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (want, _msg) = + prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?; + let got = + Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded()) + .expect("failed to decode prepare step"); + assert_eq!(got, want); + } + Ok(()) + } +} diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs new file mode 100644 index 0000000000..d4c9151ce0 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio3_test.rs @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{ + codec::{Encode, ParameterizedDecode}, + flp::Type, + vdaf::{ + prg::Prg, + prio3::{Prio3, Prio3InputShare, Prio3PrepareShare}, + Aggregator, PrepareTransition, + }, +}; +use serde::{Deserialize, Serialize}; +use std::{convert::TryInto, fmt::Debug}; + +#[derive(Debug, Deserialize, Serialize)] +struct TEncoded(#[serde(with = "hex")] Vec<u8>); + +impl AsRef<[u8]> for TEncoded { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Deserialize, Serialize)] +struct TPrio3Prep<M> { + measurement: M, + #[serde(with = "hex")] + nonce: Vec<u8>, + input_shares: Vec<TEncoded>, + prep_shares: Vec<Vec<TEncoded>>, + prep_messages: Vec<TEncoded>, + out_shares: Vec<Vec<M>>, +} + +#[derive(Deserialize, Serialize)] +struct TPrio3<M> { + verify_key: TEncoded, + prep: Vec<TPrio3Prep<M>>, +} + +macro_rules! err { + ( + $test_num:ident, + $error:expr, + $msg:expr + ) => { + panic!("test #{} failed: {} err: {}", $test_num, $msg, $error) + }; +} + +// TODO Generalize this method to work with any VDAF. To do so we would need to add +// `test_vec_setup()` and `test_vec_shard()` to traits. (There may be a less invasive alternative.) +fn check_prep_test_vec<M, T, P, const L: usize>( + prio3: &Prio3<T, P, L>, + verify_key: &[u8; L], + test_num: usize, + t: &TPrio3Prep<M>, +) where + T: Type<Measurement = M>, + P: Prg<L>, + M: From<<T as Type>::Field> + Debug + PartialEq, +{ + let input_shares = prio3 + .test_vec_shard(&t.measurement) + .expect("failed to generate input shares"); + + assert_eq!(2, t.input_shares.len(), "#{}", test_num); + for (agg_id, want) in t.input_shares.iter().enumerate() { + assert_eq!( + input_shares[agg_id], + Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")), + "#{}", + test_num + ); + assert_eq!( + input_shares[agg_id].get_encoded(), + want.as_ref(), + "#{}", + test_num + ) + } + + let mut states = Vec::new(); + let mut prep_shares = Vec::new(); + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (state, prep_share) = prio3 + .prepare_init(verify_key, agg_id, &(), &t.nonce, &(), input_share) + .unwrap_or_else(|e| err!(test_num, e, "prep state init")); + states.push(state); + prep_shares.push(prep_share); + } + + assert_eq!(1, t.prep_shares.len(), "#{}", test_num); + for (i, want) in t.prep_shares[0].iter().enumerate() { + assert_eq!( + prep_shares[i], + Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")), + "#{}", + test_num + ); + assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{}", test_num); + } + + let inbound = prio3 + .prepare_preprocess(prep_shares) + .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); + assert_eq!(t.prep_messages.len(), 1); + assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref()); + + let mut out_shares = Vec::new(); + for state in states.iter_mut() { + match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() { + PrepareTransition::Finish(out_share) => { + out_shares.push(out_share); + } + _ => panic!("unexpected transition"), + } + } + + for (got, want) in out_shares.iter().zip(t.out_shares.iter()) { + let got: Vec<M> = got.as_ref().iter().map(|x| M::from(*x)).collect(); + assert_eq!(&got, want); + } +} + +#[test] +fn test_vec_prio3_count() { + let t: TPrio3<u64> = + serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Count_0.json")).unwrap(); + let prio3 = Prio3::new_aes128_count(2).unwrap(); + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + for (test_num, p) in t.prep.iter().enumerate() { + check_prep_test_vec(&prio3, &verify_key, test_num, p); + } +} + +#[test] +fn test_vec_prio3_sum() { + let t: TPrio3<u128> = + serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Sum_0.json")).unwrap(); + let prio3 = Prio3::new_aes128_sum(2, 8).unwrap(); + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + for (test_num, p) in t.prep.iter().enumerate() { + check_prep_test_vec(&prio3, &verify_key, test_num, p); + } +} + +#[test] +fn test_vec_prio3_histogram() { + let t: TPrio3<u128> = + serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Histogram_0.json")).unwrap(); + let prio3 = Prio3::new_aes128_histogram(2, &[1, 10, 100]).unwrap(); + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + for (test_num, p) in t.prep.iter().enumerate() { + check_prep_test_vec(&prio3, &verify_key, test_num, p); + } +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json b/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json new file mode 100644 index 0000000000..e450665173 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json @@ -0,0 +1,7 @@ +{ + "derived_seed": "ccf3be704c982182ad2961e9795a88aa", + "expanded_vec_field128": "ccf3be704c982182ad2961e9795a88aa8df71c0b5ea5c13bcf3173c3f3626505e1bf4738874d5405805082cc38c55d1f04f85fbb88b8cf8592ffed8a4ac7f76991c58d850a15e8deb34fb289ab6fab584554ffef16c683228db2b76e792ca4f3c15760044d0703b438c2aefd7975c5dd4b9992ee6f87f20e570572dea18fa580ee17204903c1234f1332d47a442ea636580518ce7aa5943c415117460a049bc19cc81edbb0114d71890cbdbe4ea2664cd038e57b88fb7fd3557830ad363c20b9840d35fd6bee6c3c8424f026ee7fbca3daf3c396a4d6736d7bd3b65b2c228d22a40f4404e47c61b26ac3c88bebf2f268fa972f8831f18bee374a22af0f8bb94d9331a1584bdf8cf3e8a5318b546efee8acd28f6cba8b21b9d52acbae8e726500340da98d643d0a5f1270ecb94c574130cea61224b0bc6d438b2f4f74152e72b37e6a9541c9dc5515f8f98fd0d1bce8743f033ab3e8574180ffc3363f3a0490f6f9583bf73a87b9bb4b51bfd0ef260637a4288c37a491c6cbdc46b6a86cd26edf611793236e912e7227bfb85b560308b06238bbd978f72ed4a58583cf0c6e134066eb6b399ad2f26fa01d69a62d8a2d04b4b8acf82299b07a834d4c2f48fee23a24c20307f9cabcd34b6d69f1969588ebde777e46e9522e866e6dd1e14119a1cb4c0709fa9ea347d9f872e76a39313e7d49bfbf3e5ce807183f43271ba2b5c6aaeaef22da301327c1fd9fedde7c5a68d9b97fa6eb687ec8ca692cb0f631f46e6699a211a1254026c9a0a43eceb450dc97cfa923321baf1f4b6f233260d46182b844dccec153aaddd20f920e9e13ff11434bcd2aa632bf4f544f41b5ddced962939676476f70e0b8640c3471fc7af62d80053781295b070388f7b7f1fa66220cb3", + "info": "696e666f20737472696e67", + "length": 40, + "seed": "01010101010101010101010101010101" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json new file mode 100644 index 0000000000..9e79888745 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json @@ -0,0 +1,37 @@ +{ + "agg_param": null, + "agg_result": 1, + "agg_shares": [ + "ad8bb894e3222b47", + "5274476a1cddd4bb" + ], + "prep": [ + { + "input_shares": [ + "ad8bb894e3222b47b70eb67d4f70cb78644826d67d31129e422b910cf0aab70c0b78fa57b4a7b3aaafae57bd1012e813", + "0101010101010101010101010101010101010101010101010101010101010101" + ], + "measurement": 1, + "nonce": "01010101010101010101010101010101", + "out_shares": [ + [ + 12505291739929652039 + ], + [ + 5941452329484932283 + ] + ], + "prep_messages": [ + "" + ], + "prep_shares": [ + [ + "38d535dd68f3c02ed6681f7ff24239d46fde93c8402d24ebbafa25c77ca3535d", + "c72aca21970c3fd35274476a1cddd4bb4efa24ee0d71473e4a0a23713a347d78" + ] + ], + "public_share": "" + } + ], + "verify_key": "01010101010101010101010101010101" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json new file mode 100644 index 0000000000..f5476455fa --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json @@ -0,0 +1,53 @@ +{ + "agg_param": null, + "agg_result": [ + 0, + 0, + 1, + 0 + ], + "agg_shares": [ + "ee1076c1ebc2d48a557a71031bc9dd5c9cd5e91180bbb51f4ac366946bcbfa93b908792bd15d402f4ac8da264e24a20f645ef68472180c5894bac704ae0675d7", + "11ef893e143d2b59aa858efce43622a5632a16ee7f444ac4b53c996b9434056e46f786d42ea2bfb4b53725d9b1db5df39ba1097b8de7f38b6b4538fb51f98a2a" + ], + "buckets": [ + 1, + 10, + 100 + ], + "prep": [ + { + "input_shares": [ + "ee1076c1ebc2d48a557a71031bc9dd5c9cd5e91180bbb51f4ac366946bcbfa93b908792bd15d402f4ac8da264e24a20f645ef68472180c5894bac704ae0675d7f16776df4f93852a40b514593a73be51ad64d8c28322a47af92c92223dd489998a3c6687861cdc2e4d834885d03d8d3273af0bf742c47985ae8fec6d16c31216792bb0cdca0d1d1fa2287414cd069f8caa42dc08f78dd43e14c4095e2ef9d9609937caebcd534e813136e79a4233e873397a6c7fd164928d43673b32e061139dc6650152d8433e2342f595149418929b74c9e23f1469ed1eebdaa57d0b5c62f90cb5a53dc68c8e030448bb2d9c07aeed50d82c93e1afe8febd68918933ed9b2dd36b9d8a35fd6c57cd76707011fca77526437aeb8392a2013f829c1e395f7f8ddef030f5bc869833f528ae2137a2e667aa648d8643f6c13e8d76e8832ab9ef7d0101010101010101010101010101010194c3f0f1061c8f440b51f806ad822510", + "0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101016195ec204fd5d65c14fac36b73723cde" + ], + "measurement": 50, + "nonce": "01010101010101010101010101010101", + "out_shares": [ + [ + 316441748434879643753815489063091297628, + 208470253761472213750543248431791209107, + 245951175238245845331446316072865931791, + 133415875449384174923011884997795018199 + ], + [ + 23840618486058819193050284304809468581, + 131812113159466249196322524936109557102, + 94331191682692617615419457295034834419, + 206866491471554288023853888370105748010 + ] + ], + "prep_messages": [ + "7912f1157c2ce3a4dca6456224aeaeea" + ], + "prep_shares": [ + [ + "f2dc9e823b867d760b2169644633804eabec10e5869fe8f3030c5da6dc0fce03a433572cb8aaa7ca3559959f7bad68306195ec204fd5d65c14fac36b73723cde", + "0d23617dc479826df4de969bb9cc7fb3f5e934542e987db0271aee33551b28a4c16f7ad00127c43df9c433a1c224594d94c3f0f1061c8f440b51f806ad822510" + ] + ], + "public_share": "" + } + ], + "verify_key": "01010101010101010101010101010101" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json new file mode 100644 index 0000000000..55d2a000db --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json @@ -0,0 +1,38 @@ +{ + "agg_param": null, + "agg_result": 100, + "agg_shares": [ + "ab3bcc5ef693737a7a3e76cd9face3e0", + "54c433a1096c8c6985c1893260531c85" + ], + "bits": 8, + "prep": [ + { + "input_shares": [ + "fc1e42a024e3d8b45f63f485ebe2dc8a356c5e5780a0bd751184e6a02a96c0767f518e87282ebdc039590aef02e40e5492c9eb69dd22b6b4f1d630e7ca8612b7a7e090b39460bc4036345f5ef537d691fd585bc05a2ea580c7e354680afd0fd49f3d083d5e383b97755a842cf5e69870a970b14a10595c0c639ad2e7bda42c7146c4b69fd79e7403d89dac5816d0dc6f2bb987fccca4c4aee64444b7f46431433c59c6e7f2839fe2b7ad9316d31a52dcc0df07f1da14aa38e0cd88de380fda29b33704e8c3439376762739aa5b5cff9e925939773d24ca0e75bcf87149c9bcc2f8462afa6b50513ab003ac00c9ae3685ea52bdee3c814ffd5afc8357d93454b3ffaf0b5e9fd351730f0d55aed54a9cfa86f9119601ce9857ee0af3f579251bcc7ffe51b8393adc36ab6142eb0e0d07c9b2d5ab71d8d5639f32c61f7d59b45a95129cbc76d7e30c02a1329454f843553413d4e84bcab2c3ba1a0150292026dfa37488da5dd639c53edd51bf4eb5aa54d5b165fcd55d10f3496008f4e3b6d3eb200c19c5b9c42ad4f12977a857d02f787b14ced27fc5eefb05722b372a7d48c1891d30a32d84ec8d1f9a783a38bfac2793f0da6796cff90521e1d73f497f7d2c910b7fbbea2ba4b906d437a53bcbed16986f5646fd238e736f1c3e9d3a910218ce7f48dea3e9a1a848c580a1c506a80edb0c0a973a269667475ce88f4424674b14a3a8f2b71ef529d2ca96a3c5e4da384545749a55188d4de0074ad601695e934c9fe71d27c139b7678ead7f904cd2ae2a3aafa96d8211579e391507df96bf42c383f2ac71d7a558ebf1e3d5ab086b65422415bd24be9c979ca5b4f381d51b06ec4f6740b1a084999cd95fe63fec4a019f635640ba18d42312de7d1994947502b9010101010101010101010101010101015f0721f50826593dc3908dad39353846", + "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101094240ceae2d63ba1bdda997fa0bcbd8" + ], + "measurement": 100, + "nonce": "01010101010101010101010101010101", + "out_shares": [ + [ + 227608477929192160221239678567201956832 + ], + [ + 112673888991746302725626094800698809477 + ] + ], + "prep_messages": [ + "60af733578d766f2305c1d53c840b4b5" + ], + "prep_shares": [ + [ + "0a85b5e51cacf514ee9e9bbe5d3ac023795e910b765411e5cea8ff187973640694bd740cc15bc9cad60bc85785206062094240ceae2d63ba1bdda997fa0bcbd8", + "f57a4a1ae3530acf11616441a2c53fde804d262dc42e15e556ee02c588c3ca9d924eefa735a95f6e420f2c5161706e025f0721f50826593dc3908dad39353846" + ] + ], + "public_share": "" + } + ], + "verify_key": "01010101010101010101010101010101" +} |