summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src')
-rw-r--r--third_party/rust/prio/src/benchmarked.rs56
-rw-r--r--third_party/rust/prio/src/client.rs264
-rw-r--r--third_party/rust/prio/src/codec.rs658
-rw-r--r--third_party/rust/prio/src/encrypt.rs232
-rw-r--r--third_party/rust/prio/src/fft.rs226
-rw-r--r--third_party/rust/prio/src/field.rs960
-rw-r--r--third_party/rust/prio/src/flp.rs1035
-rw-r--r--third_party/rust/prio/src/flp/gadgets.rs715
-rw-r--r--third_party/rust/prio/src/flp/types.rs1199
-rw-r--r--third_party/rust/prio/src/fp.rs561
-rw-r--r--third_party/rust/prio/src/lib.rs33
-rw-r--r--third_party/rust/prio/src/polynomial.rs384
-rw-r--r--third_party/rust/prio/src/prng.rs208
-rw-r--r--third_party/rust/prio/src/server.rs469
-rw-r--r--third_party/rust/prio/src/test_vector.rs244
-rw-r--r--third_party/rust/prio/src/util.rs201
-rw-r--r--third_party/rust/prio/src/vdaf.rs562
-rw-r--r--third_party/rust/prio/src/vdaf/poplar1.rs933
-rw-r--r--third_party/rust/prio/src/vdaf/prg.rs239
-rw-r--r--third_party/rust/prio/src/vdaf/prio2.rs425
-rw-r--r--third_party/rust/prio/src/vdaf/prio3.rs1168
-rw-r--r--third_party/rust/prio/src/vdaf/prio3_test.rs162
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json7
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json37
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json53
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json38
26 files changed, 11069 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/benchmarked.rs b/third_party/rust/prio/src/benchmarked.rs
new file mode 100644
index 0000000000..8811250f9a
--- /dev/null
+++ b/third_party/rust/prio/src/benchmarked.rs
@@ -0,0 +1,56 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! This module provides wrappers around internal components of this crate that we want to
+//! benchmark, but which we don't want to expose in the public API.
+
+#[cfg(feature = "prio2")]
+use crate::client::Client;
+use crate::fft::discrete_fourier_transform;
+use crate::field::FieldElement;
+use crate::flp::gadgets::Mul;
+use crate::flp::FlpError;
+use crate::polynomial::{poly_fft, PolyAuxMemory};
+
+/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
+pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
+ discrete_fourier_transform(outp, inp, inp.len()).unwrap();
+}
+
+/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
+pub fn benchmarked_recursive_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
+ let mut mem = PolyAuxMemory::new(inp.len() / 2);
+ poly_fft(
+ outp,
+ inp,
+ &mem.roots_2n,
+ inp.len(),
+ false,
+ &mut mem.fft_memory,
+ )
+}
+
+/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
+/// uses FFT for multiplication.
+pub fn benchmarked_gadget_mul_call_poly_fft<F: FieldElement>(
+ g: &mut Mul<F>,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError> {
+ g.call_poly_fft(outp, inp)
+}
+
+/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
+/// does the multiplication directly.
+pub fn benchmarked_gadget_mul_call_poly_direct<F: FieldElement>(
+ g: &mut Mul<F>,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError> {
+ g.call_poly_direct(outp, inp)
+}
+
+/// Returns a Prio v2 proof that `data` is a valid boolean vector.
+#[cfg(feature = "prio2")]
+pub fn benchmarked_v2_prove<F: FieldElement>(data: &[F], client: &mut Client<F>) -> Vec<F> {
+ client.gen_proof(data)
+}
diff --git a/third_party/rust/prio/src/client.rs b/third_party/rust/prio/src/client.rs
new file mode 100644
index 0000000000..3ed5ee66c7
--- /dev/null
+++ b/third_party/rust/prio/src/client.rs
@@ -0,0 +1,264 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! The Prio v2 client. Only 0 / 1 vectors are supported for now.
+
+use crate::{
+ encrypt::{encrypt_share, EncryptError, PublicKey},
+ field::FieldElement,
+ polynomial::{poly_fft, PolyAuxMemory},
+ prng::{Prng, PrngError},
+ util::{proof_length, unpack_proof_mut},
+ vdaf::{prg::SeedStreamAes128, VdafError},
+};
+
+use std::convert::TryFrom;
+
+/// The main object that can be used to create Prio shares
+///
+/// Client is used to create Prio shares.
+#[derive(Debug)]
+pub struct Client<F: FieldElement> {
+ dimension: usize,
+ mem: ClientMemory<F>,
+ public_key1: PublicKey,
+ public_key2: PublicKey,
+}
+
+/// Errors that might be emitted by the client.
+#[derive(Debug, thiserror::Error)]
+pub enum ClientError {
+ /// Encryption/decryption error
+ #[error("encryption/decryption error")]
+ Encrypt(#[from] EncryptError),
+ /// PRNG error
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+ /// VDAF error
+ #[error("vdaf error: {0}")]
+ Vdaf(#[from] VdafError),
+}
+
+impl<F: FieldElement> Client<F> {
+ /// Construct a new Prio client
+ pub fn new(
+ dimension: usize,
+ public_key1: PublicKey,
+ public_key2: PublicKey,
+ ) -> Result<Self, ClientError> {
+ Ok(Client {
+ dimension,
+ mem: ClientMemory::new(dimension)?,
+ public_key1,
+ public_key2,
+ })
+ }
+
+ /// Construct a pair of encrypted shares based on the input data.
+ pub fn encode_simple(&mut self, data: &[F]) -> Result<(Vec<u8>, Vec<u8>), ClientError> {
+ let copy_data = |share_data: &mut [F]| {
+ share_data[..].clone_from_slice(data);
+ };
+ Ok(self.encode_with(copy_data)?)
+ }
+
+ /// Construct a pair of encrypted shares using a initilization function.
+ ///
+ /// This might be slightly more efficient on large vectors, because one can
+ /// avoid copying the input data.
+ pub fn encode_with<G>(&mut self, init_function: G) -> Result<(Vec<u8>, Vec<u8>), EncryptError>
+ where
+ G: FnOnce(&mut [F]),
+ {
+ let mut proof = self.mem.prove_with(self.dimension, init_function);
+
+ // use prng to share the proof: share2 is the PRNG seed, and proof is mutated
+ // in-place
+ let mut share2 = [0; 32];
+ getrandom::getrandom(&mut share2)?;
+ let share2_prng = Prng::from_prio2_seed(&share2);
+ for (s1, d) in proof.iter_mut().zip(share2_prng.into_iter()) {
+ *s1 -= d;
+ }
+ let share1 = F::slice_into_byte_vec(&proof);
+ // encrypt shares with respective keys
+ let encrypted_share1 = encrypt_share(&share1, &self.public_key1)?;
+ let encrypted_share2 = encrypt_share(&share2, &self.public_key2)?;
+ Ok((encrypted_share1, encrypted_share2))
+ }
+
+ /// Generate a proof of the input's validity. The output is the encoded input and proof.
+ pub(crate) fn gen_proof(&mut self, input: &[F]) -> Vec<F> {
+ let copy_data = |share_data: &mut [F]| {
+ share_data[..].clone_from_slice(input);
+ };
+ self.mem.prove_with(self.dimension, copy_data)
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct ClientMemory<F> {
+ prng: Prng<F, SeedStreamAes128>,
+ points_f: Vec<F>,
+ points_g: Vec<F>,
+ evals_f: Vec<F>,
+ evals_g: Vec<F>,
+ poly_mem: PolyAuxMemory<F>,
+}
+
+impl<F: FieldElement> ClientMemory<F> {
+ pub(crate) fn new(dimension: usize) -> Result<Self, VdafError> {
+ let n = (dimension + 1).next_power_of_two();
+ if let Ok(size) = F::Integer::try_from(2 * n) {
+ if size > F::generator_order() {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+ } else {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+
+ Ok(Self {
+ prng: Prng::new()?,
+ points_f: vec![F::zero(); n],
+ points_g: vec![F::zero(); n],
+ evals_f: vec![F::zero(); 2 * n],
+ evals_g: vec![F::zero(); 2 * n],
+ poly_mem: PolyAuxMemory::new(n),
+ })
+ }
+}
+
+impl<F: FieldElement> ClientMemory<F> {
+ pub(crate) fn prove_with<G>(&mut self, dimension: usize, init_function: G) -> Vec<F>
+ where
+ G: FnOnce(&mut [F]),
+ {
+ let mut proof = vec![F::zero(); proof_length(dimension)];
+ // unpack one long vector to different subparts
+ let unpacked = unpack_proof_mut(&mut proof, dimension).unwrap();
+ // initialize the data part
+ init_function(unpacked.data);
+ // fill in the rest
+ construct_proof(
+ unpacked.data,
+ dimension,
+ unpacked.f0,
+ unpacked.g0,
+ unpacked.h0,
+ unpacked.points_h_packed,
+ self,
+ );
+
+ proof
+ }
+}
+
+/// Convenience function if one does not want to reuse
+/// [`Client`](struct.Client.html).
+pub fn encode_simple<F: FieldElement>(
+ data: &[F],
+ public_key1: PublicKey,
+ public_key2: PublicKey,
+) -> Result<(Vec<u8>, Vec<u8>), ClientError> {
+ let dimension = data.len();
+ let mut client_memory = Client::new(dimension, public_key1, public_key2)?;
+ client_memory.encode_simple(data)
+}
+
+fn interpolate_and_evaluate_at_2n<F: FieldElement>(
+ n: usize,
+ points_in: &[F],
+ evals_out: &mut [F],
+ mem: &mut PolyAuxMemory<F>,
+) {
+ // interpolate through roots of unity
+ poly_fft(
+ &mut mem.coeffs,
+ points_in,
+ &mem.roots_n_inverted,
+ n,
+ true,
+ &mut mem.fft_memory,
+ );
+ // evaluate at 2N roots of unity
+ poly_fft(
+ evals_out,
+ &mem.coeffs,
+ &mem.roots_2n,
+ 2 * n,
+ false,
+ &mut mem.fft_memory,
+ );
+}
+
+/// Proof construction
+///
+/// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation
+/// This constructs the output \pi by doing the necessesary calculations
+fn construct_proof<F: FieldElement>(
+ data: &[F],
+ dimension: usize,
+ f0: &mut F,
+ g0: &mut F,
+ h0: &mut F,
+ points_h_packed: &mut [F],
+ mem: &mut ClientMemory<F>,
+) {
+ let n = (dimension + 1).next_power_of_two();
+
+ // set zero terms to random
+ *f0 = mem.prng.get();
+ *g0 = mem.prng.get();
+ mem.points_f[0] = *f0;
+ mem.points_g[0] = *g0;
+
+ // set zero term for the proof polynomial
+ *h0 = *f0 * *g0;
+
+ // set f_i = data_(i - 1)
+ // set g_i = f_i - 1
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..dimension {
+ mem.points_f[i + 1] = data[i];
+ mem.points_g[i + 1] = data[i] - F::one();
+ }
+
+ // interpolate and evaluate at roots of unity
+ interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem);
+ interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem);
+
+ // calculate the proof polynomial as evals_f(r) * evals_g(r)
+ // only add non-zero points
+ let mut j: usize = 0;
+ let mut i: usize = 1;
+ while i < 2 * n {
+ points_h_packed[j] = mem.evals_f[i] * mem.evals_g[i];
+ j += 1;
+ i += 2;
+ }
+}
+
+#[test]
+fn test_encode() {
+ use crate::field::Field32;
+ let pub_key1 = PublicKey::from_base64(
+ "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=",
+ )
+ .unwrap();
+ let pub_key2 = PublicKey::from_base64(
+ "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LE=",
+ )
+ .unwrap();
+
+ let data_u32 = [0u32, 1, 0, 1, 1, 0, 0, 0, 1];
+ let data = data_u32
+ .iter()
+ .map(|x| Field32::from(*x))
+ .collect::<Vec<Field32>>();
+ let encoded_shares = encode_simple(&data, pub_key1, pub_key2);
+ assert!(encoded_shares.is_ok());
+}
diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs
new file mode 100644
index 0000000000..c3ee8e9db7
--- /dev/null
+++ b/third_party/rust/prio/src/codec.rs
@@ -0,0 +1,658 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Module `codec` provides support for encoding and decoding messages to or from the TLS wire
+//! encoding, as specified in [RFC 8446, Section 3][1]. It provides traits that can be implemented
+//! on values that need to be encoded or decoded, as well as utility functions for encoding
+//! sequences of values.
+//!
+//! [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3
+
+use byteorder::{BigEndian, ReadBytesExt};
+use std::{
+ convert::TryInto,
+ error::Error,
+ io::{Cursor, Read},
+ mem::size_of,
+ num::TryFromIntError,
+};
+
+#[allow(missing_docs)]
+#[derive(Debug, thiserror::Error)]
+pub enum CodecError {
+ #[error("I/O error")]
+ Io(#[from] std::io::Error),
+ #[error("{0} bytes left in buffer after decoding value")]
+ BytesLeftOver(usize),
+ #[error("length prefix of encoded vector overflows buffer: {0}")]
+ LengthPrefixTooBig(usize),
+ #[error("other error: {0}")]
+ Other(#[source] Box<dyn Error + 'static + Send + Sync>),
+ #[error("unexpected value")]
+ UnexpectedValue,
+}
+
+/// Describes how to decode an object from a byte sequence.
+pub trait Decode: Sized {
+ /// Read and decode an encoded object from `bytes`. On success, the decoded value is returned
+ /// and `bytes` is advanced by the encoded size of the value. On failure, an error is returned
+ /// and no further attempt to read from `bytes` should be made.
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>;
+
+ /// Convenience method to get decoded value. Returns an error if [`Self::decode`] fails, or if
+ /// there are any bytes left in `bytes` after decoding a value.
+ fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> {
+ Self::get_decoded_with_param(&(), bytes)
+ }
+}
+
+/// Describes how to decode an object from a byte sequence, with a decoding parameter provided to
+/// provide additional data used in decoding.
+pub trait ParameterizedDecode<P>: Sized {
+ /// Read and decode an encoded object from `bytes`. `decoding_parameter` provides details of the
+ /// wire encoding such as lengths of different portions of the message. On success, the decoded
+ /// value is returned and `bytes` is advanced by the encoded size of the value. On failure, an
+ /// error is returned and no further attempt to read from `bytes` should be made.
+ fn decode_with_param(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError>;
+
+ /// Convenience method to get decoded value. Returns an error if [`Self::decode_with_param`]
+ /// fails, or if there are any bytes left in `bytes` after decoding a value.
+ fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> {
+ let mut cursor = Cursor::new(bytes);
+ let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?;
+ if cursor.position() as usize != bytes.len() {
+ return Err(CodecError::BytesLeftOver(
+ bytes.len() - cursor.position() as usize,
+ ));
+ }
+
+ Ok(decoded)
+ }
+}
+
+// Provide a blanket implementation so that any Decode can be used as a ParameterizedDecode<T> for
+// any T.
+impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D {
+ fn decode_with_param(
+ _decoding_parameter: &T,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Self::decode(bytes)
+ }
+}
+
+/// Describes how to encode objects into a byte sequence.
+pub trait Encode {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ fn encode(&self, bytes: &mut Vec<u8>);
+
+ /// Convenience method to get encoded value.
+ fn get_encoded(&self) -> Vec<u8> {
+ self.get_encoded_with_param(&())
+ }
+}
+
+/// Describes how to encode objects into a byte sequence.
+pub trait ParameterizedEncode<P> {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ /// `encoding_parameter` provides details of the wire encoding, used to control how the value
+ /// is encoded.
+ fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>);
+
+ /// Convenience method to get encoded value.
+ fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> {
+ let mut ret = Vec::new();
+ self.encode_with_param(encoding_parameter, &mut ret);
+ ret
+ }
+}
+
+// Provide a blanket implementation so that any Encode can be used as a ParameterizedEncode<T> for
+// any T.
+impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
+ fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) {
+ self.encode(bytes)
+ }
+}
+
+impl Decode for () {
+ fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(())
+ }
+}
+
+impl Encode for () {
+ fn encode(&self, _bytes: &mut Vec<u8>) {}
+}
+
+impl Decode for u8 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut value = [0u8; size_of::<u8>()];
+ bytes.read_exact(&mut value)?;
+ Ok(value[0])
+ }
+}
+
+impl Encode for u8 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.push(*self);
+ }
+}
+
+impl Decode for u16 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u16::<BigEndian>()?)
+ }
+}
+
+impl Encode for u16 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u16::to_be_bytes(*self));
+ }
+}
+
+/// 24 bit integer, per
+/// [RFC 8443, section 3.3](https://datatracker.ietf.org/doc/html/rfc8446#section-3.3)
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+struct U24(pub u32);
+
+impl Decode for U24 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(U24(bytes.read_u24::<BigEndian>()?))
+ }
+}
+
+impl Encode for U24 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // Encode lower three bytes of the u32 as u24
+ bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
+ }
+}
+
+impl Decode for u32 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u32::<BigEndian>()?)
+ }
+}
+
+impl Encode for u32 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u32::to_be_bytes(*self));
+ }
+}
+
+impl Decode for u64 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u64::<BigEndian>()?)
+ }
+}
+
+impl Encode for u64 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u64::to_be_bytes(*self));
+ }
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ bytes.push(0);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 1;
+ assert!(len <= u8::MAX.into());
+ bytes[len_offset] = len as u8;
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u8_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read one byte to get length of opaque byte vector
+ let length = usize::from(u8::decode(bytes)?);
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ 0u16.encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 2;
+ assert!(len <= u16::MAX.into());
+ for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u16_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read two bytes to get length of opaque byte vector
+ let length = usize::from(u16::decode(bytes)?);
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
+/// `0xffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ U24(0).encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 3;
+ assert!(len <= 0xffffff);
+ for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u24_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read three bytes to get length of opaque byte vector
+ let length = U24::decode(bytes)?.0 as usize;
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
+/// `0xffffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u32_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ 0u32.encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 4;
+ let len: u32 = len.try_into().expect("Length too large");
+ for (offset, byte) in len.to_be_bytes().iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u32_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read four bytes to get length of opaque byte vector.
+ let len: usize = u32::decode(bytes)?
+ .try_into()
+ .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?;
+
+ decode_items(len, decoding_parameter, bytes)
+}
+
+/// Decode the next `length` bytes from `bytes` into as many instances of `D` as possible.
+fn decode_items<P, D: ParameterizedDecode<P>>(
+ length: usize,
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ let mut decoded = Vec::new();
+ let initial_position = bytes.position() as usize;
+
+ // Create cursor over specified portion of provided cursor to ensure we can't read past length.
+ let inner = bytes.get_ref();
+
+ // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer.
+ let (items_end, overflowed) = initial_position.overflowing_add(length);
+ if overflowed || items_end > inner.len() {
+ return Err(CodecError::LengthPrefixTooBig(length));
+ }
+
+ let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]);
+
+ while sub.position() < length as u64 {
+ decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?);
+ }
+
+ // Advance outer cursor by the amount read in the inner cursor
+ bytes.set_position(initial_position as u64 + sub.position());
+
+ Ok(decoded)
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn encode_nothing() {
+ let mut bytes = vec![];
+ ().encode(&mut bytes);
+ assert_eq!(bytes.len(), 0);
+ }
+
+ #[test]
+ fn roundtrip_u8() {
+ let value = 100u8;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 1);
+
+ let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u16() {
+ let value = 1000u16;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 2);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![3, 232]);
+
+ let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u24() {
+ let value = U24(1_000_000u32);
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 3);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![15, 66, 64]);
+
+ let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u32() {
+ let value = 134_217_728u32;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 4);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![8, 0, 0, 0]);
+
+ let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u64() {
+ let value = 137_438_953_472u64;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 8);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
+
+ let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct TestMessage {
+ field_u8: u8,
+ field_u16: u16,
+ field_u24: U24,
+ field_u32: u32,
+ field_u64: u64,
+ }
+
+ impl Encode for TestMessage {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.field_u8.encode(bytes);
+ self.field_u16.encode(bytes);
+ self.field_u24.encode(bytes);
+ self.field_u32.encode(bytes);
+ self.field_u64.encode(bytes);
+ }
+ }
+
+ impl Decode for TestMessage {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let field_u8 = u8::decode(bytes)?;
+ let field_u16 = u16::decode(bytes)?;
+ let field_u24 = U24::decode(bytes)?;
+ let field_u32 = u32::decode(bytes)?;
+ let field_u64 = u64::decode(bytes)?;
+
+ Ok(TestMessage {
+ field_u8,
+ field_u16,
+ field_u24,
+ field_u32,
+ field_u64,
+ })
+ }
+ }
+
+ impl TestMessage {
+ fn encoded_length() -> usize {
+ // u8 field
+ 1 +
+ // u16 field
+ 2 +
+ // u24 field
+ 3 +
+ // u32 field
+ 4 +
+ // u64 field
+ 8
+ }
+ }
+
+ #[test]
+ fn roundtrip_message() {
+ let value = TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ };
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), TestMessage::encoded_length());
+
+ let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ fn messages_vec() -> Vec<TestMessage> {
+ vec![
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ ]
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u8() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u8_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 1 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u16() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u16_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 2 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ // Check endianness of encoded length
+ assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]);
+
+ let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u24() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u24_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 3 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ // Check endianness of encoded length
+ assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]);
+
+ let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u32() {
+ let values = messages_vec();
+ let mut bytes = Vec::new();
+ encode_u32_items(&mut bytes, &(), &values);
+
+ assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length());
+
+ // Check endianness of encoded length.
+ assert_eq!(
+ bytes[0..4],
+ [0, 0, 0, 3 * TestMessage::encoded_length() as u8]
+ );
+
+ let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn decode_items_overflow() {
+ let encoded = vec![1u8];
+
+ let mut cursor = Cursor::new(encoded.as_slice());
+ cursor.set_position(1);
+
+ assert_matches!(
+ decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(),
+ CodecError::LengthPrefixTooBig(usize::MAX)
+ );
+ }
+
+ #[test]
+ fn decode_items_too_big() {
+ let encoded = vec![1u8];
+
+ let mut cursor = Cursor::new(encoded.as_slice());
+ cursor.set_position(1);
+
+ assert_matches!(
+ decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(),
+ CodecError::LengthPrefixTooBig(2)
+ );
+ }
+}
diff --git a/third_party/rust/prio/src/encrypt.rs b/third_party/rust/prio/src/encrypt.rs
new file mode 100644
index 0000000000..9429c0f2ce
--- /dev/null
+++ b/third_party/rust/prio/src/encrypt.rs
@@ -0,0 +1,232 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Utilities for ECIES encryption / decryption used by the Prio client and server.
+
+use crate::prng::PrngError;
+
+use aes_gcm::aead::generic_array::typenum::U16;
+use aes_gcm::aead::generic_array::GenericArray;
+use aes_gcm::{AeadInPlace, NewAead};
+use ring::agreement;
+type Aes128 = aes_gcm::AesGcm<aes_gcm::aes::Aes128, U16>;
+
+/// Length of the EC public key (X9.62 format)
+pub const PUBLICKEY_LENGTH: usize = 65;
+/// Length of the AES-GCM tag
+pub const TAG_LENGTH: usize = 16;
+/// Length of the symmetric AES-GCM key
+const KEY_LENGTH: usize = 16;
+
+/// Possible errors from encryption / decryption.
+#[derive(Debug, thiserror::Error)]
+pub enum EncryptError {
+ /// Base64 decoding error
+ #[error("base64 decoding error")]
+ DecodeBase64(#[from] base64::DecodeError),
+ /// Error in ECDH
+ #[error("error in ECDH")]
+ KeyAgreement,
+ /// Buffer for ciphertext was not large enough
+ #[error("buffer for ciphertext was not large enough")]
+ Encryption,
+ /// Authentication tags did not match.
+ #[error("authentication tags did not match")]
+ Decryption,
+ /// Input ciphertext was too small
+ #[error("input ciphertext was too small")]
+ DecryptionLength,
+ /// PRNG error
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+ /// failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+}
+
+/// NIST P-256, public key in X9.62 uncompressed format
+#[derive(Debug, Clone)]
+pub struct PublicKey(Vec<u8>);
+
+/// NIST P-256, private key
+///
+/// X9.62 uncompressed public key concatenated with the secret scalar.
+#[derive(Debug, Clone)]
+pub struct PrivateKey(Vec<u8>);
+
+impl PublicKey {
+ /// Load public key from a base64 encoded X9.62 uncompressed representation.
+ pub fn from_base64(key: &str) -> Result<Self, EncryptError> {
+ let keydata = base64::decode(key)?;
+ Ok(PublicKey(keydata))
+ }
+}
+
+/// Copy public key from a private key
+impl std::convert::From<&PrivateKey> for PublicKey {
+ fn from(pk: &PrivateKey) -> Self {
+ PublicKey(pk.0[..PUBLICKEY_LENGTH].to_owned())
+ }
+}
+
+impl PrivateKey {
+ /// Load private key from a base64 encoded string.
+ pub fn from_base64(key: &str) -> Result<Self, EncryptError> {
+ let keydata = base64::decode(key)?;
+ Ok(PrivateKey(keydata))
+ }
+}
+
+/// Encrypt a bytestring using the public key
+///
+/// This uses ECIES with X9.63 key derivation function and AES-GCM for the
+/// symmetic encryption and MAC.
+pub fn encrypt_share(share: &[u8], key: &PublicKey) -> Result<Vec<u8>, EncryptError> {
+ let rng = ring::rand::SystemRandom::new();
+ let ephemeral_priv = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &rng)
+ .map_err(|_| EncryptError::KeyAgreement)?;
+ let peer_public = agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, &key.0);
+ let ephemeral_pub = ephemeral_priv
+ .compute_public_key()
+ .map_err(|_| EncryptError::KeyAgreement)?;
+
+ let symmetric_key_bytes = agreement::agree_ephemeral(
+ ephemeral_priv,
+ &peer_public,
+ EncryptError::KeyAgreement,
+ |material| Ok(x963_kdf(material, ephemeral_pub.as_ref())),
+ )?;
+
+ let in_out = share.to_owned();
+ let encrypted = encrypt_aes_gcm(
+ &symmetric_key_bytes[..16],
+ &symmetric_key_bytes[16..],
+ in_out,
+ )?;
+
+ let mut output = Vec::with_capacity(encrypted.len() + ephemeral_pub.as_ref().len());
+ output.extend_from_slice(ephemeral_pub.as_ref());
+ output.extend_from_slice(&encrypted);
+
+ Ok(output)
+}
+
+/// Decrypt a bytestring using the private key
+///
+/// This uses ECIES with X9.63 key derivation function and AES-GCM for the
+/// symmetic encryption and MAC.
+pub fn decrypt_share(share: &[u8], key: &PrivateKey) -> Result<Vec<u8>, EncryptError> {
+ if share.len() < PUBLICKEY_LENGTH + TAG_LENGTH {
+ return Err(EncryptError::DecryptionLength);
+ }
+ let empheral_pub_bytes: &[u8] = &share[0..PUBLICKEY_LENGTH];
+
+ let ephemeral_pub =
+ agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, empheral_pub_bytes);
+
+ let fake_rng = ring::test::rand::FixedSliceRandom {
+ // private key consists of the public key + private scalar
+ bytes: &key.0[PUBLICKEY_LENGTH..],
+ };
+
+ let private_key = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &fake_rng)
+ .map_err(|_| EncryptError::KeyAgreement)?;
+
+ let symmetric_key_bytes = agreement::agree_ephemeral(
+ private_key,
+ &ephemeral_pub,
+ EncryptError::KeyAgreement,
+ |material| Ok(x963_kdf(material, empheral_pub_bytes)),
+ )?;
+
+ // in_out is the AES-GCM ciphertext+tag, wihtout the ephemeral EC pubkey
+ let in_out = share[PUBLICKEY_LENGTH..].to_owned();
+ decrypt_aes_gcm(
+ &symmetric_key_bytes[..KEY_LENGTH],
+ &symmetric_key_bytes[KEY_LENGTH..],
+ in_out,
+ )
+}
+
+fn x963_kdf(z: &[u8], shared_info: &[u8]) -> [u8; 32] {
+ let mut hasher = ring::digest::Context::new(&ring::digest::SHA256);
+ hasher.update(z);
+ hasher.update(&1u32.to_be_bytes());
+ hasher.update(shared_info);
+ let digest = hasher.finish();
+ use std::convert::TryInto;
+ // unwrap never fails because SHA256 output len is 32
+ digest.as_ref().try_into().unwrap()
+}
+
+fn decrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> {
+ let cipher = Aes128::new(GenericArray::from_slice(key));
+ cipher
+ .decrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data)
+ .map_err(|_| EncryptError::Decryption)?;
+ Ok(data)
+}
+
+fn encrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> {
+ let cipher = Aes128::new(GenericArray::from_slice(key));
+ cipher
+ .encrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data)
+ .map_err(|_| EncryptError::Encryption)?;
+ Ok(data)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_encrypt_decrypt() -> Result<(), EncryptError> {
+ let pub_key = PublicKey::from_base64(
+ "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9\
+ HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=",
+ )?;
+ let priv_key = PrivateKey::from_base64(
+ "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgN\
+ t9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==",
+ )?;
+ let data = (0..100).map(|_| rand::random::<u8>()).collect::<Vec<u8>>();
+ let encrypted = encrypt_share(&data, &pub_key)?;
+
+ let decrypted = decrypt_share(&encrypted, &priv_key)?;
+ assert_eq!(decrypted, data);
+ Ok(())
+ }
+
+ #[test]
+ fn test_interop() {
+ let share1 = base64::decode("Kbnd2ZWrsfLfcpuxHffMrJ1b7sCrAsNqlb6Y1eAMfwCVUNXt").unwrap();
+ let share2 = base64::decode("hu+vT3+8/taHP7B/dWXh/g==").unwrap();
+ let encrypted_share1 = base64::decode(
+ "BEWObg41JiMJglSEA6Ebk37xOeflD2a1t2eiLmX0OPccJhAER5NmOI+4r4Cfm7aJn141sGKnTbCuIB9+AeVuw\
+ MAQnzjsGPu5aNgkdpp+6VowAcVAV1DlzZvtwlQkCFlX4f3xmafTPFTPOokYi2a+H1n8GKwd",
+ )
+ .unwrap();
+ let encrypted_share2 = base64::decode(
+ "BNRzQ6TbqSc4pk0S8aziVRNjWm4DXQR5yCYTK2w22iSw4XAPW4OB9RxBpWVa1C/3ywVBT/3yLArOMXEsCEMOG\
+ 1+d2CiEvtuU1zADH2MVaCnXL/dVXkDchYZsvPWPkDcjQA==",
+ )
+ .unwrap();
+
+ let priv_key1 = PrivateKey::from_base64(
+ "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOg\
+ Nt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==",
+ )
+ .unwrap();
+ let priv_key2 = PrivateKey::from_base64(
+ "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhF\
+ LMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==",
+ )
+ .unwrap();
+
+ let decrypted1 = decrypt_share(&encrypted_share1, &priv_key1).unwrap();
+ let decrypted2 = decrypt_share(&encrypted_share2, &priv_key2).unwrap();
+
+ assert_eq!(decrypted1, share1);
+ assert_eq!(decrypted2, share2);
+ }
+}
diff --git a/third_party/rust/prio/src/fft.rs b/third_party/rust/prio/src/fft.rs
new file mode 100644
index 0000000000..c7a1dfbb8b
--- /dev/null
+++ b/third_party/rust/prio/src/fft.rs
@@ -0,0 +1,226 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
+//! Transform (DFT) over a slice of field elements.
+
+use crate::field::FieldElement;
+use crate::fp::{log2, MAX_ROOTS};
+
+use std::convert::TryFrom;
+
+/// An error returned by an FFT operation.
+#[derive(Debug, PartialEq, Eq, thiserror::Error)]
+pub enum FftError {
+ /// The output is too small.
+ #[error("output slice is smaller than specified size")]
+ OutputTooSmall,
+ /// The specified size is too large.
+ #[error("size is larger than than maximum permitted")]
+ SizeTooLarge,
+ /// The specified size is not a power of 2.
+ #[error("size is not a power of 2")]
+ SizeInvalid,
+}
+
+/// Sets `outp` to the DFT of `inp`.
+///
+/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input
+/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of
+/// unity.
+#[allow(clippy::many_single_char_names)]
+pub fn discrete_fourier_transform<F: FieldElement>(
+ outp: &mut [F],
+ inp: &[F],
+ size: usize,
+) -> Result<(), FftError> {
+ let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?;
+
+ if size > outp.len() {
+ return Err(FftError::OutputTooSmall);
+ }
+
+ if size > 1 << MAX_ROOTS {
+ return Err(FftError::SizeTooLarge);
+ }
+
+ if size != 1 << d {
+ return Err(FftError::SizeInvalid);
+ }
+
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..size {
+ let j = bitrev(d, i);
+ outp[i] = if j < inp.len() { inp[j] } else { F::zero() }
+ }
+
+ let mut w: F;
+ for l in 1..d + 1 {
+ w = F::one();
+ let r = F::root(l).unwrap();
+ let y = 1 << (l - 1);
+ for i in 0..y {
+ for j in 0..(size / y) >> 1 {
+ let x = (1 << l) * j + i;
+ let u = outp[x];
+ let v = w * outp[x + y];
+ outp[x] = u + v;
+ outp[x + y] = u - v;
+ }
+ w *= r;
+ }
+ }
+
+ Ok(())
+}
+
+/// Sets `outp` to the inverse of the DFT of `inp`.
+#[cfg(test)]
+pub(crate) fn discrete_fourier_transform_inv<F: FieldElement>(
+ outp: &mut [F],
+ inp: &[F],
+ size: usize,
+) -> Result<(), FftError> {
+ let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv();
+ discrete_fourier_transform(outp, inp, size)?;
+ discrete_fourier_transform_inv_finish(outp, size, size_inv);
+ Ok(())
+}
+
+/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to
+/// amortize the cost the modular inverse across multiple inverse DFT operations.
+pub(crate) fn discrete_fourier_transform_inv_finish<F: FieldElement>(
+ outp: &mut [F],
+ size: usize,
+ size_inv: F,
+) {
+ let mut tmp: F;
+ outp[0] *= size_inv;
+ outp[size >> 1] *= size_inv;
+ for i in 1..size >> 1 {
+ tmp = outp[i] * size_inv;
+ outp[i] = outp[size - i] * size_inv;
+ outp[size - i] = tmp;
+ }
+}
+
+// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
+fn bitrev(d: usize, x: usize) -> usize {
+ let mut y = 0;
+ for i in 0..d {
+ y += ((x >> i) & 1) << (d - i);
+ }
+ y >> 1
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{
+ random_vector, split_vector, Field128, Field32, Field64, Field96, FieldPrio2,
+ };
+ use crate::polynomial::{poly_fft, PolyAuxMemory};
+
+ fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
+ let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];
+
+ for size in test_sizes.iter() {
+ let mut tmp = vec![F::zero(); *size];
+ let mut got = vec![F::zero(); *size];
+ let want = random_vector(*size).unwrap();
+
+ discrete_fourier_transform(&mut tmp, &want, want.len())?;
+ discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?;
+ assert_eq!(got, want);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_field32() {
+ discrete_fourier_transform_then_inv_test::<Field32>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_priov2_field32() {
+ discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field64() {
+ discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field96() {
+ discrete_fourier_transform_then_inv_test::<Field96>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field128() {
+ discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_recursive_fft() {
+ let size = 128;
+ let mut mem = PolyAuxMemory::new(size / 2);
+
+ let inp = random_vector(size).unwrap();
+ let mut want = vec![Field32::zero(); size];
+ let mut got = vec![Field32::zero(); size];
+
+ discrete_fourier_transform::<Field32>(&mut want, &inp, inp.len()).unwrap();
+
+ poly_fft(
+ &mut got,
+ &inp,
+ &mem.roots_2n,
+ size,
+ false,
+ &mut mem.fft_memory,
+ );
+
+ assert_eq!(got, want);
+ }
+
+ // This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial
+ // over secret shares and summing up the coefficients is equivalent to interpolating a
+ // polynomial over the plaintext data.
+ #[test]
+ fn test_fft_linearity() {
+ let len = 16;
+ let num_shares = 3;
+ let x: Vec<Field64> = random_vector(len).unwrap();
+ let mut x_shares = split_vector(&x, num_shares).unwrap();
+
+ // Just for fun, let's do something different with a subset of the inputs. For the first
+ // share, every odd element is set to the plaintext value. For all shares but the first,
+ // every odd element is set to 0.
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..len {
+ if i % 2 != 0 {
+ x_shares[0][i] = x[i];
+ }
+ for j in 1..num_shares {
+ if i % 2 != 0 {
+ x_shares[j][i] = Field64::zero();
+ }
+ }
+ }
+
+ let mut got = vec![Field64::zero(); len];
+ let mut buf = vec![Field64::zero(); len];
+ for share in x_shares {
+ discrete_fourier_transform_inv(&mut buf, &share, len).unwrap();
+ for i in 0..len {
+ got[i] += buf[i];
+ }
+ }
+
+ let mut want = vec![Field64::zero(); len];
+ discrete_fourier_transform_inv(&mut want, &x, len).unwrap();
+
+ assert_eq!(got, want);
+ }
+}
diff --git a/third_party/rust/prio/src/field.rs b/third_party/rust/prio/src/field.rs
new file mode 100644
index 0000000000..cbfe40826e
--- /dev/null
+++ b/third_party/rust/prio/src/field.rs
@@ -0,0 +1,960 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Finite field arithmetic.
+//!
+//! Each field has an associated parameter called the "generator" that generates a multiplicative
+//! subgroup of order `2^n` for some `n`.
+
+#[cfg(feature = "crypto-dependencies")]
+use crate::prng::{Prng, PrngError};
+use crate::{
+ codec::{CodecError, Decode, Encode},
+ fp::{FP128, FP32, FP64, FP96},
+};
+use serde::{
+ de::{DeserializeOwned, Visitor},
+ Deserialize, Deserializer, Serialize, Serializer,
+};
+use std::{
+ cmp::min,
+ convert::{TryFrom, TryInto},
+ fmt::{self, Debug, Display, Formatter},
+ hash::{Hash, Hasher},
+ io::{Cursor, Read},
+ marker::PhantomData,
+ ops::{Add, AddAssign, BitAnd, Div, DivAssign, Mul, MulAssign, Neg, Shl, Shr, Sub, SubAssign},
+};
+
+/// Possible errors from finite field operations.
+#[derive(Debug, thiserror::Error)]
+pub enum FieldError {
+ /// Input sizes do not match.
+ #[error("input sizes do not match")]
+ InputSizeMismatch,
+ /// Returned when decoding a `FieldElement` from a short byte string.
+ #[error("short read from bytes")]
+ ShortRead,
+ /// Returned when decoding a `FieldElement` from a byte string encoding an integer larger than
+ /// or equal to the field modulus.
+ #[error("read from byte slice exceeds modulus")]
+ ModulusOverflow,
+ /// Error while performing I/O.
+ #[error("I/O error")]
+ Io(#[from] std::io::Error),
+ /// Error encoding or decoding a field.
+ #[error("Codec error")]
+ Codec(#[from] CodecError),
+ /// Error converting to `FieldElement::Integer`.
+ #[error("Integer TryFrom error")]
+ IntegerTryFrom,
+ /// Error converting `FieldElement::Integer` into something else.
+ #[error("Integer TryInto error")]
+ IntegerTryInto,
+}
+
+/// Byte order for encoding FieldElement values into byte sequences.
+#[derive(Clone, Copy, Debug)]
+enum ByteOrder {
+ /// Big endian byte order.
+ BigEndian,
+ /// Little endian byte order.
+ LittleEndian,
+}
+
+/// Objects with this trait represent an element of `GF(p)` for some prime `p`.
+pub trait FieldElement:
+ Sized
+ + Debug
+ + Copy
+ + PartialEq
+ + Eq
+ + Add<Output = Self>
+ + AddAssign
+ + Sub<Output = Self>
+ + SubAssign
+ + Mul<Output = Self>
+ + MulAssign
+ + Div<Output = Self>
+ + DivAssign
+ + Neg<Output = Self>
+ + Display
+ + From<<Self as FieldElement>::Integer>
+ + for<'a> TryFrom<&'a [u8], Error = FieldError>
+ // NOTE Ideally we would require `Into<[u8; Self::ENCODED_SIZE]>` instead of `Into<Vec<u8>>`,
+ // since the former avoids a heap allocation and can easily be converted into Vec<u8>, but that
+ // isn't possible yet[1]. However we can provide the impl on FieldElement implementations.
+ // [1]: https://github.com/rust-lang/rust/issues/60551
+ + Into<Vec<u8>>
+ + Serialize
+ + DeserializeOwned
+ + Encode
+ + Decode
+ + 'static // NOTE This bound is needed for downcasting a `dyn Gadget<F>>` to a concrete type.
+{
+ /// Size in bytes of the encoding of a value.
+ const ENCODED_SIZE: usize;
+
+ /// The error returned if converting `usize` to an `Integer` fails.
+ type IntegerTryFromError: std::error::Error;
+
+ /// The error returend if converting an `Integer` to a `u64` fails.
+ type TryIntoU64Error: std::error::Error;
+
+ /// The integer representation of the field element.
+ type Integer: Copy
+ + Debug
+ + Eq
+ + Ord
+ + BitAnd<Output = <Self as FieldElement>::Integer>
+ + Div<Output = <Self as FieldElement>::Integer>
+ + Shl<Output = <Self as FieldElement>::Integer>
+ + Shr<Output = <Self as FieldElement>::Integer>
+ + Add<Output = <Self as FieldElement>::Integer>
+ + Sub<Output = <Self as FieldElement>::Integer>
+ + From<Self>
+ + TryFrom<usize, Error = Self::IntegerTryFromError>
+ + TryInto<u64, Error = Self::TryIntoU64Error>;
+
+ /// Modular exponentation, i.e., `self^exp (mod p)`.
+ fn pow(&self, exp: Self::Integer) -> Self;
+
+ /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
+ fn inv(&self) -> Self;
+
+ /// Returns the prime modulus `p`.
+ fn modulus() -> Self::Integer;
+
+ /// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the
+ /// field. The `m` most significant bits are cleared, where `m` is equal to the length of
+ /// [`Self::Integer`] in bits minus the length of the modulus in bits.
+ ///
+ /// # Errors
+ ///
+ /// An error is returned if the provided slice is too small to encode a field element or if the
+ /// result encodes an integer larger than or equal to the field modulus.
+ ///
+ /// # Warnings
+ ///
+ /// This function should only be used within [`prng::Prng`] to convert a random byte string into
+ /// a field element. Use [`Self::decode`] to deserialize field elements. Use
+ /// [`field::rand`] or [`prng::Prng`] to randomly generate field elements.
+ #[doc(hidden)]
+ fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError>;
+
+ /// Returns the size of the multiplicative subgroup generated by `generator()`.
+ fn generator_order() -> Self::Integer;
+
+ /// Returns the generator of the multiplicative subgroup of size `generator_order()`.
+ fn generator() -> Self;
+
+ /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
+ /// prinicpal root of unity is 1 by definition.
+ fn root(l: usize) -> Option<Self>;
+
+ /// Returns the additive identity.
+ fn zero() -> Self;
+
+ /// Returns the multiplicative identity.
+ fn one() -> Self;
+
+ /// Convert a slice of field elements into a vector of bytes.
+ ///
+ /// # Notes
+ ///
+ /// Ideally we would implement `From<&[F: FieldElement]> for Vec<u8>` or the corresponding
+ /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
+ /// impossible.
+ fn slice_into_byte_vec(values: &[Self]) -> Vec<u8> {
+ let mut vec = Vec::with_capacity(values.len() * Self::ENCODED_SIZE);
+ for elem in values {
+ vec.append(&mut (*elem).into());
+ }
+ vec
+ }
+
+ /// Convert a slice of bytes into a vector of field elements. The slice is interpreted as a
+ /// sequence of [`Self::ENCODED_SIZE`]-byte sequences.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the length of the provided byte slice is not a multiple of the size of a
+ /// field element, or if any of the values in the byte slice are invalid encodings of a field
+ /// element, because the encoded integer is larger than or equal to the field modulus.
+ ///
+ /// # Notes
+ ///
+ /// Ideally we would implement `From<&[u8]> for Vec<F: FieldElement>` or the corresponding
+ /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
+ /// impossible.
+ fn byte_slice_into_vec(bytes: &[u8]) -> Result<Vec<Self>, FieldError> {
+ if bytes.len() % Self::ENCODED_SIZE != 0 {
+ return Err(FieldError::ShortRead);
+ }
+ let mut vec = Vec::with_capacity(bytes.len() / Self::ENCODED_SIZE);
+ for chunk in bytes.chunks_exact(Self::ENCODED_SIZE) {
+ vec.push(Self::get_decoded(chunk)?);
+ }
+ Ok(vec)
+ }
+}
+
+/// Methods common to all `FieldElement` implementations that are private to the crate.
+pub(crate) trait FieldElementExt: FieldElement {
+ /// Encode `input` as `bits`-bit vector of elements of `Self` if it's small enough
+ /// to be represented with that many bits.
+ ///
+ /// # Arguments
+ ///
+ /// * `input` - The field element to encode
+ /// * `bits` - The number of bits to use for the encoding
+ fn encode_into_bitvector_representation(
+ input: &Self::Integer,
+ bits: usize,
+ ) -> Result<Vec<Self>, FieldError> {
+ // Create a mutable copy of `input`. In each iteration of the following loop we take the
+ // least significant bit, and shift input to the right by one bit.
+ let mut i = *input;
+
+ let one = Self::Integer::from(Self::one());
+ let mut encoded = Vec::with_capacity(bits);
+ for _ in 0..bits {
+ let w = Self::from(i & one);
+ encoded.push(w);
+ i = i >> one;
+ }
+
+ // If `i` is still not zero, this means that it cannot be encoded by `bits` bits.
+ if i != Self::Integer::from(Self::zero()) {
+ return Err(FieldError::InputSizeMismatch);
+ }
+
+ Ok(encoded)
+ }
+
+ /// Decode the bitvector-represented value `input` into a simple representation as a single
+ /// field element.
+ ///
+ /// # Errors
+ ///
+ /// This function errors if `2^input.len() - 1` does not fit into the field `Self`.
+ fn decode_from_bitvector_representation(input: &[Self]) -> Result<Self, FieldError> {
+ if !Self::valid_integer_bitlength(input.len()) {
+ return Err(FieldError::ModulusOverflow);
+ }
+
+ let mut decoded = Self::zero();
+ for (l, bit) in input.iter().enumerate() {
+ let w = Self::Integer::try_from(1 << l).map_err(|_| FieldError::IntegerTryFrom)?;
+ decoded += Self::from(w) * *bit;
+ }
+ Ok(decoded)
+ }
+
+ /// Interpret `i` as [`Self::Integer`] if it's representable in that type and smaller than the
+ /// field modulus.
+ fn valid_integer_try_from<N>(i: N) -> Result<Self::Integer, FieldError>
+ where
+ Self::Integer: TryFrom<N>,
+ {
+ let i_int = Self::Integer::try_from(i).map_err(|_| FieldError::IntegerTryFrom)?;
+ if Self::modulus() <= i_int {
+ return Err(FieldError::ModulusOverflow);
+ }
+ Ok(i_int)
+ }
+
+ /// Check if the largest number representable with `bits` bits (i.e. 2^bits - 1) is
+ /// representable in this field.
+ fn valid_integer_bitlength(bits: usize) -> bool {
+ if let Ok(bits_int) = Self::Integer::try_from(bits) {
+ if Self::modulus() >> bits_int != Self::Integer::from(Self::zero()) {
+ return true;
+ }
+ }
+ false
+ }
+}
+
+impl<F: FieldElement> FieldElementExt for F {}
+
+/// serde Visitor implementation used to generically deserialize `FieldElement`
+/// values from byte arrays.
+struct FieldElementVisitor<F: FieldElement> {
+ phantom: PhantomData<F>,
+}
+
+impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> {
+ type Value = F;
+
+ fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
+ formatter.write_fmt(format_args!("an array of {} bytes", F::ENCODED_SIZE))
+ }
+
+ fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
+ where
+ E: serde::de::Error,
+ {
+ Self::Value::try_from(v).map_err(E::custom)
+ }
+
+ fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::SeqAccess<'de>,
+ {
+ let mut bytes = vec![];
+ while let Some(byte) = seq.next_element()? {
+ bytes.push(byte);
+ }
+
+ self.visit_bytes(&bytes)
+ }
+}
+
+macro_rules! make_field {
+ (
+ $(#[$meta:meta])*
+ $elem:ident, $int:ident, $fp:ident, $encoding_size:literal, $encoding_order:expr,
+ ) => {
+ $(#[$meta])*
+ ///
+ /// This structure represents a field element in a prime order field. The concrete
+ /// representation of the element is via the Montgomery domain. For an element n in GF(p),
+ /// we store n * R^-1 mod p (where R is a given power of two). This representation enables
+ /// using a more efficient (and branchless) multiplication algorithm, at the expense of
+ /// having to convert elements between their Montgomery domain representation and natural
+ /// representation. For calculations with many multiplications or exponentiations, this is
+ /// worthwhile.
+ ///
+ /// As an invariant, this integer representing the field element in the Montgomery domain
+ /// must be less than the prime p.
+ #[derive(Clone, Copy, PartialOrd, Ord, Default)]
+ pub struct $elem(u128);
+
+ impl $elem {
+ /// Attempts to instantiate an `$elem` from the first `Self::ENCODED_SIZE` bytes in the
+ /// provided slice. The decoded value will be bitwise-ANDed with `mask` before reducing
+ /// it using the field modulus.
+ ///
+ /// # Errors
+ ///
+ /// An error is returned if the provided slice is not long enough to encode a field
+ /// element or if the decoded value is greater than the field prime.
+ ///
+ /// # Notes
+ ///
+ /// We cannot use `u128::from_le_bytes` or `u128::from_be_bytes` because those functions
+ /// expect inputs to be exactly 16 bytes long. Our encoding of most field elements is
+ /// more compact, and does not have to correspond to the size of an integer type. For
+ /// instance,`Field96`'s encoding is 12 bytes, even though it is a 16 byte `u128` in
+ /// memory.
+ fn try_from_bytes(bytes: &[u8], mask: u128) -> Result<Self, FieldError> {
+ if Self::ENCODED_SIZE > bytes.len() {
+ return Err(FieldError::ShortRead);
+ }
+
+ let mut int = 0;
+ for i in 0..Self::ENCODED_SIZE {
+ let j = match $encoding_order {
+ ByteOrder::LittleEndian => i,
+ ByteOrder::BigEndian => Self::ENCODED_SIZE - i - 1,
+ };
+
+ int |= (bytes[j] as u128) << (i << 3);
+ }
+
+ int &= mask;
+
+ if int >= $fp.p {
+ return Err(FieldError::ModulusOverflow);
+ }
+ // FieldParameters::montgomery() will return a value that has been fully reduced
+ // mod p, satisfying the invariant on Self.
+ Ok(Self($fp.montgomery(int)))
+ }
+ }
+
+ impl PartialEq for $elem {
+ fn eq(&self, rhs: &Self) -> bool {
+ // The fields included in this comparison MUST match the fields
+ // used in Hash::hash
+ // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
+
+ // Check the invariant that the integer representation is fully reduced.
+ debug_assert!(self.0 < $fp.p);
+ debug_assert!(rhs.0 < $fp.p);
+
+ self.0 == rhs.0
+ }
+ }
+
+ impl Hash for $elem {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ // The fields included in this hash MUST match the fields used
+ // in PartialEq::eq
+ // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
+
+ // Check the invariant that the integer representation is fully reduced.
+ debug_assert!(self.0 < $fp.p);
+
+ self.0.hash(state);
+ }
+ }
+
+ impl Eq for $elem {}
+
+ impl Add for $elem {
+ type Output = $elem;
+ fn add(self, rhs: Self) -> Self {
+ // FieldParameters::add() returns a value that has been fully reduced
+ // mod p, satisfying the invariant on Self.
+ Self($fp.add(self.0, rhs.0))
+ }
+ }
+
+ impl Add for &$elem {
+ type Output = $elem;
+ fn add(self, rhs: Self) -> $elem {
+ *self + *rhs
+ }
+ }
+
+ impl AddAssign for $elem {
+ fn add_assign(&mut self, rhs: Self) {
+ *self = *self + rhs;
+ }
+ }
+
+ impl Sub for $elem {
+ type Output = $elem;
+ fn sub(self, rhs: Self) -> Self {
+ // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub()
+ // returns a value less than p, satisfying the invariant on Self.
+ Self($fp.sub(self.0, rhs.0))
+ }
+ }
+
+ impl Sub for &$elem {
+ type Output = $elem;
+ fn sub(self, rhs: Self) -> $elem {
+ *self - *rhs
+ }
+ }
+
+ impl SubAssign for $elem {
+ fn sub_assign(&mut self, rhs: Self) {
+ *self = *self - rhs;
+ }
+ }
+
+ impl Mul for $elem {
+ type Output = $elem;
+ fn mul(self, rhs: Self) -> Self {
+ // FieldParameters::mul() always returns a value less than p, so the invariant on
+ // Self is satisfied.
+ Self($fp.mul(self.0, rhs.0))
+ }
+ }
+
+ impl Mul for &$elem {
+ type Output = $elem;
+ fn mul(self, rhs: Self) -> $elem {
+ *self * *rhs
+ }
+ }
+
+ impl MulAssign for $elem {
+ fn mul_assign(&mut self, rhs: Self) {
+ *self = *self * rhs;
+ }
+ }
+
+ impl Div for $elem {
+ type Output = $elem;
+ #[allow(clippy::suspicious_arithmetic_impl)]
+ fn div(self, rhs: Self) -> Self {
+ self * rhs.inv()
+ }
+ }
+
+ impl Div for &$elem {
+ type Output = $elem;
+ fn div(self, rhs: Self) -> $elem {
+ *self / *rhs
+ }
+ }
+
+ impl DivAssign for $elem {
+ fn div_assign(&mut self, rhs: Self) {
+ *self = *self / rhs;
+ }
+ }
+
+ impl Neg for $elem {
+ type Output = $elem;
+ fn neg(self) -> Self {
+ // FieldParameters::neg() will return a value less than p because self.0 is less
+ // than p, and neg() dispatches to sub().
+ Self($fp.neg(self.0))
+ }
+ }
+
+ impl Neg for &$elem {
+ type Output = $elem;
+ fn neg(self) -> $elem {
+ -(*self)
+ }
+ }
+
+ impl From<$int> for $elem {
+ fn from(x: $int) -> Self {
+ // FieldParameters::montgomery() will return a value that has been fully reduced
+ // mod p, satisfying the invariant on Self.
+ Self($fp.montgomery(u128::try_from(x).unwrap()))
+ }
+ }
+
+ impl From<$elem> for $int {
+ fn from(x: $elem) -> Self {
+ $int::try_from($fp.residue(x.0)).unwrap()
+ }
+ }
+
+ impl PartialEq<$int> for $elem {
+ fn eq(&self, rhs: &$int) -> bool {
+ $fp.residue(self.0) == u128::try_from(*rhs).unwrap()
+ }
+ }
+
+ impl<'a> TryFrom<&'a [u8]> for $elem {
+ type Error = FieldError;
+
+ fn try_from(bytes: &[u8]) -> Result<Self, FieldError> {
+ Self::try_from_bytes(bytes, u128::MAX)
+ }
+ }
+
+ impl From<$elem> for [u8; $elem::ENCODED_SIZE] {
+ fn from(elem: $elem) -> Self {
+ let int = $fp.residue(elem.0);
+ let mut slice = [0; $elem::ENCODED_SIZE];
+ for i in 0..$elem::ENCODED_SIZE {
+ let j = match $encoding_order {
+ ByteOrder::LittleEndian => i,
+ ByteOrder::BigEndian => $elem::ENCODED_SIZE - i - 1,
+ };
+
+ slice[j] = ((int >> (i << 3)) & 0xff) as u8;
+ }
+ slice
+ }
+ }
+
+ impl From<$elem> for Vec<u8> {
+ fn from(elem: $elem) -> Self {
+ <[u8; $elem::ENCODED_SIZE]>::from(elem).to_vec()
+ }
+ }
+
+ impl Display for $elem {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ write!(f, "{}", $fp.residue(self.0))
+ }
+ }
+
+ impl Debug for $elem {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", $fp.residue(self.0))
+ }
+ }
+
+ // We provide custom [`serde::Serialize`] and [`serde::Deserialize`] implementations because
+ // the derived implementations would represent `FieldElement` values as the backing `u128`,
+ // which is not what we want because (1) we can be more efficient in all cases and (2) in
+ // some circumstances, [some serializers don't support `u128`](https://github.com/serde-rs/json/issues/625).
+ impl Serialize for $elem {
+ fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+ let bytes: [u8; $elem::ENCODED_SIZE] = (*self).into();
+ serializer.serialize_bytes(&bytes)
+ }
+ }
+
+ impl<'de> Deserialize<'de> for $elem {
+ fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<$elem, D::Error> {
+ deserializer.deserialize_bytes(FieldElementVisitor { phantom: PhantomData })
+ }
+ }
+
+ impl Encode for $elem {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ let slice = <[u8; $elem::ENCODED_SIZE]>::from(*self);
+ bytes.extend_from_slice(&slice);
+ }
+ }
+
+ impl Decode for $elem {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut value = [0u8; $elem::ENCODED_SIZE];
+ bytes.read_exact(&mut value)?;
+ $elem::try_from_bytes(&value, u128::MAX).map_err(|e| {
+ CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>)
+ })
+ }
+ }
+
+ impl FieldElement for $elem {
+ const ENCODED_SIZE: usize = $encoding_size;
+ type Integer = $int;
+ type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
+ type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;
+
+ fn pow(&self, exp: Self::Integer) -> Self {
+ // FieldParameters::pow() relies on mul(), and will always return a value less
+ // than p.
+ Self($fp.pow(self.0, u128::try_from(exp).unwrap()))
+ }
+
+ fn inv(&self) -> Self {
+ // FieldParameters::inv() ultimately relies on mul(), and will always return a
+ // value less than p.
+ Self($fp.inv(self.0))
+ }
+
+ fn modulus() -> Self::Integer {
+ $fp.p as $int
+ }
+
+ fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> {
+ $elem::try_from_bytes(bytes, $fp.bit_mask)
+ }
+
+ fn generator() -> Self {
+ Self($fp.g)
+ }
+
+ fn generator_order() -> Self::Integer {
+ 1 << (Self::Integer::try_from($fp.num_roots).unwrap())
+ }
+
+ fn root(l: usize) -> Option<Self> {
+ if l < min($fp.roots.len(), $fp.num_roots+1) {
+ Some(Self($fp.roots[l]))
+ } else {
+ None
+ }
+ }
+
+ fn zero() -> Self {
+ Self(0)
+ }
+
+ fn one() -> Self {
+ Self($fp.roots[0])
+ }
+ }
+ };
+}
+
+make_field!(
+ /// `GF(4293918721)`, a 32-bit field.
+ Field32,
+ u32,
+ FP32,
+ 4,
+ ByteOrder::BigEndian,
+);
+
+make_field!(
+ /// Same as Field32, but encoded in little endian for compatibility with Prio v2.
+ FieldPrio2,
+ u32,
+ FP32,
+ 4,
+ ByteOrder::LittleEndian,
+);
+
+make_field!(
+ /// `GF(18446744069414584321)`, a 64-bit field.
+ Field64,
+ u64,
+ FP64,
+ 8,
+ ByteOrder::BigEndian,
+);
+
+make_field!(
+ /// `GF(79228148845226978974766202881)`, a 96-bit field.
+ Field96,
+ u128,
+ FP96,
+ 12,
+ ByteOrder::BigEndian,
+);
+
+make_field!(
+ /// `GF(340282366920938462946865773367900766209)`, a 128-bit field.
+ Field128,
+ u128,
+ FP128,
+ 16,
+ ByteOrder::BigEndian,
+);
+
+/// Merge two vectors of fields by summing other_vector into accumulator.
+///
+/// # Errors
+///
+/// Fails if the two vectors do not have the same length.
+#[cfg(any(test, feature = "prio2"))]
+pub(crate) fn merge_vector<F: FieldElement>(
+ accumulator: &mut [F],
+ other_vector: &[F],
+) -> Result<(), FieldError> {
+ if accumulator.len() != other_vector.len() {
+ return Err(FieldError::InputSizeMismatch);
+ }
+ for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) {
+ *a += *o;
+ }
+
+ Ok(())
+}
+
+/// Outputs an additive secret sharing of the input.
+#[cfg(feature = "crypto-dependencies")]
+pub(crate) fn split_vector<F: FieldElement>(
+ inp: &[F],
+ num_shares: usize,
+) -> Result<Vec<Vec<F>>, PrngError> {
+ if num_shares == 0 {
+ return Ok(vec![]);
+ }
+
+ let mut outp = Vec::with_capacity(num_shares);
+ outp.push(inp.to_vec());
+
+ for _ in 1..num_shares {
+ let share: Vec<F> = random_vector(inp.len())?;
+ for (x, y) in outp[0].iter_mut().zip(&share) {
+ *x -= *y;
+ }
+ outp.push(share);
+ }
+
+ Ok(outp)
+}
+
+/// Generate a vector of uniform random field elements.
+#[cfg(feature = "crypto-dependencies")]
+pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> {
+ Ok(Prng::new()?.take(len).collect())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::fp::MAX_ROOTS;
+ use crate::prng::Prng;
+ use assert_matches::assert_matches;
+ use std::collections::hash_map::DefaultHasher;
+
+ #[test]
+ fn test_endianness() {
+ let little_endian_encoded: [u8; FieldPrio2::ENCODED_SIZE] =
+ FieldPrio2(0x12_34_56_78).into();
+
+ let mut big_endian_encoded: [u8; Field32::ENCODED_SIZE] = Field32(0x12_34_56_78).into();
+ big_endian_encoded.reverse();
+
+ assert_eq!(little_endian_encoded, big_endian_encoded);
+ }
+
+ #[test]
+ fn test_accumulate() {
+ let mut lhs = vec![Field32(1); 10];
+ let rhs = vec![Field32(2); 10];
+
+ merge_vector(&mut lhs, &rhs).unwrap();
+
+ lhs.iter().for_each(|f| assert_eq!(*f, Field32(3)));
+ rhs.iter().for_each(|f| assert_eq!(*f, Field32(2)));
+
+ let wrong_len = vec![Field32::zero(); 9];
+ let result = merge_vector(&mut lhs, &wrong_len);
+ assert_matches!(result, Err(FieldError::InputSizeMismatch));
+ }
+
+ fn hash_helper<H: Hash>(input: H) -> u64 {
+ let mut hasher = DefaultHasher::new();
+ input.hash(&mut hasher);
+ hasher.finish()
+ }
+
+ // Some of the checks in this function, like `assert_eq!(one - one, zero)`
+ // or `assert_eq!(two / two, one)` trip this clippy lint for tautological
+ // comparisons, but we have a legitimate need to verify these basics. We put
+ // the #[allow] on the whole function since "attributes on expressions are
+ // experimental" https://github.com/rust-lang/rust/issues/15701
+ #[allow(clippy::eq_op)]
+ fn field_element_test<F: FieldElement + Hash>() {
+ let mut prng: Prng<F, _> = Prng::new().unwrap();
+ let int_modulus = F::modulus();
+ let int_one = F::Integer::try_from(1).unwrap();
+ let zero = F::zero();
+ let one = F::one();
+ let two = F::from(F::Integer::try_from(2).unwrap());
+ let four = F::from(F::Integer::try_from(4).unwrap());
+
+ // add
+ assert_eq!(F::from(int_modulus - int_one) + one, zero);
+ assert_eq!(one + one, two);
+ assert_eq!(two + F::from(int_modulus), two);
+
+ // sub
+ assert_eq!(zero - one, F::from(int_modulus - int_one));
+ assert_eq!(one - one, zero);
+ assert_eq!(two - F::from(int_modulus), two);
+ assert_eq!(one - F::from(int_modulus - int_one), two);
+
+ // add + sub
+ for _ in 0..100 {
+ let f = prng.get();
+ let g = prng.get();
+ assert_eq!(f + g - f - g, zero);
+ assert_eq!(f + g - g, f);
+ assert_eq!(f + g - f, g);
+ }
+
+ // mul
+ assert_eq!(two * two, four);
+ assert_eq!(two * one, two);
+ assert_eq!(two * zero, zero);
+ assert_eq!(one * F::from(int_modulus), zero);
+
+ // div
+ assert_eq!(four / two, two);
+ assert_eq!(two / two, one);
+ assert_eq!(zero / two, zero);
+ assert_eq!(two / zero, zero); // Undefined behavior
+ assert_eq!(zero.inv(), zero); // Undefined behavior
+
+ // mul + div
+ for _ in 0..100 {
+ let f = prng.get();
+ if f == zero {
+ continue;
+ }
+ assert_eq!(f * f.inv(), one);
+ assert_eq!(f.inv() * f, one);
+ }
+
+ // pow
+ assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one);
+ assert_eq!(two.pow(int_one), two);
+ assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four);
+ assert_eq!(two.pow(int_modulus - int_one), one);
+ assert_eq!(two.pow(int_modulus), two);
+
+ // roots
+ let mut int_order = F::generator_order();
+ for l in 0..MAX_ROOTS + 1 {
+ assert_eq!(
+ F::generator().pow(int_order),
+ F::root(l).unwrap(),
+ "failure for F::root({})",
+ l
+ );
+ int_order = int_order >> int_one;
+ }
+
+ // serialization
+ let test_inputs = vec![zero, one, prng.get(), F::from(int_modulus - int_one)];
+ for want in test_inputs.iter() {
+ let mut bytes = vec![];
+ want.encode(&mut bytes);
+
+ assert_eq!(bytes.len(), F::ENCODED_SIZE);
+
+ let got = F::get_decoded(&bytes).unwrap();
+ assert_eq!(got, *want);
+ }
+
+ let serialized_vec = F::slice_into_byte_vec(&test_inputs);
+ let deserialized = F::byte_slice_into_vec(&serialized_vec).unwrap();
+ assert_eq!(deserialized, test_inputs);
+
+ // equality and hash: Generate many elements, confirm they are not equal, and confirm
+ // various products that should be equal have the same hash. Three is chosen as a generator
+ // here because it happens to generate fairly large subgroups of (Z/pZ)* for all four
+ // primes.
+ let three = F::from(F::Integer::try_from(3).unwrap());
+ let mut powers_of_three = Vec::with_capacity(500);
+ let mut power = one;
+ for _ in 0..500 {
+ powers_of_three.push(power);
+ power *= three;
+ }
+ // Check all these elements are mutually not equal.
+ for i in 0..powers_of_three.len() {
+ let first = &powers_of_three[i];
+ for second in &powers_of_three[0..i] {
+ assert_ne!(first, second);
+ }
+ }
+
+ // Check that 3^i is the same whether it's calculated with pow() or repeated
+ // multiplication, with both equality and hash equality.
+ for (i, power) in powers_of_three.iter().enumerate() {
+ let result = three.pow(F::Integer::try_from(i).unwrap());
+ assert_eq!(result, *power);
+ let hash1 = hash_helper(power);
+ let hash2 = hash_helper(result);
+ assert_eq!(hash1, hash2);
+ }
+
+ // Check that 3^n = (3^i)*(3^(n-i)), via both equality and hash equality.
+ let expected_product = powers_of_three[powers_of_three.len() - 1];
+ let expected_hash = hash_helper(expected_product);
+ for i in 0..powers_of_three.len() {
+ let a = powers_of_three[i];
+ let b = powers_of_three[powers_of_three.len() - 1 - i];
+ let product = a * b;
+ assert_eq!(product, expected_product);
+ assert_eq!(hash_helper(product), expected_hash);
+ }
+
+ // Construct an element from a number that needs to be reduced, and test comparisons on it,
+ // confirming that FieldParameters::montgomery() reduced it correctly.
+ let p = F::from(int_modulus);
+ assert_eq!(p, zero);
+ assert_eq!(hash_helper(p), hash_helper(zero));
+ let p_plus_one = F::from(int_modulus + F::Integer::try_from(1).unwrap());
+ assert_eq!(p_plus_one, one);
+ assert_eq!(hash_helper(p_plus_one), hash_helper(one));
+ }
+
+ #[test]
+ fn test_field32() {
+ field_element_test::<Field32>();
+ }
+
+ #[test]
+ fn test_field_priov2() {
+ field_element_test::<FieldPrio2>();
+ }
+
+ #[test]
+ fn test_field64() {
+ field_element_test::<Field64>();
+ }
+
+ #[test]
+ fn test_field96() {
+ field_element_test::<Field96>();
+ }
+
+ #[test]
+ fn test_field128() {
+ field_element_test::<Field128>();
+ }
+}
diff --git a/third_party/rust/prio/src/flp.rs b/third_party/rust/prio/src/flp.rs
new file mode 100644
index 0000000000..7f37347ca3
--- /dev/null
+++ b/third_party/rust/prio/src/flp.rs
@@ -0,0 +1,1035 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of the generic Fully Linear Proof (FLP) system specified in
+//! [[draft-irtf-cfrg-vdaf-03]]. This is the main building block of [`Prio3`](crate::vdaf::prio3).
+//!
+//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation
+//! specifies a validity circuit that defines the set of valid measurements, as well as the finite
+//! field in which the validity circuit is evaluated. It also determines how raw measurements are
+//! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of
+//! measurements.
+//!
+//! # Overview
+//!
+//! The proof system is comprised of three algorithms. The first, `prove`, is run by the prover in
+//! order to generate a proof of a statement's validity. The second and third, `query` and
+//! `decide`, are run by the verifier in order to check the proof. The proof asserts that the input
+//! is an element of a language recognized by the arithmetic circuit. If an input is _not_ valid,
+//! then the verification step will fail with high probability:
+//!
+//! ```
+//! use prio::flp::types::Count;
+//! use prio::flp::Type;
+//! use prio::field::{random_vector, FieldElement, Field64};
+//!
+//! // The prover chooses a measurement.
+//! let count = Count::new();
+//! let input: Vec<Field64> = count.encode_measurement(&0).unwrap();
+//!
+//! // The prover and verifier agree on "joint randomness" used to generate and
+//! // check the proof. The application needs to ensure that the prover
+//! // "commits" to the input before this point. In Prio3, the joint
+//! // randomness is derived from additive shares of the input.
+//! let joint_rand = random_vector(count.joint_rand_len()).unwrap();
+//!
+//! // The prover generates the proof.
+//! let prove_rand = random_vector(count.prove_rand_len()).unwrap();
+//! let proof = count.prove(&input, &prove_rand, &joint_rand).unwrap();
+//!
+//! // The verifier checks the proof. In the first step, the verifier "queries"
+//! // the input and proof, getting the "verifier message" in response. It then
+//! // inspects the verifier to decide if the input is valid.
+//! let query_rand = random_vector(count.query_rand_len()).unwrap();
+//! let verifier = count.query(&input, &proof, &query_rand, &joint_rand, 1).unwrap();
+//! assert!(count.decide(&verifier).unwrap());
+//! ```
+//!
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError};
+use crate::field::{FieldElement, FieldError};
+use crate::fp::log2;
+use crate::polynomial::poly_eval;
+use std::any::Any;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+
+pub mod gadgets;
+pub mod types;
+
+/// Errors propagated by methods in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum FlpError {
+ /// Calling [`Type::prove`] returned an error.
+ #[error("prove error: {0}")]
+ Prove(String),
+
+ /// Calling [`Type::query`] returned an error.
+ #[error("query error: {0}")]
+ Query(String),
+
+ /// Calling [`Type::decide`] returned an error.
+ #[error("decide error: {0}")]
+ Decide(String),
+
+ /// Calling a gadget returned an error.
+ #[error("gadget error: {0}")]
+ Gadget(String),
+
+ /// Calling the validity circuit returned an error.
+ #[error("validity circuit error: {0}")]
+ Valid(String),
+
+ /// Calling [`Type::encode_measurement`] returned an error.
+ #[error("value error: {0}")]
+ Encode(String),
+
+ /// Calling [`Type::decode_result`] returned an error.
+ #[error("value error: {0}")]
+ Decode(String),
+
+ /// Calling [`Type::truncate`] returned an error.
+ #[error("truncate error: {0}")]
+ Truncate(String),
+
+ /// Returned if an FFT operation propagates an error.
+ #[error("FFT error: {0}")]
+ Fft(#[from] FftError),
+
+ /// Returned if a field operation encountered an error.
+ #[error("Field error: {0}")]
+ Field(#[from] FieldError),
+
+ /// Unit test error.
+ #[cfg(test)]
+ #[error("test failed: {0}")]
+ Test(String),
+}
+
+/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
+/// as a vector of field elements and how validity of the encoded measurement is determined.
+/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
+pub trait Type: Sized + Eq + Clone + Debug {
+ /// The Prio3 VDAF identifier corresponding to this type.
+ const ID: u32;
+
+ /// The type of raw measurement to be encoded.
+ type Measurement: Clone + Debug;
+
+ /// The type of aggregate result for this type.
+ type AggregateResult: Clone + Debug;
+
+ /// The finite field used for this type.
+ type Field: FieldElement;
+
+ /// Encodes a measurement as a vector of [`Self::input_len`] field elements.
+ fn encode_measurement(
+ &self,
+ measurement: &Self::Measurement,
+ ) -> Result<Vec<Self::Field>, FlpError>;
+
+ /// Decode an aggregate result.
+ fn decode_result(
+ &self,
+ data: &[Self::Field],
+ num_measurements: usize,
+ ) -> Result<Self::AggregateResult, FlpError>;
+
+ /// Returns the sequence of gadgets associated with the validity circuit.
+ ///
+ /// # Notes
+ ///
+ /// The construction of [[BBCG+19], Theorem 4.3] uses a single gadget rather than many. The
+ /// idea to generalize the proof system to allow multiple gadgets is discussed briefly in
+ /// [[BBCG+19], Remark 4.5], but no construction is given. The construction implemented here
+ /// requires security analysis.
+ ///
+ /// [BBCG+19]: https://ia.cr/2019/188
+ fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>;
+
+ /// Evaluates the validity circuit on an input and returns the output.
+ ///
+ /// # Parameters
+ ///
+ /// * `gadgets` is the sequence of gadgets, presumably output by [`Self::gadget`].
+ /// * `input` is the input to be validated.
+ /// * `joint_rand` is the joint randomness shared by the prover and verifier.
+ /// * `num_shares` is the number of input shares.
+ ///
+ /// # Example usage
+ ///
+ /// Applications typically do not call this method directly. It is used internally by
+ /// [`Self::prove`] and [`Self::query`] to generate and verify the proof respectively.
+ ///
+ /// ```
+ /// use prio::flp::types::Count;
+ /// use prio::flp::Type;
+ /// use prio::field::{random_vector, FieldElement, Field64};
+ ///
+ /// let count = Count::new();
+ /// let input: Vec<Field64> = count.encode_measurement(&1).unwrap();
+ /// let joint_rand = random_vector(count.joint_rand_len()).unwrap();
+ /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap();
+ /// assert_eq!(v, Field64::zero());
+ /// ```
+ fn valid(
+ &self,
+ gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>,
+ input: &[Self::Field],
+ joint_rand: &[Self::Field],
+ num_shares: usize,
+ ) -> Result<Self::Field, FlpError>;
+
+ /// Constructs an aggregatable output from an encoded input. Calling this method is only safe
+ /// once `input` has been validated.
+ fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>;
+
+ /// The length in field elements of the encoded input returned by [`Self::encode_measurement`].
+ fn input_len(&self) -> usize;
+
+ /// The length in field elements of the proof generated for this type.
+ fn proof_len(&self) -> usize;
+
+ /// The length in field elements of the verifier message constructed by [`Self::query`].
+ fn verifier_len(&self) -> usize;
+
+ /// The length of the truncated output (i.e., the output of [`Type::truncate`]).
+ fn output_len(&self) -> usize;
+
+ /// The length of the joint random input.
+ fn joint_rand_len(&self) -> usize;
+
+ /// The length in field elements of the random input consumed by the prover to generate a
+ /// proof. This is the same as the sum of the arity of each gadget in the validity circuit.
+ fn prove_rand_len(&self) -> usize;
+
+ /// The length in field elements of the random input consumed by the verifier to make queries
+ /// against inputs and proofs. This is the same as the number of gadgets in the validity
+ /// circuit.
+ fn query_rand_len(&self) -> usize;
+
+ /// Generate a proof of an input's validity. The return value is a sequence of
+ /// [`Self::proof_len`] field elements.
+ ///
+ /// # Parameters
+ ///
+ /// * `input` is the input.
+ /// * `prove_rand` is the prover' randomness.
+ /// * `joint_rand` is the randomness shared by the prover and verifier.
+ #[allow(clippy::needless_range_loop)]
+ fn prove(
+ &self,
+ input: &[Self::Field],
+ prove_rand: &[Self::Field],
+ joint_rand: &[Self::Field],
+ ) -> Result<Vec<Self::Field>, FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Prove(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len()
+ )));
+ }
+
+ if prove_rand.len() != self.prove_rand_len() {
+ return Err(FlpError::Prove(format!(
+ "unexpected prove randomness length: got {}; want {}",
+ prove_rand.len(),
+ self.prove_rand_len()
+ )));
+ }
+
+ if joint_rand.len() != self.joint_rand_len() {
+ return Err(FlpError::Prove(format!(
+ "unexpected joint randomness length: got {}; want {}",
+ joint_rand.len(),
+ self.joint_rand_len()
+ )));
+ }
+
+ let mut prove_rand_len = 0;
+ let mut shim = self
+ .gadget()
+ .into_iter()
+ .map(|inner| {
+ let inner_arity = inner.arity();
+ if prove_rand_len + inner_arity > prove_rand.len() {
+ return Err(FlpError::Prove(format!(
+ "short prove randomness: got {}; want {}",
+ prove_rand.len(),
+ self.prove_rand_len()
+ )));
+ }
+
+ let gadget = Box::new(ProveShimGadget::new(
+ inner,
+ &prove_rand[prove_rand_len..prove_rand_len + inner_arity],
+ )?) as Box<dyn Gadget<Self::Field>>;
+ prove_rand_len += inner_arity;
+
+ Ok(gadget)
+ })
+ .collect::<Result<Vec<_>, FlpError>>()?;
+ assert_eq!(prove_rand_len, self.prove_rand_len());
+
+ // Create a buffer for storing the proof. The buffer is longer than the proof itself; the extra
+ // length is to accommodate the computation of each gadget polynomial.
+ let data_len = (0..shim.len())
+ .map(|idx| {
+ let gadget_poly_len =
+ gadget_poly_len(shim[idx].degree(), wire_poly_len(shim[idx].calls()));
+
+ // Computing the gadget polynomial using FFT requires an amount of memory that is a
+ // power of 2. Thus we choose the smallest power of 2 that is at least as large as
+ // the gadget polynomial. The wire seeds are encoded in the proof, too, so we
+ // include the arity of the gadget to ensure there is always enough room at the end
+ // of the buffer to compute the next gadget polynomial. It's likely that the
+ // memory footprint here can be reduced, with a bit of care.
+ shim[idx].arity() + gadget_poly_len.next_power_of_two()
+ })
+ .sum();
+ let mut proof = vec![Self::Field::zero(); data_len];
+
+ // Run the validity circuit with a sequence of "shim" gadgets that record the value of each
+ // input wire of each gadget evaluation. These values are used to construct the wire
+ // polynomials for each gadget in the next step.
+ let _ = self.valid(&mut shim, input, joint_rand, 1)?;
+
+ // Construct the proof.
+ let mut proof_len = 0;
+ for idx in 0..shim.len() {
+ let gadget = shim[idx]
+ .as_any()
+ .downcast_mut::<ProveShimGadget<Self::Field>>()
+ .unwrap();
+
+ // Interpolate the wire polynomials `f[0], ..., f[g_arity-1]` from the input wires of each
+ // evaluation of the gadget.
+ let m = wire_poly_len(gadget.calls());
+ let m_inv =
+ Self::Field::from(<Self::Field as FieldElement>::Integer::try_from(m).unwrap())
+ .inv();
+ let mut f = vec![vec![Self::Field::zero(); m]; gadget.arity()];
+ for wire in 0..gadget.arity() {
+ discrete_fourier_transform(&mut f[wire], &gadget.f_vals[wire], m)?;
+ discrete_fourier_transform_inv_finish(&mut f[wire], m, m_inv);
+
+ // The first point on each wire polynomial is a random value chosen by the prover. This
+ // point is stored in the proof so that the verifier can reconstruct the wire
+ // polynomials.
+ proof[proof_len + wire] = gadget.f_vals[wire][0];
+ }
+
+ // Construct the gadget polynomial `G(f[0], ..., f[g_arity-1])` and append it to `proof`.
+ let gadget_poly_len = gadget_poly_len(gadget.degree(), m);
+ let start = proof_len + gadget.arity();
+ let end = start + gadget_poly_len.next_power_of_two();
+ gadget.call_poly(&mut proof[start..end], &f)?;
+ proof_len += gadget.arity() + gadget_poly_len;
+ }
+
+ // Truncate the buffer to the size of the proof.
+ assert_eq!(proof_len, self.proof_len());
+ proof.truncate(proof_len);
+ Ok(proof)
+ }
+
+ /// Query an input and proof and return the verifier message. The return value has length
+ /// [`Self::verifier_len`].
+ ///
+ /// # Parameters
+ ///
+ /// * `input` is the input or input share.
+ /// * `proof` is the proof or proof share.
+ /// * `query_rand` is the verifier's randomness.
+ /// * `joint_rand` is the randomness shared by the prover and verifier.
+ /// * `num_shares` is the total number of input shares.
+ fn query(
+ &self,
+ input: &[Self::Field],
+ proof: &[Self::Field],
+ query_rand: &[Self::Field],
+ joint_rand: &[Self::Field],
+ num_shares: usize,
+ ) -> Result<Vec<Self::Field>, FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len()
+ )));
+ }
+
+ if proof.len() != self.proof_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected proof length: got {}; want {}",
+ proof.len(),
+ self.proof_len()
+ )));
+ }
+
+ if query_rand.len() != self.query_rand_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected query randomness length: got {}; want {}",
+ query_rand.len(),
+ self.query_rand_len()
+ )));
+ }
+
+ if joint_rand.len() != self.joint_rand_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected joint randomness length: got {}; want {}",
+ joint_rand.len(),
+ self.joint_rand_len()
+ )));
+ }
+
+ let mut proof_len = 0;
+ let mut shim = self
+ .gadget()
+ .into_iter()
+ .enumerate()
+ .map(|(idx, gadget)| {
+ let gadget_degree = gadget.degree();
+ let gadget_arity = gadget.arity();
+ let m = (1 + gadget.calls()).next_power_of_two();
+ let r = query_rand[idx];
+
+ // Make sure the query randomness isn't a root of unity. Evaluating the gadget
+ // polynomial at any of these points would be a privacy violation, since these points
+ // were used by the prover to construct the wire polynomials.
+ if r.pow(<Self::Field as FieldElement>::Integer::try_from(m).unwrap())
+ == Self::Field::one()
+ {
+ return Err(FlpError::Query(format!(
+ "invalid query randomness: encountered 2^{}-th root of unity",
+ m
+ )));
+ }
+
+ // Compute the length of the sub-proof corresponding to the `idx`-th gadget.
+ let next_len = gadget_arity + gadget_degree * (m - 1) + 1;
+ let proof_data = &proof[proof_len..proof_len + next_len];
+ proof_len += next_len;
+
+ Ok(Box::new(QueryShimGadget::new(gadget, r, proof_data)?)
+ as Box<dyn Gadget<Self::Field>>)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ // Create a buffer for the verifier data. This includes the output of the validity circuit and,
+ // for each gadget `shim[idx].inner`, the wire polynomials evaluated at the query randomness
+ // `query_rand[idx]` and the gadget polynomial evaluated at `query_rand[idx]`.
+ let data_len = 1
+ + (0..shim.len())
+ .map(|idx| shim[idx].arity() + 1)
+ .sum::<usize>();
+ let mut verifier = Vec::with_capacity(data_len);
+
+ // Run the validity circuit with a sequence of "shim" gadgets that record the inputs to each
+ // wire for each gadget call. Record the output of the circuit and append it to the verifier
+ // message.
+ //
+ // NOTE The proof of [BBC+19, Theorem 4.3] assumes that the output of the validity circuit is
+ // equal to the output of the last gadget evaluation. Here we relax this assumption. This
+ // should be OK, since it's possible to transform any circuit into one for which this is true.
+ // (Needs security analysis.)
+ let validity = self.valid(&mut shim, input, joint_rand, num_shares)?;
+ verifier.push(validity);
+
+ // Fill the buffer with the verifier message.
+ for idx in 0..shim.len() {
+ let r = query_rand[idx];
+ let gadget = shim[idx]
+ .as_any()
+ .downcast_ref::<QueryShimGadget<Self::Field>>()
+ .unwrap();
+
+ // Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire
+ // polynomial at query randomness `r`.
+ let m = (1 + gadget.calls()).next_power_of_two();
+ let m_inv =
+ Self::Field::from(<Self::Field as FieldElement>::Integer::try_from(m).unwrap())
+ .inv();
+ let mut f = vec![Self::Field::zero(); m];
+ for wire in 0..gadget.arity() {
+ discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?;
+ discrete_fourier_transform_inv_finish(&mut f, m, m_inv);
+ verifier.push(poly_eval(&f, r));
+ }
+
+ // Add the value of the gadget polynomial evaluated at `r`.
+ verifier.push(gadget.p_at_r);
+ }
+
+ assert_eq!(verifier.len(), self.verifier_len());
+ Ok(verifier)
+ }
+
+ /// Returns true if the verifier message indicates that the input from which it was generated is valid.
+ #[allow(clippy::needless_range_loop)]
+ fn decide(&self, verifier: &[Self::Field]) -> Result<bool, FlpError> {
+ if verifier.len() != self.verifier_len() {
+ return Err(FlpError::Decide(format!(
+ "unexpected verifier length: got {}; want {}",
+ verifier.len(),
+ self.verifier_len()
+ )));
+ }
+
+ // Check if the output of the circuit is 0.
+ if verifier[0] != Self::Field::zero() {
+ return Ok(false);
+ }
+
+ // Check that each of the proof polynomials are well-formed.
+ let mut gadgets = self.gadget();
+ let mut verifier_len = 1;
+ for idx in 0..gadgets.len() {
+ let next_len = 1 + gadgets[idx].arity();
+
+ let e = gadgets[idx].call(&verifier[verifier_len..verifier_len + next_len - 1])?;
+ if e != verifier[verifier_len + next_len - 1] {
+ return Ok(false);
+ }
+
+ verifier_len += next_len;
+ }
+
+ Ok(true)
+ }
+
+ /// Check whether `input` and `joint_rand` have the length expected by `self`,
+ /// return [`FlpError::Valid`] otherwise.
+ fn valid_call_check(
+ &self,
+ input: &[Self::Field],
+ joint_rand: &[Self::Field],
+ ) -> Result<(), FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Valid(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len(),
+ )));
+ }
+
+ if joint_rand.len() != self.joint_rand_len() {
+ return Err(FlpError::Valid(format!(
+ "unexpected joint randomness length: got {}; want {}",
+ joint_rand.len(),
+ self.joint_rand_len()
+ )));
+ }
+
+ Ok(())
+ }
+
+ /// Check if the length of `input` matches `self`'s `input_len()`,
+ /// return [`FlpError::Truncate`] otherwise.
+ fn truncate_call_check(&self, input: &[Self::Field]) -> Result<(), FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Truncate(format!(
+ "Unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len()
+ )));
+ }
+
+ Ok(())
+ }
+}
+
+/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit.
+pub trait Gadget<F: FieldElement>: Debug {
+ /// Evaluates the gadget on input `inp` and returns the output.
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError>;
+
+ /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`.
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>;
+
+ /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or
+ /// `call_poly`.
+ fn arity(&self) -> usize;
+
+ /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp`
+ /// buffer passed to `call_poly`.
+ fn degree(&self) -> usize;
+
+ /// Returns the number of times the gadget is expected to be called.
+ fn calls(&self) -> usize;
+
+ /// This call is used to downcast a `Box<dyn Gadget<F>>` to a concrete type.
+ fn as_any(&mut self) -> &mut dyn Any;
+}
+
+// A "shim" gadget used during proof generation to record the input wires each time a gadget is
+// evaluated.
+#[derive(Debug)]
+struct ProveShimGadget<F: FieldElement> {
+ inner: Box<dyn Gadget<F>>,
+
+ /// Points at which the wire polynomials are interpolated.
+ f_vals: Vec<Vec<F>>,
+
+ /// The number of times the gadget has been called so far.
+ ct: usize,
+}
+
+impl<F: FieldElement> ProveShimGadget<F> {
+ fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, FlpError> {
+ let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()];
+
+ #[allow(clippy::needless_range_loop)]
+ for wire in 0..f_vals.len() {
+ // Choose a random field element as the first point on the wire polynomial.
+ f_vals[wire][0] = prove_rand[wire];
+ }
+
+ Ok(Self {
+ inner,
+ f_vals,
+ ct: 1,
+ })
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for ProveShimGadget<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ #[allow(clippy::needless_range_loop)]
+ for wire in 0..inp.len() {
+ self.f_vals[wire][self.ct] = inp[wire];
+ }
+ self.ct += 1;
+ self.inner.call(inp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ self.inner.call_poly(outp, inp)
+ }
+
+ fn arity(&self) -> usize {
+ self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+// A "shim" gadget used during proof verification to record the points at which the intermediate
+// proof polynomials are evaluated.
+#[derive(Debug)]
+struct QueryShimGadget<F: FieldElement> {
+ inner: Box<dyn Gadget<F>>,
+
+ /// Points at which intermediate proof polynomials are interpolated.
+ f_vals: Vec<Vec<F>>,
+
+ /// Points at which the gadget polynomial is interpolated.
+ p_vals: Vec<F>,
+
+ /// The gadget polynomial evaluated on a random input `r`.
+ p_at_r: F,
+
+ /// Used to compute an index into `p_val`.
+ step: usize,
+
+ /// The number of times the gadget has been called so far.
+ ct: usize,
+}
+
+impl<F: FieldElement> QueryShimGadget<F> {
+ fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, FlpError> {
+ let gadget_degree = inner.degree();
+ let gadget_arity = inner.arity();
+ let m = (1 + inner.calls()).next_power_of_two();
+ let p = m * gadget_degree;
+
+ // Each call to this gadget records the values at which intermediate proof polynomials were
+ // interpolated. The first point was a random value chosen by the prover and transmitted in
+ // the proof.
+ let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; gadget_arity];
+ for wire in 0..gadget_arity {
+ f_vals[wire][0] = proof_data[wire];
+ }
+
+ // Evaluate the gadget polynomial at roots of unity.
+ let size = p.next_power_of_two();
+ let mut p_vals = vec![F::zero(); size];
+ discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?;
+
+ // The step is used to compute the element of `p_val` that will be returned by a call to
+ // the gadget.
+ let step = (1 << (log2(p as u128) - log2(m as u128))) as usize;
+
+ // Evaluate the gadget polynomial `p` at query randomness `r`.
+ let p_at_r = poly_eval(&proof_data[gadget_arity..], r);
+
+ Ok(Self {
+ inner,
+ f_vals,
+ p_vals,
+ p_at_r,
+ step,
+ ct: 1,
+ })
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for QueryShimGadget<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ #[allow(clippy::needless_range_loop)]
+ for wire in 0..inp.len() {
+ self.f_vals[wire][self.ct] = inp[wire];
+ }
+ let outp = self.p_vals[self.ct * self.step];
+ self.ct += 1;
+ Ok(outp)
+ }
+
+ fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> {
+ panic!("no-op");
+ }
+
+ fn arity(&self) -> usize {
+ self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// Compute the length of the wire polynomial constructed from the given number of gadget calls.
+#[inline]
+pub(crate) fn wire_poly_len(num_calls: usize) -> usize {
+ (1 + num_calls).next_power_of_two()
+}
+
+/// Compute the length of the gadget polynomial for a gadget with the given degree and from wire
+/// polynomials of the given length.
+#[inline]
+pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usize {
+ gadget_degree * (wire_poly_len - 1) + 1
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{random_vector, split_vector, Field128};
+ use crate::flp::gadgets::{Mul, PolyEval};
+ use crate::polynomial::poly_range_check;
+
+ use std::marker::PhantomData;
+
+ // Simple integration test for the core FLP logic. You'll find more extensive unit tests for
+ // each implemented data type in src/types.rs.
+ #[test]
+ fn test_flp() {
+ const NUM_SHARES: usize = 2;
+
+ let typ: TestType<Field128> = TestType::new();
+ let input = typ.encode_measurement(&3).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+
+ let input_shares: Vec<Vec<Field128>> = split_vector(input.as_slice(), NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ assert_eq!(proof.len(), typ.proof_len());
+
+ let proof_shares: Vec<Vec<Field128>> = split_vector(&proof, NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let verifier: Vec<Field128> = (0..NUM_SHARES)
+ .map(|i| {
+ typ.query(
+ &input_shares[i],
+ &proof_shares[i],
+ &query_rand,
+ &joint_rand,
+ NUM_SHARES,
+ )
+ .unwrap()
+ })
+ .reduce(|mut left, right| {
+ for (x, y) in left.iter_mut().zip(right.iter()) {
+ *x += *y;
+ }
+ left
+ })
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ /// A toy type used for testing multiple gadgets. Valid inputs of this type consist of a pair
+ /// of field elements `(x, y)` where `2 <= x < 5` and `x^3 == y`.
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct TestType<F>(PhantomData<F>);
+
+ impl<F> TestType<F> {
+ fn new() -> Self {
+ Self(PhantomData)
+ }
+ }
+
+ impl<F: FieldElement> Type for TestType<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ let r = joint_rand[0];
+ let mut res = F::zero();
+
+ // Check that `data[0]^3 == data[1]`.
+ let mut inp = [input[0], input[0]];
+ inp[0] = g[0].call(&inp)?;
+ inp[0] = g[0].call(&inp)?;
+ let x3_diff = inp[0] - input[1];
+ res += r * x3_diff;
+
+ // Check that `data[0]` is in the correct range.
+ let x_checked = g[1].call(&[input[0]])?;
+ res += (r * r) * x_checked;
+
+ Ok(res)
+ }
+
+ fn input_len(&self) -> usize {
+ 2
+ }
+
+ fn proof_len(&self) -> usize {
+ // First chunk
+ let mul = 2 /* gadget arity */ + 2 /* gadget degree */ * (
+ (1 + 2_usize /* gadget calls */).next_power_of_two() - 1) + 1;
+
+ // Second chunk
+ let poly = 1 /* gadget arity */ + 3 /* gadget degree */ * (
+ (1 + 1_usize /* gadget calls */).next_power_of_two() - 1) + 1;
+
+ mul + poly
+ }
+
+ fn verifier_len(&self) -> usize {
+ // First chunk
+ let mul = 1 + 2 /* gadget arity */;
+
+ // Second chunk
+ let poly = 1 + 1 /* gadget arity */;
+
+ 1 + mul + poly
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 3
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![
+ Box::new(Mul::new(2)),
+ Box::new(PolyEval::new(poly_range_check(2, 5), 1)),
+ ]
+ }
+
+ fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
+ Ok(vec![
+ F::from(*measurement),
+ F::from(*measurement).pow(F::Integer::try_from(3).unwrap()),
+ ])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ Ok(input)
+ }
+
+ fn decode_result(
+ &self,
+ _data: &[F],
+ _num_measurements: usize,
+ ) -> Result<F::Integer, FlpError> {
+ panic!("not implemented");
+ }
+ }
+
+ // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that
+ // gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than
+ // needed for computing the gadget polynomial.
+ #[test]
+ fn issue254() {
+ let typ: Issue254Type<Field128> = Issue254Type::new();
+ let input = typ.encode_measurement(&0).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct Issue254Type<F> {
+ num_gadget_calls: [usize; 2],
+ phantom: PhantomData<F>,
+ }
+
+ impl<F> Issue254Type<F> {
+ fn new() -> Self {
+ Self {
+ // The bug is triggered when there are two gadgets, but it doesn't matter how many
+ // times the second gadget is called.
+ num_gadget_calls: [100, 0],
+ phantom: PhantomData,
+ }
+ }
+ }
+
+ impl<F: FieldElement> Type for Issue254Type<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ _joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ // This is a useless circuit, as it only accepts "0". Its purpose is to exercise the
+ // use of multiple gadgets, each of which is called an arbitrary number of times.
+ let mut res = F::zero();
+ for _ in 0..self.num_gadget_calls[0] {
+ res += g[0].call(&[input[0]])?;
+ }
+ for _ in 0..self.num_gadget_calls[1] {
+ res += g[1].call(&[input[0]])?;
+ }
+ Ok(res)
+ }
+
+ fn input_len(&self) -> usize {
+ 1
+ }
+
+ fn proof_len(&self) -> usize {
+ // First chunk
+ let first = 1 /* gadget arity */ + 2 /* gadget degree */ * (
+ (1 + self.num_gadget_calls[0]).next_power_of_two() - 1) + 1;
+
+ // Second chunk
+ let second = 1 /* gadget arity */ + 2 /* gadget degree */ * (
+ (1 + self.num_gadget_calls[1]).next_power_of_two() - 1) + 1;
+
+ first + second
+ }
+
+ fn verifier_len(&self) -> usize {
+ // First chunk
+ let first = 1 + 1 /* gadget arity */;
+
+ // Second chunk
+ let second = 1 + 1 /* gadget arity */;
+
+ 1 + first + second
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 0
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ // First chunk
+ let first = 1; // gadget arity
+
+ // Second chunk
+ let second = 1; // gadget arity
+
+ first + second
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 2 // number of gadgets
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ let poly = poly_range_check(0, 2); // A polynomial with degree 2
+ vec![
+ Box::new(PolyEval::new(poly.clone(), self.num_gadget_calls[0])),
+ Box::new(PolyEval::new(poly, self.num_gadget_calls[1])),
+ ]
+ }
+
+ fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
+ Ok(vec![F::from(*measurement)])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ Ok(input)
+ }
+
+ fn decode_result(
+ &self,
+ _data: &[F],
+ _num_measurements: usize,
+ ) -> Result<F::Integer, FlpError> {
+ panic!("not implemented");
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs
new file mode 100644
index 0000000000..fd2be84eaa
--- /dev/null
+++ b/third_party/rust/prio/src/flp/gadgets.rs
@@ -0,0 +1,715 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of gadgets.
+
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
+use crate::field::FieldElement;
+use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
+use crate::polynomial::{poly_deg, poly_eval, poly_mul};
+
+#[cfg(feature = "multithreaded")]
+use rayon::prelude::*;
+
+use std::any::Any;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
+/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
+const FFT_THRESHOLD: usize = 60;
+
+/// An arity-2 gadget that multiples its inputs.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Mul<F: FieldElement> {
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> Mul<F> {
+ /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
+ /// called by the validity circuit.
+ pub fn new(num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(2, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ // Multiply input polynomials directly.
+ pub(crate) fn call_poly_direct(
+ &mut self,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+ ) -> Result<(), FlpError> {
+ let v = poly_mul(&inp[0], &inp[1]);
+ outp[..v.len()].clone_from_slice(&v);
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let mut buf = vec![F::zero(); n];
+
+ discrete_fourier_transform(&mut buf, &inp[0], n)?;
+ discrete_fourier_transform(outp, &inp[1], n)?;
+
+ for i in 0..n {
+ buf[i] *= outp[i];
+ }
+
+ discrete_fourier_transform(outp, &buf, n)?;
+ discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for Mul<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[0] * inp[1])
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ 2
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-1 gadget that evaluates its input on some polynomial.
+//
+// TODO Make `poly` an array of length determined by a const generic.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct PolyEval<F: FieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> PolyEval<F> {
+ /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
+ /// this gadget is called by the validity circuit.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+}
+
+impl<F: FieldElement> PolyEval<F> {
+ // Multiply input polynomials directly.
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ outp[0] = self.poly[0];
+ let mut x = inp[0].to_vec();
+ for i in 1..self.poly.len() {
+ for j in 0..x.len() {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ x = poly_mul(&x, &inp[0]);
+ }
+ }
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let inp = &inp[0];
+
+ let mut inp_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut inp_vals, inp, n)?;
+
+ let mut x_vals = inp_vals.clone();
+ let mut x = vec![F::zero(); n];
+ x[..inp.len()].clone_from_slice(inp);
+
+ outp[0] = self.poly[0];
+ for i in 1..self.poly.len() {
+ for j in 0..n {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ x_vals[j] *= inp_vals[j];
+ }
+
+ discrete_fourier_transform(&mut x, &x_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for PolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for item in outp.iter_mut() {
+ *item = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 1
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly)
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-2 gadget that returns `poly(in[0]) * in[1]` for some polynomial `poly`.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct BlindPolyEval<F: FieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for the outer FFT multiplication.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FieldElement> BlindPolyEval<F> {
+ /// Returns a `BlindPolyEval` gadget for polynomial `poly`.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly) + 1, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let x = &inp[0];
+ let y = &inp[1];
+
+ let mut z = y.to_vec();
+ for i in 0..self.poly.len() {
+ for j in 0..z.len() {
+ outp[j] += self.poly[i] * z[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ z = poly_mul(&z, x);
+ }
+ }
+ Ok(())
+ }
+
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let x = &inp[0];
+ let y = &inp[1];
+
+ let mut x_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut x_vals, x, n)?;
+
+ let mut z_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut z_vals, y, n)?;
+
+ let mut z = vec![F::zero(); n];
+ let mut z_len = y.len();
+ z[..y.len()].clone_from_slice(y);
+
+ for i in 0..self.poly.len() {
+ for j in 0..z_len {
+ outp[j] += self.poly[i] * z[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ z_vals[j] *= x_vals[j];
+ }
+
+ discrete_fourier_transform(&mut z, &z_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut z, n, self.n_inv);
+ z_len += x.len();
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> Gadget<F> for BlindPolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[1] * poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly) + 1
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// Marker trait for abstracting over [`ParallelSum`].
+pub trait ParallelSumGadget<F: FieldElement, G>: Gadget<F> + Debug {
+ /// Wraps `inner` into a sum gadget with `chunks` chunks
+ fn new(inner: G, chunks: usize) -> Self;
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSum<F: FieldElement, G: Gadget<F>> {
+ inner: G,
+ chunks: usize,
+ phantom: PhantomData<F>,
+}
+
+impl<F: FieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> for ParallelSum<F, G> {
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ inner,
+ chunks,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F: FieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ let mut outp = F::zero();
+ for chunk in inp.chunks(self.inner.arity()) {
+ outp += self.inner.call(chunk)?;
+ }
+ Ok(outp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ let mut partial_outp = vec![F::zero(); outp.len()];
+
+ for chunk in inp.chunks(self.inner.arity()) {
+ self.inner.call_poly(&mut partial_outp, chunk)?;
+ for i in 0..outp.len() {
+ outp[i] += partial_outp[i]
+ }
+ }
+
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.chunks * self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
+/// evaluation is multithreaded.
+#[cfg(feature = "multithreaded")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSumMultithreaded<F: FieldElement, G: Gadget<F>> {
+ serial_sum: ParallelSum<F, G>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
+where
+ F: FieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ serial_sum: ParallelSum::new(inner, chunks),
+ }
+ }
+}
+
+/// Data structures passed between fold operations in [`ParallelSumMultithreaded`].
+#[cfg(feature = "multithreaded")]
+struct ParallelSumFoldState<F, G> {
+ /// Inner gadget.
+ inner: G,
+ /// Output buffer for `call_poly()`.
+ partial_output: Vec<F>,
+ /// Sum accumulator.
+ partial_sum: Vec<F>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumFoldState<F, G> {
+ fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G>
+ where
+ G: Clone,
+ F: FieldElement,
+ {
+ ParallelSumFoldState {
+ inner: gadget.clone(),
+ partial_output: vec![F::zero(); length],
+ partial_sum: vec![F::zero(); length],
+ }
+ }
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
+where
+ F: FieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ self.serial_sum.call(inp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
+ // gadget on each input polynomial, using the first temporary buffer as an output buffer.
+ // Then accumulate that result into the second temporary buffer, which acts as a running
+ // sum. Then, discard everything but the partial sums, add them, and finally copy the sum
+ // to the output parameter. This is equivalent to the single threaded calculation in
+ // ParallelSum, since we only rearrange additions, and field addition is associative.
+ let res = inp
+ .par_chunks(self.serial_sum.inner.arity())
+ .fold(
+ || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()),
+ |mut state, chunk| {
+ state
+ .inner
+ .call_poly(&mut state.partial_output, chunk)
+ .unwrap();
+ for (sum_elem, output_elem) in state
+ .partial_sum
+ .iter_mut()
+ .zip(state.partial_output.iter())
+ {
+ *sum_elem += *output_elem;
+ }
+ state
+ },
+ )
+ .map(|state| state.partial_sum)
+ .reduce(
+ || vec![F::zero(); outp.len()],
+ |mut x, y| {
+ for (xi, yi) in x.iter_mut().zip(y.iter()) {
+ *xi += *yi;
+ }
+ x
+ },
+ );
+
+ outp.copy_from_slice(&res[..]);
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.serial_sum.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.serial_sum.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.serial_sum.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+// Check that the input parameters of g.call() are well-formed.
+fn gadget_call_check<F: FieldElement, G: Gadget<F>>(
+ gadget: &G,
+ in_len: usize,
+) -> Result<(), FlpError> {
+ if in_len != gadget.arity() {
+ return Err(FlpError::Gadget(format!(
+ "unexpected number of inputs: got {}; want {}",
+ in_len,
+ gadget.arity()
+ )));
+ }
+
+ if in_len == 0 {
+ return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
+ }
+
+ Ok(())
+}
+
+// Check that the input parameters of g.call_poly() are well-formed.
+fn gadget_call_poly_check<F: FieldElement, G: Gadget<F>>(
+ gadget: &G,
+ outp: &[F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError>
+where
+ G: Gadget<F>,
+{
+ gadget_call_check(gadget, inp.len())?;
+
+ for i in 1..inp.len() {
+ if inp[i].len() != inp[0].len() {
+ return Err(FlpError::Gadget(
+ "gadget called on wire polynomials with different lengths".to_string(),
+ ));
+ }
+ }
+
+ let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
+ if outp.len() != expected {
+ return Err(FlpError::Gadget(format!(
+ "incorrect output length: got {}; want {}",
+ outp.len(),
+ expected
+ )));
+ }
+
+ Ok(())
+}
+
+#[inline]
+fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
+ gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::field::{random_vector, Field96 as TestField};
+ use crate::prng::Prng;
+
+ #[test]
+ fn test_mul() {
+ // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
+ // naive multiplication code path.
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+
+ // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
+ // FFT-based polynomial multiplication.
+ let num_calls = FFT_THRESHOLD;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_blind_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_parallel_sum() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+ let num_calls = 10;
+ let chunks = 23;
+
+ let mut g = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_parallel_sum_multithreaded() {
+ use std::iter;
+
+ for num_calls in [1, 10, 100] {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+ let chunks = 23;
+
+ let mut g =
+ ParallelSumMultithreaded::new(BlindPolyEval::new(poly.clone(), num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+
+ // Test that the multithreaded version has the same output as the normal version.
+ let mut g_serial = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
+ assert_eq!(g.arity(), g_serial.arity());
+ assert_eq!(g.degree(), g_serial.degree());
+ assert_eq!(g.calls(), g_serial.calls());
+
+ let arity = g.arity();
+ let degree = g.degree();
+
+ // Test that both gadgets evaluate to the same value when run on scalar inputs.
+ let inp: Vec<TestField> = random_vector(arity).unwrap();
+ let result = g.call(&inp).unwrap();
+ let result_serial = g_serial.call(&inp).unwrap();
+ assert_eq!(result, result_serial);
+
+ // Test that both gadgets evaluate to the same value when run on polynomial inputs.
+ let mut poly_outp =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut poly_outp_serial =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut prng: Prng<TestField, _> = Prng::new().unwrap();
+ let poly_inp: Vec<_> = iter::repeat_with(|| {
+ iter::repeat_with(|| prng.get())
+ .take(1 + num_calls)
+ .collect::<Vec<_>>()
+ })
+ .take(arity)
+ .collect();
+
+ g.call_poly(&mut poly_outp, &poly_inp).unwrap();
+ g_serial
+ .call_poly(&mut poly_outp_serial, &poly_inp)
+ .unwrap();
+ assert_eq!(poly_outp, poly_outp_serial);
+ }
+ }
+
+ // Test that calling g.call_poly() and evaluating the output at a given point is equivalent
+ // to evaluating each of the inputs at the same point and applying g.call() on the results.
+ fn gadget_test<F: FieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
+ let wire_poly_len = (1 + num_calls).next_power_of_two();
+ let mut prng = Prng::new().unwrap();
+ let mut inp = vec![F::zero(); g.arity()];
+ let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
+ let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];
+
+ let r = prng.get();
+ for i in 0..g.arity() {
+ for j in 0..wire_poly_len {
+ wire_polys[i][j] = prng.get();
+ }
+ inp[i] = poly_eval(&wire_polys[i], r);
+ }
+
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ let want = g.call(&inp).unwrap();
+ assert_eq!(got, want);
+
+ // Repeat the call to make sure that the gadget's memory is reset properly between calls.
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ assert_eq!(got, want);
+ }
+}
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs
new file mode 100644
index 0000000000..83b9752f69
--- /dev/null
+++ b/third_party/rust/prio/src/flp/types.rs
@@ -0,0 +1,1199 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of [`Type`](crate::flp::Type) implementations.
+
+use crate::field::{FieldElement, FieldElementExt};
+use crate::flp::gadgets::{BlindPolyEval, Mul, ParallelSumGadget, PolyEval};
+use crate::flp::{FlpError, Gadget, Type};
+use crate::polynomial::poly_range_check;
+use std::convert::TryInto;
+use std::marker::PhantomData;
+
+/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of
+/// the measurements (i.e., the total number of `1s`).
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Count<F> {
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Count<F> {
+ /// Return a new [`Count`] type instance.
+ pub fn new() -> Self {
+ Self {
+ range_checker: poly_range_check(0, 2),
+ }
+ }
+}
+
+impl<F: FieldElement> Default for Count<F> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<F: FieldElement> Type for Count<F> {
+ const ID: u32 = 0x00000000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let max = F::valid_integer_try_from(1)?;
+ if *value > max {
+ return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
+ }
+
+ Ok(vec![F::from(*value)])
+ }
+
+ fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
+ decode_result(data)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(Mul::new(1))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ Ok(g[0].call(&[input[0], input[0]])? - input[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ 1
+ }
+
+ fn proof_len(&self) -> usize {
+ 5
+ }
+
+ fn verifier_len(&self) -> usize {
+ 4
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 0
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of
+/// the measurements.
+///
+/// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3].
+///
+/// [BBCG+19]: https://ia.cr/2019/188
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Sum<F: FieldElement> {
+ bits: usize,
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Sum<F> {
+ /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0,
+ /// 2^bits)`.
+ pub fn new(bits: usize) -> Result<Self, FlpError> {
+ if !F::valid_integer_bitlength(bits) {
+ return Err(FlpError::Encode(
+ "invalid bits: number of bits exceeds maximum number of bits in this field"
+ .to_string(),
+ ));
+ }
+ Ok(Self {
+ bits,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FieldElement> Type for Sum<F> {
+ const ID: u32 = 0x00000001;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ Ok(v)
+ }
+
+ fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
+ decode_result(data)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.clone(),
+ self.bits,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let res = F::decode_from_bitvector_representation(&input)?;
+ Ok(vec![res])
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.bits).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ 1
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
+/// aggregate is the arithmetic average.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Average<F: FieldElement> {
+ bits: usize,
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Average<F> {
+ /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0,
+ /// 2^bits)`.
+ pub fn new(bits: usize) -> Result<Self, FlpError> {
+ if !F::valid_integer_bitlength(bits) {
+ return Err(FlpError::Encode(
+ "invalid bits: number of bits exceeds maximum number of bits in this field"
+ .to_string(),
+ ));
+ }
+ Ok(Self {
+ bits,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FieldElement> Type for Average<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = f64;
+ type Field = F;
+
+ fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ Ok(v)
+ }
+
+ fn decode_result(&self, data: &[F], num_measurements: usize) -> Result<f64, FlpError> {
+ // Compute the average from the aggregated sum.
+ let data = decode_result(data)?;
+ let data: u64 = data.try_into().map_err(|err| {
+ FlpError::Decode(format!("failed to convert {:?} to u64: {}", data, err,))
+ })?;
+ let result = (data as f64) / (num_measurements as f64);
+ Ok(result)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.clone(),
+ self.bits,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let res = F::decode_from_bitvector_representation(&input)?;
+ Ok(vec![res])
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.bits).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ 1
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// The histogram type. Each measurement is a non-negative integer and the aggregate is a histogram
+/// approximating the distribution of the measurements.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Histogram<F: FieldElement> {
+ buckets: Vec<F::Integer>,
+ range_checker: Vec<F>,
+}
+
+impl<F: FieldElement> Histogram<F> {
+ /// Return a new [`Histogram`] type with the given buckets.
+ pub fn new(buckets: Vec<F::Integer>) -> Result<Self, FlpError> {
+ if buckets.len() >= u32::MAX as usize {
+ return Err(FlpError::Encode(
+ "invalid buckets: number of buckets exceeds maximum permitted".to_string(),
+ ));
+ }
+
+ if !buckets.is_empty() {
+ for i in 0..buckets.len() - 1 {
+ if buckets[i + 1] <= buckets[i] {
+ return Err(FlpError::Encode(
+ "invalid buckets: out-of-order boundary".to_string(),
+ ));
+ }
+ }
+ }
+
+ Ok(Self {
+ buckets,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FieldElement> Type for Histogram<F> {
+ const ID: u32 = 0x00000002;
+ type Measurement = F::Integer;
+ type AggregateResult = Vec<F::Integer>;
+ type Field = F;
+
+ fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let mut data = vec![F::zero(); self.buckets.len() + 1];
+
+ let bucket = match self.buckets.binary_search(measurement) {
+ Ok(i) => i, // on a bucket boundary
+ Err(i) => i, // smaller than the i-th bucket boundary
+ };
+
+ data[bucket] = F::one();
+ Ok(data)
+ }
+
+ fn decode_result(
+ &self,
+ data: &[F],
+ _num_measurements: usize,
+ ) -> Result<Vec<F::Integer>, FlpError> {
+ decode_result_vec(data, self.buckets.len() + 1)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.to_vec(),
+ self.input_len(),
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ // Check that each element of `input` is a 0 or 1.
+ let range_check = call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])?;
+
+ // Check that the elements of `input` sum to 1.
+ let mut sum_check = -(F::one() / F::from(F::valid_integer_try_from(num_shares)?));
+ for val in input.iter() {
+ sum_check += *val;
+ }
+
+ // Take a random linear combination of both checks.
+ let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check;
+ Ok(out)
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ self.buckets.len() + 1
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.input_len()).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// A sequence of counters. This type uses a neat trick from [[BBCG+19], Corollary 4.9] to reduce
+/// the proof size to roughly the square root of the input size.
+///
+/// [BBCG+19]: https://eprint.iacr.org/2019/188
+#[derive(Debug, PartialEq, Eq)]
+pub struct CountVec<F, S> {
+ range_checker: Vec<F>,
+ len: usize,
+ chunk_len: usize,
+ gadget_calls: usize,
+ phantom: PhantomData<S>,
+}
+
+impl<F: FieldElement, S: ParallelSumGadget<F, BlindPolyEval<F>>> CountVec<F, S> {
+ /// Returns a new [`CountVec`] with the given length.
+ pub fn new(len: usize) -> Self {
+ // The optimal chunk length is the square root of the input length. If the input length is
+ // not a perfect square, then round down. If the result is 0, then let the chunk length be
+ // 1 so that the underlying gadget can still be called.
+ let chunk_len = std::cmp::max(1, (len as f64).sqrt() as usize);
+
+ let mut gadget_calls = len / chunk_len;
+ if len % chunk_len != 0 {
+ gadget_calls += 1;
+ }
+
+ Self {
+ range_checker: poly_range_check(0, 2),
+ len,
+ chunk_len,
+ gadget_calls,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F: FieldElement, S> Clone for CountVec<F, S> {
+ fn clone(&self) -> Self {
+ Self {
+ range_checker: self.range_checker.clone(),
+ len: self.len,
+ chunk_len: self.chunk_len,
+ gadget_calls: self.gadget_calls,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F, S> Type for CountVec<F, S>
+where
+ F: FieldElement,
+ S: ParallelSumGadget<F, BlindPolyEval<F>> + Eq + 'static,
+{
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = Vec<F::Integer>;
+ type AggregateResult = Vec<F::Integer>;
+ type Field = F;
+
+ fn encode_measurement(&self, measurement: &Vec<F::Integer>) -> Result<Vec<F>, FlpError> {
+ if measurement.len() != self.len {
+ return Err(FlpError::Encode(format!(
+ "unexpected measurement length: got {}; want {}",
+ measurement.len(),
+ self.len
+ )));
+ }
+
+ let max = F::Integer::from(F::one());
+ for value in measurement {
+ if *value > max {
+ return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
+ }
+ }
+
+ Ok(measurement.iter().map(|value| F::from(*value)).collect())
+ }
+
+ fn decode_result(
+ &self,
+ data: &[F],
+ _num_measurements: usize,
+ ) -> Result<Vec<F::Integer>, FlpError> {
+ decode_result_vec(data, self.len)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(S::new(
+ BlindPolyEval::new(self.range_checker.clone(), self.gadget_calls),
+ self.chunk_len,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ let s = F::from(F::valid_integer_try_from(num_shares)?).inv();
+ let mut r = joint_rand[0];
+ let mut outp = F::zero();
+ let mut padded_chunk = vec![F::zero(); 2 * self.chunk_len];
+ for chunk in input.chunks(self.chunk_len) {
+ let d = chunk.len();
+ for i in 0..self.chunk_len {
+ if i < d {
+ padded_chunk[2 * i] = chunk[i];
+ } else {
+ // If the chunk is smaller than the chunk length, then copy the last element of
+ // the chunk into the remaining slots.
+ padded_chunk[2 * i] = chunk[d - 1];
+ }
+ padded_chunk[2 * i + 1] = r * s;
+ r *= joint_rand[0];
+ }
+
+ outp += g[0].call(&padded_chunk)?;
+ }
+
+ Ok(outp)
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ self.len
+ }
+
+ fn proof_len(&self) -> usize {
+ (self.chunk_len * 2) + 3 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1
+ }
+
+ fn verifier_len(&self) -> usize {
+ 2 + self.chunk_len * 2
+ }
+
+ fn output_len(&self) -> usize {
+ self.len
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ self.chunk_len * 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// Compute a random linear combination of the result of calls of `g` on each element of `input`.
+///
+/// # Arguments
+///
+/// * `g` - The gadget to be applied elementwise
+/// * `input` - The vector on whose elements to apply `g`
+/// * `rnd` - The randomness used for the linear combination
+pub(crate) fn call_gadget_on_vec_entries<F: FieldElement>(
+ g: &mut Box<dyn Gadget<F>>,
+ input: &[F],
+ rnd: F,
+) -> Result<F, FlpError> {
+ let mut range_check = F::zero();
+ let mut r = rnd;
+ for chunk in input.chunks(1) {
+ range_check += r * g.call(chunk)?;
+ r *= rnd;
+ }
+ Ok(range_check)
+}
+
+/// Given a vector `data` of field elements which should contain exactly one entry, return the
+/// integer representation of that entry.
+pub(crate) fn decode_result<F: FieldElement>(data: &[F]) -> Result<F::Integer, FlpError> {
+ if data.len() != 1 {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ Ok(F::Integer::from(data[0]))
+}
+
+/// Given a vector `data` of field elements, return a vector containing the corresponding integer
+/// representations, if the number of entries matches `expected_len`.
+pub(crate) fn decode_result_vec<F: FieldElement>(
+ data: &[F],
+ expected_len: usize,
+) -> Result<Vec<F::Integer>, FlpError> {
+ if data.len() != expected_len {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ Ok(data.iter().map(|elem| F::Integer::from(*elem)).collect())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{random_vector, split_vector, Field64 as TestField};
+ use crate::flp::gadgets::ParallelSum;
+ #[cfg(feature = "multithreaded")]
+ use crate::flp::gadgets::ParallelSumMultithreaded;
+
+ // Number of shares to split input and proofs into in `flp_test`.
+ const NUM_SHARES: usize = 3;
+
+ struct ValidityTestCase<F> {
+ expect_valid: bool,
+ expected_output: Option<Vec<F>>,
+ }
+
+ #[test]
+ fn test_count() {
+ let count: Count<TestField> = Count::new();
+ let zero = TestField::zero();
+ let one = TestField::one();
+
+ // Round trip
+ assert_eq!(
+ count
+ .decode_result(
+ &count
+ .truncate(count.encode_measurement(&1).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ 1,
+ );
+
+ // Test FLP on valid input.
+ flp_validity_test(
+ &count,
+ &count.encode_measurement(&1).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &count,
+ &count.encode_measurement(&0).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero]),
+ },
+ )
+ .unwrap();
+
+ // Test FLP on invalid input.
+ flp_validity_test(
+ &count,
+ &[TestField::from(1337)],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ // Try running the validity circuit on an input that's too short.
+ count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err();
+ count
+ .valid(&mut count.gadget(), &[1.into(), 2.into()], &[], 1)
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_sum() {
+ let sum = Sum::new(11).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ // Round trip
+ assert_eq!(
+ sum.decode_result(
+ &sum.truncate(sum.encode_measurement(&27).unwrap()).unwrap(),
+ 1
+ )
+ .unwrap(),
+ 27,
+ );
+
+ // Test FLP on valid input.
+ flp_validity_test(
+ &sum,
+ &sum.encode_measurement(&1337).unwrap(),
+ &ValidityTestCase {
+ expect_valid: true,
+ expected_output: Some(vec![TestField::from(1337)]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(0).unwrap(),
+ &[],
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(2).unwrap(),
+ &[one, zero],
+ &ValidityTestCase {
+ expect_valid: true,
+ expected_output: Some(vec![one]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(9).unwrap(),
+ &[one, zero, one, one, zero, one, one, one, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![TestField::from(237)]),
+ },
+ )
+ .unwrap();
+
+ // Test FLP on invalid input.
+ flp_validity_test(
+ &Sum::new(3).unwrap(),
+ &[one, nine, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(5).unwrap(),
+ &[zero, zero, zero, zero, nine],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_average() {
+ let average = Average::new(11).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let ten = TestField::from(10);
+
+ // Testing that average correctly quotients the sum of the measurements
+ // by the number of measurements.
+ assert_eq!(average.decode_result(&[zero], 1).unwrap(), 0.0);
+ assert_eq!(average.decode_result(&[one], 1).unwrap(), 1.0);
+ assert_eq!(average.decode_result(&[one], 2).unwrap(), 0.5);
+ assert_eq!(average.decode_result(&[one], 4).unwrap(), 0.25);
+ assert_eq!(average.decode_result(&[ten], 8).unwrap(), 1.25);
+
+ // round trip of 12 with `num_measurements`=1
+ assert_eq!(
+ average
+ .decode_result(
+ &average
+ .truncate(average.encode_measurement(&12).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ 12.0
+ );
+
+ // round trip of 12 with `num_measurements`=24
+ assert_eq!(
+ average
+ .decode_result(
+ &average
+ .truncate(average.encode_measurement(&12).unwrap())
+ .unwrap(),
+ 24
+ )
+ .unwrap(),
+ 0.5
+ );
+ }
+
+ #[test]
+ fn test_histogram() {
+ let hist = Histogram::new(vec![10, 20]).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ assert_eq!(&hist.encode_measurement(&7).unwrap(), &[one, zero, zero]);
+ assert_eq!(&hist.encode_measurement(&10).unwrap(), &[one, zero, zero]);
+ assert_eq!(&hist.encode_measurement(&17).unwrap(), &[zero, one, zero]);
+ assert_eq!(&hist.encode_measurement(&20).unwrap(), &[zero, one, zero]);
+ assert_eq!(&hist.encode_measurement(&27).unwrap(), &[zero, zero, one]);
+
+ // Round trip
+ assert_eq!(
+ hist.decode_result(
+ &hist
+ .truncate(hist.encode_measurement(&27).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ [0, 0, 1]
+ );
+
+ // Invalid bucket boundaries.
+ Histogram::<TestField>::new(vec![10, 0]).unwrap_err();
+ Histogram::<TestField>::new(vec![10, 10]).unwrap_err();
+
+ // Test valid inputs.
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&0).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one, zero, zero]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&17).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero, one, zero]),
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&1337).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero, zero, one]),
+ },
+ )
+ .unwrap();
+
+ // Test invalid inputs.
+ flp_validity_test(
+ &hist,
+ &[zero, zero, nine],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[zero, one, one],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[one, one, one],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[zero, zero, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+ }
+
+ fn test_count_vec<F, S>(f: F)
+ where
+ F: Fn(usize) -> CountVec<TestField, S>,
+ S: 'static + ParallelSumGadget<TestField, BlindPolyEval<TestField>> + Eq,
+ {
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ // Test on valid inputs.
+ for len in 0..10 {
+ let count_vec = f(len);
+ flp_validity_test(
+ &count_vec,
+ &count_vec.encode_measurement(&vec![1; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one; len]),
+ },
+ )
+ .unwrap();
+ }
+
+ let len = 100;
+ let count_vec = f(len);
+ flp_validity_test(
+ &count_vec,
+ &count_vec.encode_measurement(&vec![1; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one; len]),
+ },
+ )
+ .unwrap();
+
+ // Test on invalid inputs.
+ for len in 1..10 {
+ let count_vec = f(len);
+ flp_validity_test(
+ &count_vec,
+ &vec![nine; len],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ },
+ )
+ .unwrap();
+ }
+
+ // Round trip
+ let want = vec![1; len];
+ assert_eq!(
+ count_vec
+ .decode_result(
+ &count_vec
+ .truncate(count_vec.encode_measurement(&want).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ want
+ );
+ }
+
+ #[test]
+ fn test_count_vec_serial() {
+ test_count_vec(CountVec::<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>>::new)
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_count_vec_parallel() {
+ test_count_vec(CountVec::<TestField, ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>>::new)
+ }
+
+ #[test]
+ fn count_vec_serial_long() {
+ let typ: CountVec<TestField, ParallelSum<TestField, BlindPolyEval<TestField>>> =
+ CountVec::new(1000);
+ let input = typ.encode_measurement(&vec![0; 1000]).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn count_vec_parallel_long() {
+ let typ: CountVec<
+ TestField,
+ ParallelSumMultithreaded<TestField, BlindPolyEval<TestField>>,
+ > = CountVec::new(1000);
+ let input = typ.encode_measurement(&vec![0; 1000]).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ fn flp_validity_test<T: Type>(
+ typ: &T,
+ input: &[T::Field],
+ t: &ValidityTestCase<T::Field>,
+ ) -> Result<(), FlpError> {
+ let mut gadgets = typ.gadget();
+
+ if input.len() != typ.input_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ typ.input_len()
+ )));
+ }
+
+ if typ.query_rand_len() != gadgets.len() {
+ return Err(FlpError::Test(format!(
+ "query rand length: got {}; want {}",
+ typ.query_rand_len(),
+ gadgets.len()
+ )));
+ }
+
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+
+ // Run the validity circuit.
+ let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?;
+ if v != T::Field::zero() && t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "expected valid input: valid() returned {}",
+ v
+ )));
+ }
+ if v == T::Field::zero() && !t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "expected invalid input: valid() returned {}",
+ v
+ )));
+ }
+
+ // Generate the proof.
+ let proof = typ.prove(input, &prove_rand, &joint_rand)?;
+ if proof.len() != typ.proof_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected proof length: got {}; want {}",
+ proof.len(),
+ typ.proof_len()
+ )));
+ }
+
+ // Query the proof.
+ let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?;
+ if verifier.len() != typ.verifier_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected verifier length: got {}; want {}",
+ verifier.len(),
+ typ.verifier_len()
+ )));
+ }
+
+ // Decide if the input is valid.
+ let res = typ.decide(&verifier)?;
+ if res != t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "decision is {}; want {}",
+ res, t.expect_valid,
+ )));
+ }
+
+ // Run distributed FLP.
+ let input_shares: Vec<Vec<T::Field>> = split_vector(input, NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let verifier: Vec<T::Field> = (0..NUM_SHARES)
+ .map(|i| {
+ typ.query(
+ &input_shares[i],
+ &proof_shares[i],
+ &query_rand,
+ &joint_rand,
+ NUM_SHARES,
+ )
+ .unwrap()
+ })
+ .reduce(|mut left, right| {
+ for (x, y) in left.iter_mut().zip(right.iter()) {
+ *x += *y;
+ }
+ left
+ })
+ .unwrap();
+
+ let res = typ.decide(&verifier)?;
+ if res != t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "distributed decision is {}; want {}",
+ res, t.expect_valid,
+ )));
+ }
+
+ // Try verifying various proof mutants.
+ for i in 0..proof.len() {
+ let mut mutated_proof = proof.clone();
+ mutated_proof[i] += T::Field::one();
+ let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?;
+ if typ.decide(&verifier)? {
+ return Err(FlpError::Test(format!(
+ "decision for proof mutant {} is {}; want {}",
+ i, true, false,
+ )));
+ }
+ }
+
+ // Try verifying a proof that is too short.
+ let mut mutated_proof = proof.clone();
+ mutated_proof.truncate(gadgets[0].arity() - 1);
+ if typ
+ .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .is_ok()
+ {
+ return Err(FlpError::Test(
+ "query on short proof succeeded; want failure".to_string(),
+ ));
+ }
+
+ // Try verifying a proof that is too long.
+ let mut mutated_proof = proof;
+ mutated_proof.extend_from_slice(&[T::Field::one(); 17]);
+ if typ
+ .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .is_ok()
+ {
+ return Err(FlpError::Test(
+ "query on long proof succeeded; want failure".to_string(),
+ ));
+ }
+
+ if let Some(ref want) = t.expected_output {
+ let got = typ.truncate(input.to_vec())?;
+
+ if got.len() != typ.output_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected output length: got {}; want {}",
+ got.len(),
+ typ.output_len()
+ )));
+ }
+
+ if &got != want {
+ return Err(FlpError::Test(format!(
+ "unexpected output: got {:?}; want {:?}",
+ got, want
+ )));
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/prio/src/fp.rs b/third_party/rust/prio/src/fp.rs
new file mode 100644
index 0000000000..d828fb7daf
--- /dev/null
+++ b/third_party/rust/prio/src/fp.rs
@@ -0,0 +1,561 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Finite field arithmetic for any field GF(p) for which p < 2^128.
+
+#[cfg(test)]
+use rand::{prelude::*, Rng};
+
+/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots
+/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This
+/// is the largest input size we would ever need for the cryptographic applications in this crate.
+pub(crate) const MAX_ROOTS: usize = 20;
+
+/// This structure represents the parameters of a finite field GF(p) for which p < 2^128.
+#[derive(Debug, PartialEq, Eq)]
+pub(crate) struct FieldParameters {
+ /// The prime modulus `p`.
+ pub p: u128,
+ /// `mu = -p^(-1) mod 2^64`.
+ pub mu: u64,
+ /// `r2 = (2^128)^2 mod p`.
+ pub r2: u128,
+ /// The `2^num_roots`-th -principal root of unity. This element is used to generate the
+ /// elements of `roots`.
+ pub g: u128,
+ /// The number of principal roots of unity in `roots`.
+ pub num_roots: usize,
+ /// Equal to `2^b - 1`, where `b` is the length of `p` in bits.
+ pub bit_mask: u128,
+ /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the
+ /// multiplicative group. `roots[0]` is equal to one by definition.
+ pub roots: [u128; MAX_ROOTS + 1],
+}
+
+impl FieldParameters {
+ /// Addition. The result will be in [0, p), so long as both x and y are as well.
+ pub fn add(&self, x: u128, y: u128) -> u128 {
+ // 0,x
+ // + 0,y
+ // =====
+ // c,z
+ let (z, carry) = x.overflowing_add(y);
+ // c, z
+ // - 0, p
+ // ========
+ // b1,s1,s0
+ let (s0, b0) = z.overflowing_sub(self.p);
+ let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128);
+ // if b1 == 1: return z
+ // else: return s0
+ let m = 0u128.wrapping_sub(b1 as u128);
+ (z & m) | (s0 & !m)
+ }
+
+ /// Subtraction. The result will be in [0, p), so long as both x and y are as well.
+ pub fn sub(&self, x: u128, y: u128) -> u128 {
+ // 0, x
+ // - 0, y
+ // ========
+ // b1,z1,z0
+ let (z0, b0) = x.overflowing_sub(y);
+ let (_z1, b1) = 0u128.overflowing_sub(b0 as u128);
+ let m = 0u128.wrapping_sub(b1 as u128);
+ // z1,z0
+ // + 0, p
+ // ========
+ // s1,s0
+ z0.wrapping_add(m & self.p)
+ // if b1 == 1: return s0
+ // else: return z0
+ }
+
+ /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm
+ /// described
+ /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf).
+ /// The result will be in [0, p).
+ ///
+ /// # Example usage
+ /// ```text
+ /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46);
+ /// ```
+ pub fn mul(&self, x: u128, y: u128) -> u128 {
+ let x = [lo64(x), hi64(x)];
+ let y = [lo64(y), hi64(y)];
+ let p = [lo64(self.p), hi64(self.p)];
+ let mut zz = [0; 4];
+
+ // Integer multiplication
+ // z = x * y
+
+ // x1,x0
+ // * y1,y0
+ // ===========
+ // z3,z2,z1,z0
+ let mut result = x[0] * y[0];
+ let mut carry = hi64(result);
+ zz[0] = lo64(result);
+ result = x[0] * y[1];
+ let mut hi = hi64(result);
+ let mut lo = lo64(result);
+ result = lo + carry;
+ zz[1] = lo64(result);
+ let mut cc = hi64(result);
+ result = hi + cc;
+ zz[2] = lo64(result);
+
+ result = x[1] * y[0];
+ hi = hi64(result);
+ lo = lo64(result);
+ result = zz[1] + lo;
+ zz[1] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ carry = lo64(result);
+
+ result = x[1] * y[1];
+ hi = hi64(result);
+ lo = lo64(result);
+ result = lo + carry;
+ lo = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ hi = lo64(result);
+ result = zz[2] + lo;
+ zz[2] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ zz[3] = lo64(result);
+
+ // Montgomery Reduction
+ // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64.
+
+ // z3,z2,z1,z0
+ // + p1,p0
+ // * w = mu*z0
+ // ===========
+ // z3,z2,z1, 0
+ let w = self.mu.wrapping_mul(zz[0] as u64);
+ result = p[0] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = zz[0] + lo;
+ zz[0] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ carry = lo64(result);
+
+ result = p[1] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = lo + carry;
+ lo = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ hi = lo64(result);
+ result = zz[1] + lo;
+ zz[1] = lo64(result);
+ cc = hi64(result);
+ result = zz[2] + hi + cc;
+ zz[2] = lo64(result);
+ cc = hi64(result);
+ result = zz[3] + cc;
+ zz[3] = lo64(result);
+
+ // z3,z2,z1
+ // + p1,p0
+ // * w = mu*z1
+ // ===========
+ // z3,z2, 0
+ let w = self.mu.wrapping_mul(zz[1] as u64);
+ result = p[0] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = zz[1] + lo;
+ zz[1] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ carry = lo64(result);
+
+ result = p[1] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = lo + carry;
+ lo = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ hi = lo64(result);
+ result = zz[2] + lo;
+ zz[2] = lo64(result);
+ cc = hi64(result);
+ result = zz[3] + hi + cc;
+ zz[3] = lo64(result);
+ cc = hi64(result);
+
+ // z = (z3,z2)
+ let prod = zz[2] | (zz[3] << 64);
+
+ // Final subtraction
+ // If z >= p, then z = z - p
+
+ // 0, z
+ // - 0, p
+ // ========
+ // b1,s1,s0
+ let (s0, b0) = prod.overflowing_sub(self.p);
+ let (_s1, b1) = (cc as u128).overflowing_sub(b0 as u128);
+ // if b1 == 1: return z
+ // else: return s0
+ let mask = 0u128.wrapping_sub(b1 as u128);
+ (prod & mask) | (s0 & !mask)
+ }
+
+ /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the
+ /// runtime of this algorithm is linear in the bit length of `exp`.
+ pub fn pow(&self, x: u128, exp: u128) -> u128 {
+ let mut t = self.montgomery(1);
+ for i in (0..128 - exp.leading_zeros()).rev() {
+ t = self.mul(t, t);
+ if (exp >> i) & 1 != 0 {
+ t = self.mul(t, x);
+ }
+ }
+ t
+ }
+
+ /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of
+ /// this algorithm is linear in the bit length of `p`.
+ pub fn inv(&self, x: u128) -> u128 {
+ self.pow(x, self.p - 2)
+ }
+
+ /// Negation, i.e., `-x (mod p)` where `p` is the modulus.
+ pub fn neg(&self, x: u128) -> u128 {
+ self.sub(0, x)
+ }
+
+ /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery
+ /// domain in order to carry out field arithmetic. The result will be in [0, p).
+ ///
+ /// # Example usage
+ /// ```text
+ /// let integer = 1; // Standard integer representation
+ /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain
+ /// assert_eq!(elem, 2564090464);
+ /// ```
+ pub fn montgomery(&self, x: u128) -> u128 {
+ modp(self.mul(x, self.r2), self.p)
+ }
+
+ /// Returns a random field element mapped.
+ #[cfg(test)]
+ pub fn rand_elem<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 {
+ let uniform = rand::distributions::Uniform::from(0..self.p);
+ self.montgomery(uniform.sample(rng))
+ }
+
+ /// Maps a field element to its representation as an integer. The result will be in [0, p).
+ ///
+ /// #Example usage
+ /// ```text
+ /// let elem = 2564090464; // Internal representation in the Montgomery domain
+ /// let integer = fp.residue(elem); // Standard integer representation
+ /// assert_eq!(integer, 1);
+ /// ```
+ pub fn residue(&self, x: u128) -> u128 {
+ modp(self.mul(x, 1), self.p)
+ }
+
+ #[cfg(test)]
+ pub fn check(&self, p: u128, g: u128, order: u128) {
+ use modinverse::modinverse;
+ use num_bigint::{BigInt, ToBigInt};
+ use std::cmp::max;
+
+ assert_eq!(self.p, p, "p mismatch");
+
+ let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) {
+ Some(mu) => mu as u64,
+ None => panic!("inverse of -p (mod 2^64) is undefined"),
+ };
+ assert_eq!(self.mu, mu, "mu mismatch");
+
+ let big_p = &p.to_bigint().unwrap();
+ let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p);
+ let big_r2: &BigInt = &(&(big_r * big_r) % big_p);
+ let mut it = big_r2.iter_u64_digits();
+ let mut r2 = 0;
+ r2 |= it.next().unwrap() as u128;
+ if let Some(x) = it.next() {
+ r2 |= (x as u128) << 64;
+ }
+ assert_eq!(self.r2, r2, "r2 mismatch");
+
+ assert_eq!(self.g, self.montgomery(g), "g mismatch");
+ assert_eq!(
+ self.residue(self.pow(self.g, order)),
+ 1,
+ "g order incorrect"
+ );
+
+ let num_roots = log2(order) as usize;
+ assert_eq!(order, 1 << num_roots, "order not a power of 2");
+ assert_eq!(self.num_roots, num_roots, "num_roots mismatch");
+
+ let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1];
+ roots[num_roots] = self.montgomery(g);
+ for i in (0..num_roots).rev() {
+ roots[i] = self.mul(roots[i + 1], roots[i + 1]);
+ }
+ assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch");
+ assert_eq!(self.residue(self.roots[0]), 1, "first root is not one");
+
+ let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1);
+ assert_eq!(
+ self.bit_mask.to_bigint().unwrap(),
+ bit_mask,
+ "bit_mask mismatch"
+ );
+ }
+}
+
+fn lo64(x: u128) -> u128 {
+ x & ((1 << 64) - 1)
+}
+
+fn hi64(x: u128) -> u128 {
+ x >> 64
+}
+
+fn modp(x: u128, p: u128) -> u128 {
+ let (z, carry) = x.overflowing_sub(p);
+ let m = 0u128.wrapping_sub(carry as u128);
+ z.wrapping_add(m & p)
+}
+
+pub(crate) const FP32: FieldParameters = FieldParameters {
+ p: 4293918721, // 32-bit prime
+ mu: 17302828673139736575,
+ r2: 1676699750,
+ g: 1074114499,
+ num_roots: 20,
+ bit_mask: 4294967295,
+ roots: [
+ 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825,
+ 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415,
+ 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499,
+ ],
+};
+
+pub(crate) const FP64: FieldParameters = FieldParameters {
+ p: 18446744069414584321, // 64-bit prime
+ mu: 18446744069414584319,
+ r2: 4294967295,
+ g: 959634606461954525,
+ num_roots: 32,
+ bit_mask: 18446744073709551615,
+ roots: [
+ 18446744065119617025,
+ 4294967296,
+ 18446462594437939201,
+ 72057594037927936,
+ 1152921504338411520,
+ 16384,
+ 18446743519658770561,
+ 18446735273187346433,
+ 6519596376689022014,
+ 9996039020351967275,
+ 15452408553935940313,
+ 15855629130643256449,
+ 8619522106083987867,
+ 13036116919365988132,
+ 1033106119984023956,
+ 16593078884869787648,
+ 16980581328500004402,
+ 12245796497946355434,
+ 8709441440702798460,
+ 8611358103550827629,
+ 8120528636261052110,
+ ],
+};
+
+pub(crate) const FP96: FieldParameters = FieldParameters {
+ p: 79228148845226978974766202881, // 96-bit prime
+ mu: 18446744073709551615,
+ r2: 69162923446439011319006025217,
+ g: 11329412859948499305522312170,
+ num_roots: 64,
+ bit_mask: 79228162514264337593543950335,
+ roots: [
+ 10128756682736510015896859,
+ 79218020088544242464750306022,
+ 9188608122889034248261485869,
+ 10170869429050723924726258983,
+ 36379376833245035199462139324,
+ 20898601228930800484072244511,
+ 2845758484723985721473442509,
+ 71302585629145191158180162028,
+ 76552499132904394167108068662,
+ 48651998692455360626769616967,
+ 36570983454832589044179852640,
+ 72716740645782532591407744342,
+ 73296872548531908678227377531,
+ 14831293153408122430659535205,
+ 61540280632476003580389854060,
+ 42256269782069635955059793151,
+ 51673352890110285959979141934,
+ 43102967204983216507957944322,
+ 3990455111079735553382399289,
+ 68042997008257313116433801954,
+ 44344622755749285146379045633,
+ ],
+};
+
+pub(crate) const FP128: FieldParameters = FieldParameters {
+ p: 340282366920938462946865773367900766209, // 128-bit prime
+ mu: 18446744073709551615,
+ r2: 403909908237944342183153,
+ g: 107630958476043550189608038630704257141,
+ num_roots: 66,
+ bit_mask: 340282366920938463463374607431768211455,
+ roots: [
+ 516508834063867445247,
+ 340282366920938462430356939304033320962,
+ 129526470195413442198896969089616959958,
+ 169031622068548287099117778531474117974,
+ 81612939378432101163303892927894236156,
+ 122401220764524715189382260548353967708,
+ 199453575871863981432000940507837456190,
+ 272368408887745135168960576051472383806,
+ 24863773656265022616993900367764287617,
+ 257882853788779266319541142124730662203,
+ 323732363244658673145040701829006542956,
+ 57532865270871759635014308631881743007,
+ 149571414409418047452773959687184934208,
+ 177018931070866797456844925926211239962,
+ 268896136799800963964749917185333891349,
+ 244556960591856046954834420512544511831,
+ 118945432085812380213390062516065622346,
+ 202007153998709986841225284843501908420,
+ 332677126194796691532164818746739771387,
+ 258279638927684931537542082169183965856,
+ 148221243758794364405224645520862378432,
+ ],
+};
+
+// Compute the ceiling of the base-2 logarithm of `x`.
+pub(crate) fn log2(x: u128) -> u128 {
+ let y = (127 - x.leading_zeros()) as u128;
+ y + ((x > 1 << y) as u128)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use num_bigint::ToBigInt;
+
+ #[test]
+ fn test_log2() {
+ assert_eq!(log2(1), 0);
+ assert_eq!(log2(2), 1);
+ assert_eq!(log2(3), 2);
+ assert_eq!(log2(4), 2);
+ assert_eq!(log2(15), 4);
+ assert_eq!(log2(16), 4);
+ assert_eq!(log2(30), 5);
+ assert_eq!(log2(32), 5);
+ assert_eq!(log2(1 << 127), 127);
+ assert_eq!(log2((1 << 127) + 13), 128);
+ }
+
+ struct TestFieldParametersData {
+ fp: FieldParameters, // The paramters being tested
+ expected_p: u128, // Expected fp.p
+ expected_g: u128, // Expected fp.residue(fp.g)
+ expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1
+ }
+
+ #[test]
+ fn test_fp() {
+ let test_fps = vec![
+ TestFieldParametersData {
+ fp: FP32,
+ expected_p: 4293918721,
+ expected_g: 3925978153,
+ expected_order: 1 << 20,
+ },
+ TestFieldParametersData {
+ fp: FP64,
+ expected_p: 18446744069414584321,
+ expected_g: 1753635133440165772,
+ expected_order: 1 << 32,
+ },
+ TestFieldParametersData {
+ fp: FP96,
+ expected_p: 79228148845226978974766202881,
+ expected_g: 34233996298771126927060021012,
+ expected_order: 1 << 64,
+ },
+ TestFieldParametersData {
+ fp: FP128,
+ expected_p: 340282366920938462946865773367900766209,
+ expected_g: 145091266659756586618791329697897684742,
+ expected_order: 1 << 66,
+ },
+ ];
+
+ for t in test_fps.into_iter() {
+ // Check that the field parameters have been constructed properly.
+ t.fp.check(t.expected_p, t.expected_g, t.expected_order);
+
+ // Check that the generator has the correct order.
+ assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1);
+ assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1);
+
+ // Test arithmetic using the field parameters.
+ arithmetic_test(&t.fp);
+ }
+ }
+
+ fn arithmetic_test(fp: &FieldParameters) {
+ let mut rng = rand::thread_rng();
+ let big_p = &fp.p.to_bigint().unwrap();
+
+ for _ in 0..100 {
+ let x = fp.rand_elem(&mut rng);
+ let y = fp.rand_elem(&mut rng);
+ let big_x = &fp.residue(x).to_bigint().unwrap();
+ let big_y = &fp.residue(y).to_bigint().unwrap();
+
+ // Test addition.
+ let got = fp.add(x, y);
+ let want = (big_x + big_y) % big_p;
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+
+ // Test subtraction.
+ let got = fp.sub(x, y);
+ let want = if big_x >= big_y {
+ big_x - big_y
+ } else {
+ big_p - big_y + big_x
+ };
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+
+ // Test multiplication.
+ let got = fp.mul(x, y);
+ let want = (big_x * big_y) % big_p;
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+
+ // Test inversion.
+ let got = fp.inv(x);
+ let want = big_x.modpow(&(big_p - 2u128), big_p);
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+ assert_eq!(fp.residue(fp.mul(got, x)), 1);
+
+ // Test negation.
+ let got = fp.neg(x);
+ let want = (big_p - big_x) % big_p;
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+ assert_eq!(fp.residue(fp.add(got, x)), 0);
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/lib.rs b/third_party/rust/prio/src/lib.rs
new file mode 100644
index 0000000000..f0e00de3c0
--- /dev/null
+++ b/third_party/rust/prio/src/lib.rs
@@ -0,0 +1,33 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+#![warn(missing_docs)]
+#![cfg_attr(docsrs, feature(doc_cfg))]
+
+//! Libprio-rs
+//!
+//! Implementation of the [Prio](https://crypto.stanford.edu/prio/) private data aggregation
+//! protocol. For now we only support 0 / 1 vectors.
+
+pub mod benchmarked;
+#[cfg(feature = "prio2")]
+pub mod client;
+#[cfg(feature = "prio2")]
+pub mod encrypt;
+#[cfg(feature = "prio2")]
+pub mod server;
+
+pub mod codec;
+mod fft;
+pub mod field;
+pub mod flp;
+mod fp;
+mod polynomial;
+mod prng;
+// Module test_vector depends on crate `rand` so we make it an optional feature
+// to spare most clients the extra dependency.
+#[cfg(all(any(feature = "test-util", test), feature = "prio2"))]
+pub mod test_vector;
+#[cfg(feature = "prio2")]
+pub mod util;
+pub mod vdaf;
diff --git a/third_party/rust/prio/src/polynomial.rs b/third_party/rust/prio/src/polynomial.rs
new file mode 100644
index 0000000000..7c38341e36
--- /dev/null
+++ b/third_party/rust/prio/src/polynomial.rs
@@ -0,0 +1,384 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Functions for polynomial interpolation and evaluation
+
+use crate::field::FieldElement;
+
+use std::convert::TryFrom;
+
+/// Temporary memory used for FFT
+#[derive(Clone, Debug)]
+pub struct PolyFFTTempMemory<F> {
+ fft_tmp: Vec<F>,
+ fft_y_sub: Vec<F>,
+ fft_roots_sub: Vec<F>,
+}
+
+impl<F: FieldElement> PolyFFTTempMemory<F> {
+ fn new(length: usize) -> Self {
+ PolyFFTTempMemory {
+ fft_tmp: vec![F::zero(); length],
+ fft_y_sub: vec![F::zero(); length],
+ fft_roots_sub: vec![F::zero(); length],
+ }
+ }
+}
+
+/// Auxiliary memory for polynomial interpolation and evaluation
+#[derive(Clone, Debug)]
+pub struct PolyAuxMemory<F> {
+ pub roots_2n: Vec<F>,
+ pub roots_2n_inverted: Vec<F>,
+ pub roots_n: Vec<F>,
+ pub roots_n_inverted: Vec<F>,
+ pub coeffs: Vec<F>,
+ pub fft_memory: PolyFFTTempMemory<F>,
+}
+
+impl<F: FieldElement> PolyAuxMemory<F> {
+ pub fn new(n: usize) -> Self {
+ PolyAuxMemory {
+ roots_2n: fft_get_roots(2 * n, false),
+ roots_2n_inverted: fft_get_roots(2 * n, true),
+ roots_n: fft_get_roots(n, false),
+ roots_n_inverted: fft_get_roots(n, true),
+ coeffs: vec![F::zero(); 2 * n],
+ fft_memory: PolyFFTTempMemory::new(2 * n),
+ }
+ }
+}
+
+fn fft_recurse<F: FieldElement>(
+ out: &mut [F],
+ n: usize,
+ roots: &[F],
+ ys: &[F],
+ tmp: &mut [F],
+ y_sub: &mut [F],
+ roots_sub: &mut [F],
+) {
+ if n == 1 {
+ out[0] = ys[0];
+ return;
+ }
+
+ let half_n = n / 2;
+
+ let (tmp_first, tmp_second) = tmp.split_at_mut(half_n);
+ let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n);
+ let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n);
+
+ // Recurse on the first half
+ for i in 0..half_n {
+ y_sub_first[i] = ys[i] + ys[i + half_n];
+ roots_sub_first[i] = roots[2 * i];
+ }
+ fft_recurse(
+ tmp_first,
+ half_n,
+ roots_sub_first,
+ y_sub_first,
+ tmp_second,
+ y_sub_second,
+ roots_sub_second,
+ );
+ for i in 0..half_n {
+ out[2 * i] = tmp_first[i];
+ }
+
+ // Recurse on the second half
+ for i in 0..half_n {
+ y_sub_first[i] = ys[i] - ys[i + half_n];
+ y_sub_first[i] *= roots[i];
+ }
+ fft_recurse(
+ tmp_first,
+ half_n,
+ roots_sub_first,
+ y_sub_first,
+ tmp_second,
+ y_sub_second,
+ roots_sub_second,
+ );
+ for i in 0..half_n {
+ out[2 * i + 1] = tmp[i];
+ }
+}
+
+/// Calculate `count` number of roots of unity of order `count`
+fn fft_get_roots<F: FieldElement>(count: usize, invert: bool) -> Vec<F> {
+ let mut roots = vec![F::zero(); count];
+ let mut gen = F::generator();
+ if invert {
+ gen = gen.inv();
+ }
+
+ roots[0] = F::one();
+ let step_size = F::generator_order() / F::Integer::try_from(count).unwrap();
+ // generator for subgroup of order count
+ gen = gen.pow(step_size);
+
+ roots[1] = gen;
+
+ for i in 2..count {
+ roots[i] = gen * roots[i - 1];
+ }
+
+ roots
+}
+
+fn fft_interpolate_raw<F: FieldElement>(
+ out: &mut [F],
+ ys: &[F],
+ n_points: usize,
+ roots: &[F],
+ invert: bool,
+ mem: &mut PolyFFTTempMemory<F>,
+) {
+ fft_recurse(
+ out,
+ n_points,
+ roots,
+ ys,
+ &mut mem.fft_tmp,
+ &mut mem.fft_y_sub,
+ &mut mem.fft_roots_sub,
+ );
+ if invert {
+ let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv();
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..n_points {
+ out[i] *= n_inverse;
+ }
+ }
+}
+
+pub fn poly_fft<F: FieldElement>(
+ points_out: &mut [F],
+ points_in: &[F],
+ scaled_roots: &[F],
+ n_points: usize,
+ invert: bool,
+ mem: &mut PolyFFTTempMemory<F>,
+) {
+ fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem)
+}
+
+// Evaluate a polynomial using Horner's method.
+pub fn poly_eval<F: FieldElement>(poly: &[F], eval_at: F) -> F {
+ if poly.is_empty() {
+ return F::zero();
+ }
+
+ let mut result = poly[poly.len() - 1];
+ for i in (0..poly.len() - 1).rev() {
+ result *= eval_at;
+ result += poly[i];
+ }
+
+ result
+}
+
+// Returns the degree of polynomial `p`.
+pub fn poly_deg<F: FieldElement>(p: &[F]) -> usize {
+ let mut d = p.len();
+ while d > 0 && p[d - 1] == F::zero() {
+ d -= 1;
+ }
+ d.saturating_sub(1)
+}
+
+// Multiplies polynomials `p` and `q` and returns the result.
+pub fn poly_mul<F: FieldElement>(p: &[F], q: &[F]) -> Vec<F> {
+ let p_size = poly_deg(p) + 1;
+ let q_size = poly_deg(q) + 1;
+ let mut out = vec![F::zero(); p_size + q_size];
+ for i in 0..p_size {
+ for j in 0..q_size {
+ out[i + j] += p[i] * q[j];
+ }
+ }
+ out.truncate(poly_deg(&out) + 1);
+ out
+}
+
+#[cfg(feature = "prio2")]
+pub fn poly_interpret_eval<F: FieldElement>(
+ points: &[F],
+ roots: &[F],
+ eval_at: F,
+ tmp_coeffs: &mut [F],
+ fft_memory: &mut PolyFFTTempMemory<F>,
+) -> F {
+ poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory);
+ poly_eval(&tmp_coeffs[..points.len()], eval_at)
+}
+
+// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise,
+// the output is not `0`.
+pub(crate) fn poly_range_check<F: FieldElement>(start: usize, end: usize) -> Vec<F> {
+ let mut p = vec![F::one()];
+ let mut q = [F::zero(), F::one()];
+ for i in start..end {
+ q[0] = -F::from(F::Integer::try_from(i).unwrap());
+ p = poly_mul(&p, &q);
+ }
+ p
+}
+
+#[test]
+fn test_roots() {
+ use crate::field::Field32;
+
+ let count = 128;
+ let roots = fft_get_roots::<Field32>(count, false);
+ let roots_inv = fft_get_roots::<Field32>(count, true);
+
+ for i in 0..count {
+ assert_eq!(roots[i] * roots_inv[i], 1);
+ assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1);
+ assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1);
+ }
+}
+
+#[test]
+fn test_eval() {
+ use crate::field::Field32;
+
+ let mut poly = vec![Field32::from(0); 4];
+ poly[0] = 2.into();
+ poly[1] = 1.into();
+ poly[2] = 5.into();
+ // 5*3^2 + 3 + 2 = 50
+ assert_eq!(poly_eval(&poly[..3], 3.into()), 50);
+ poly[3] = 4.into();
+ // 4*3^3 + 5*3^2 + 3 + 2 = 158
+ assert_eq!(poly_eval(&poly[..4], 3.into()), 158);
+}
+
+#[test]
+fn test_poly_deg() {
+ use crate::field::Field32;
+
+ let zero = Field32::zero();
+ let one = Field32::root(0).unwrap();
+ assert_eq!(poly_deg(&[zero]), 0);
+ assert_eq!(poly_deg(&[one]), 0);
+ assert_eq!(poly_deg(&[zero, one]), 1);
+ assert_eq!(poly_deg(&[zero, zero, one]), 2);
+ assert_eq!(poly_deg(&[zero, one, one]), 2);
+ assert_eq!(poly_deg(&[zero, one, one, one]), 3);
+ assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3);
+ assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3);
+}
+
+#[test]
+fn test_poly_mul() {
+ use crate::field::Field64;
+
+ let p = [
+ Field64::from(u64::try_from(2).unwrap()),
+ Field64::from(u64::try_from(3).unwrap()),
+ ];
+
+ let q = [
+ Field64::one(),
+ Field64::zero(),
+ Field64::from(u64::try_from(5).unwrap()),
+ ];
+
+ let want = [
+ Field64::from(u64::try_from(2).unwrap()),
+ Field64::from(u64::try_from(3).unwrap()),
+ Field64::from(u64::try_from(10).unwrap()),
+ Field64::from(u64::try_from(15).unwrap()),
+ ];
+
+ let got = poly_mul(&p, &q);
+ assert_eq!(&got, &want);
+}
+
+#[test]
+fn test_poly_range_check() {
+ use crate::field::Field64;
+
+ let start = 74;
+ let end = 112;
+ let p = poly_range_check(start, end);
+
+ // Check each number in the range.
+ for i in start..end {
+ let x = Field64::from(i as u64);
+ let y = poly_eval(&p, x);
+ assert_eq!(y, Field64::zero(), "range check failed for {}", i);
+ }
+
+ // Check the number below the range.
+ let x = Field64::from((start - 1) as u64);
+ let y = poly_eval(&p, x);
+ assert_ne!(y, Field64::zero());
+
+ // Check a number above the range.
+ let x = Field64::from(end as u64);
+ let y = poly_eval(&p, x);
+ assert_ne!(y, Field64::zero());
+}
+
+#[test]
+fn test_fft() {
+ use crate::field::Field32;
+
+ use rand::prelude::*;
+ use std::convert::TryFrom;
+
+ let count = 128;
+ let mut mem = PolyAuxMemory::new(count / 2);
+
+ let mut poly = vec![Field32::from(0); count];
+ let mut points2 = vec![Field32::from(0); count];
+
+ let points = (0..count)
+ .into_iter()
+ .map(|_| Field32::from(random::<u32>()))
+ .collect::<Vec<Field32>>();
+
+ // From points to coeffs and back
+ poly_fft(
+ &mut poly,
+ &points,
+ &mem.roots_2n,
+ count,
+ false,
+ &mut mem.fft_memory,
+ );
+ poly_fft(
+ &mut points2,
+ &poly,
+ &mem.roots_2n_inverted,
+ count,
+ true,
+ &mut mem.fft_memory,
+ );
+
+ assert_eq!(points, points2);
+
+ // interpolation
+ poly_fft(
+ &mut poly,
+ &points,
+ &mem.roots_2n,
+ count,
+ false,
+ &mut mem.fft_memory,
+ );
+
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..count {
+ let mut should_be = Field32::from(0);
+ for j in 0..count {
+ should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be;
+ }
+ assert_eq!(should_be, poly[i]);
+ }
+}
diff --git a/third_party/rust/prio/src/prng.rs b/third_party/rust/prio/src/prng.rs
new file mode 100644
index 0000000000..764cd7b025
--- /dev/null
+++ b/third_party/rust/prio/src/prng.rs
@@ -0,0 +1,208 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Tool for generating pseudorandom field elements.
+//!
+//! NOTE: The public API for this module is a work in progress.
+
+use crate::field::{FieldElement, FieldError};
+use crate::vdaf::prg::SeedStream;
+#[cfg(feature = "crypto-dependencies")]
+use crate::vdaf::prg::SeedStreamAes128;
+#[cfg(feature = "crypto-dependencies")]
+use getrandom::getrandom;
+
+use std::marker::PhantomData;
+
+const BUFFER_SIZE_IN_ELEMENTS: usize = 128;
+
+/// Errors propagated by methods in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum PrngError {
+ /// Failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+}
+
+/// This type implements an iterator that generates a pseudorandom sequence of field elements. The
+/// sequence is derived from the key stream of AES-128 in CTR mode with a random IV.
+#[derive(Debug)]
+pub(crate) struct Prng<F, S> {
+ phantom: PhantomData<F>,
+ seed_stream: S,
+ buffer: Vec<u8>,
+ buffer_index: usize,
+ output_written: usize,
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl<F: FieldElement> Prng<F, SeedStreamAes128> {
+ /// Create a [`Prng`] from a seed for Prio 2. The first 16 bytes of the seed and the last 16
+ /// bytes of the seed are used, respectively, for the key and initialization vector for AES128
+ /// in CTR mode.
+ pub(crate) fn from_prio2_seed(seed: &[u8; 32]) -> Self {
+ let seed_stream = SeedStreamAes128::new(&seed[..16], &seed[16..]);
+ Self::from_seed_stream(seed_stream)
+ }
+
+ /// Create a [`Prng`] from a randomly generated seed.
+ pub(crate) fn new() -> Result<Self, PrngError> {
+ let mut seed = [0; 32];
+ getrandom(&mut seed)?;
+ Ok(Self::from_prio2_seed(&seed))
+ }
+}
+
+impl<F, S> Prng<F, S>
+where
+ F: FieldElement,
+ S: SeedStream,
+{
+ pub(crate) fn from_seed_stream(mut seed_stream: S) -> Self {
+ let mut buffer = vec![0; BUFFER_SIZE_IN_ELEMENTS * F::ENCODED_SIZE];
+ seed_stream.fill(&mut buffer);
+
+ Self {
+ phantom: PhantomData::<F>,
+ seed_stream,
+ buffer,
+ buffer_index: 0,
+ output_written: 0,
+ }
+ }
+
+ pub(crate) fn get(&mut self) -> F {
+ loop {
+ // Seek to the next chunk of the buffer that encodes an element of F.
+ for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) {
+ let j = i + F::ENCODED_SIZE;
+ if let Some(x) = match F::try_from_random(&self.buffer[i..j]) {
+ Ok(x) => Some(x),
+ Err(FieldError::ModulusOverflow) => None, // reject this sample
+ Err(err) => panic!("unexpected error: {}", err),
+ } {
+ // Set the buffer index to the next chunk.
+ self.buffer_index = j;
+ self.output_written += 1;
+ return x;
+ }
+ }
+
+ // Refresh buffer with the next chunk of PRG output.
+ self.seed_stream.fill(&mut self.buffer);
+ self.buffer_index = 0;
+ }
+ }
+}
+
+impl<F, S> Iterator for Prng<F, S>
+where
+ F: FieldElement,
+ S: SeedStream,
+{
+ type Item = F;
+
+ fn next(&mut self) -> Option<F> {
+ Some(self.get())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ codec::Decode,
+ field::{Field96, FieldPrio2},
+ vdaf::prg::{Prg, PrgAes128, Seed},
+ };
+ use std::convert::TryInto;
+
+ #[test]
+ fn secret_sharing_interop() {
+ let seed = [
+ 0xcd, 0x85, 0x5b, 0xd4, 0x86, 0x48, 0xa4, 0xce, 0x52, 0x5c, 0x36, 0xee, 0x5a, 0x71,
+ 0xf3, 0x0f, 0x66, 0x80, 0xd3, 0x67, 0x53, 0x9a, 0x39, 0x6f, 0x12, 0x2f, 0xad, 0x94,
+ 0x4d, 0x34, 0xcb, 0x58,
+ ];
+
+ let reference = [
+ 0xd0056ec5, 0xe23f9c52, 0x47e4ddb4, 0xbe5dacf6, 0x4b130aba, 0x530c7a90, 0xe8fc4ee5,
+ 0xb0569cb7, 0x7774cd3c, 0x7f24e6a5, 0xcc82355d, 0xc41f4f13, 0x67fe193c, 0xc94d63a4,
+ 0x5d7b474c, 0xcc5c9f5f, 0xe368e1d5, 0x020fa0cf, 0x9e96aa2a, 0xe924137d, 0xfa026ab9,
+ 0x8ebca0cc, 0x26fc58a5, 0x10a7b173, 0xb9c97291, 0x53ef0e28, 0x069cfb8e, 0xe9383cae,
+ 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58,
+ ];
+
+ let share2 = extract_share_from_seed::<FieldPrio2>(reference.len(), &seed);
+
+ assert_eq!(share2, reference);
+ }
+
+ /// takes a seed and hash as base64 encoded strings
+ #[cfg(feature = "prio2")]
+ fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) {
+ let seed = base64::decode(seed_base64).unwrap();
+ let random_data = extract_share_from_seed::<FieldPrio2>(len, &seed);
+
+ let random_bytes = FieldPrio2::slice_into_byte_vec(&random_data);
+
+ let digest = ring::digest::digest(&ring::digest::SHA256, &random_bytes);
+ assert_eq!(base64::encode(digest), hash_base64);
+ }
+
+ #[test]
+ #[cfg(feature = "prio2")]
+ fn test_hash_interop() {
+ random_data_interop(
+ "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
+ "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=",
+ 100_000,
+ );
+
+ // zero seed
+ random_data_interop(
+ "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
+ "3wHQbSwAn9GPfoNkKe1qSzWdKnu/R+hPPyRwwz6Di+w=",
+ 100_000,
+ );
+ // 0, 1, 2 ... seed
+ random_data_interop(
+ "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
+ "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=",
+ 100_000,
+ );
+ // one arbirtary fixed seed
+ random_data_interop(
+ "rkLrnVcU8ULaiuXTvR3OKrfpMX0kQidqVzta1pleKKg=",
+ "b1fMXYrGUNR3wOZ/7vmUMmY51QHoPDBzwok0fz6xC0I=",
+ 100_000,
+ );
+ // all bits set seed
+ random_data_interop(
+ "//////////////////////////////////////////8=",
+ "iBiDaqLrv7/rX/+vs6akPiprGgYfULdh/XhoD61HQXA=",
+ 100_000,
+ );
+ }
+
+ fn extract_share_from_seed<F: FieldElement>(length: usize, seed: &[u8]) -> Vec<F> {
+ assert_eq!(seed.len(), 32);
+ Prng::from_prio2_seed(seed.try_into().unwrap())
+ .take(length)
+ .collect()
+ }
+
+ #[test]
+ fn rejection_sampling_test_vector() {
+ // These constants were found in a brute-force search, and they test that the PRG performs
+ // rejection sampling correctly when raw AES-CTR output exceeds the prime modulus.
+ let seed_stream = PrgAes128::seed_stream(
+ &Seed::get_decoded(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 95]).unwrap(),
+ b"",
+ );
+ let mut prng = Prng::<Field96, _>::from_seed_stream(seed_stream);
+ let expected = Field96::from(39729620190871453347343769187);
+ let actual = prng.nth(145).unwrap();
+ assert_eq!(actual, expected);
+ }
+}
diff --git a/third_party/rust/prio/src/server.rs b/third_party/rust/prio/src/server.rs
new file mode 100644
index 0000000000..01f309e797
--- /dev/null
+++ b/third_party/rust/prio/src/server.rs
@@ -0,0 +1,469 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! The Prio v2 server. Only 0 / 1 vectors are supported for now.
+use crate::{
+ encrypt::{decrypt_share, EncryptError, PrivateKey},
+ field::{merge_vector, FieldElement, FieldError},
+ polynomial::{poly_interpret_eval, PolyAuxMemory},
+ prng::{Prng, PrngError},
+ util::{proof_length, unpack_proof, SerializeError},
+ vdaf::prg::SeedStreamAes128,
+};
+use serde::{Deserialize, Serialize};
+use std::convert::TryInto;
+
+/// Possible errors from server operations
+#[derive(Debug, thiserror::Error)]
+pub enum ServerError {
+ /// Unexpected Share Length
+ #[error("unexpected share length")]
+ ShareLength,
+ /// Encryption/decryption error
+ #[error("encryption/decryption error")]
+ Encrypt(#[from] EncryptError),
+ /// Finite field operation error
+ #[error("finite field operation error")]
+ Field(#[from] FieldError),
+ /// Serialization/deserialization error
+ #[error("serialization/deserialization error")]
+ Serialize(#[from] SerializeError),
+ /// Failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+ /// PRNG error.
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+}
+
+/// Auxiliary memory for constructing a
+/// [`VerificationMessage`](struct.VerificationMessage.html)
+#[derive(Debug)]
+pub struct ValidationMemory<F> {
+ points_f: Vec<F>,
+ points_g: Vec<F>,
+ points_h: Vec<F>,
+ poly_mem: PolyAuxMemory<F>,
+}
+
+impl<F: FieldElement> ValidationMemory<F> {
+ /// Construct a new ValidationMemory object for validating proof shares of
+ /// length `dimension`.
+ pub fn new(dimension: usize) -> Self {
+ let n: usize = (dimension + 1).next_power_of_two();
+ ValidationMemory {
+ points_f: vec![F::zero(); n],
+ points_g: vec![F::zero(); n],
+ points_h: vec![F::zero(); 2 * n],
+ poly_mem: PolyAuxMemory::new(n),
+ }
+ }
+}
+
+/// Main workhorse of the server.
+#[derive(Debug)]
+pub struct Server<F> {
+ prng: Prng<F, SeedStreamAes128>,
+ dimension: usize,
+ is_first_server: bool,
+ accumulator: Vec<F>,
+ validation_mem: ValidationMemory<F>,
+ private_key: PrivateKey,
+}
+
+impl<F: FieldElement> Server<F> {
+ /// Construct a new server instance
+ ///
+ /// Params:
+ /// * `dimension`: the number of elements in the aggregation vector.
+ /// * `is_first_server`: only one of the servers should have this true.
+ /// * `private_key`: the private key for decrypting the share of the proof.
+ pub fn new(
+ dimension: usize,
+ is_first_server: bool,
+ private_key: PrivateKey,
+ ) -> Result<Server<F>, ServerError> {
+ Ok(Server {
+ prng: Prng::new()?,
+ dimension,
+ is_first_server,
+ accumulator: vec![F::zero(); dimension],
+ validation_mem: ValidationMemory::new(dimension),
+ private_key,
+ })
+ }
+
+ /// Decrypt and deserialize
+ fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<F>, ServerError> {
+ let len = proof_length(self.dimension);
+ let share = decrypt_share(encrypted_share, &self.private_key)?;
+ Ok(if self.is_first_server {
+ F::byte_slice_into_vec(&share)?
+ } else {
+ if share.len() != 32 {
+ return Err(ServerError::ShareLength);
+ }
+
+ Prng::from_prio2_seed(&share.try_into().unwrap())
+ .take(len)
+ .collect()
+ })
+ }
+
+ /// Generate verification message from an encrypted share
+ ///
+ /// This decrypts the share of the proof and constructs the
+ /// [`VerificationMessage`](struct.VerificationMessage.html).
+ /// The `eval_at` field should be generate by
+ /// [choose_eval_at](#method.choose_eval_at).
+ pub fn generate_verification_message(
+ &mut self,
+ eval_at: F,
+ share: &[u8],
+ ) -> Result<VerificationMessage<F>, ServerError> {
+ let share_field = self.deserialize_share(share)?;
+ generate_verification_message(
+ self.dimension,
+ eval_at,
+ &share_field,
+ self.is_first_server,
+ &mut self.validation_mem,
+ )
+ }
+
+ /// Add the content of the encrypted share into the accumulator
+ ///
+ /// This only changes the accumulator if the verification messages `v1` and
+ /// `v2` indicate that the share passed validation.
+ pub fn aggregate(
+ &mut self,
+ share: &[u8],
+ v1: &VerificationMessage<F>,
+ v2: &VerificationMessage<F>,
+ ) -> Result<bool, ServerError> {
+ let share_field = self.deserialize_share(share)?;
+ let is_valid = is_valid_share(v1, v2);
+ if is_valid {
+ // Add to the accumulator. share_field also includes the proof
+ // encoding, so we slice off the first dimension fields, which are
+ // the actual data share.
+ merge_vector(&mut self.accumulator, &share_field[..self.dimension])?;
+ }
+
+ Ok(is_valid)
+ }
+
+ /// Return the current accumulated shares.
+ ///
+ /// These can be merged together using
+ /// [`reconstruct_shares`](../util/fn.reconstruct_shares.html).
+ pub fn total_shares(&self) -> &[F] {
+ &self.accumulator
+ }
+
+ /// Merge shares from another server.
+ ///
+ /// This modifies the current accumulator.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if `other_total_shares.len()` is not equal to this
+ //// server's `dimension`.
+ pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> {
+ Ok(merge_vector(&mut self.accumulator, other_total_shares)?)
+ }
+
+ /// Choose a random point for polynomial evaluation
+ ///
+ /// The point returned is not one of the roots used for polynomial
+ /// evaluation.
+ pub fn choose_eval_at(&mut self) -> F {
+ loop {
+ let eval_at = self.prng.get();
+ if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) {
+ break eval_at;
+ }
+ }
+ }
+}
+
+/// Verification message for proof validation
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct VerificationMessage<F> {
+ /// f evaluated at random point
+ pub f_r: F,
+ /// g evaluated at random point
+ pub g_r: F,
+ /// h evaluated at random point
+ pub h_r: F,
+}
+
+/// Given a proof and evaluation point, this constructs the verification
+/// message.
+pub fn generate_verification_message<F: FieldElement>(
+ dimension: usize,
+ eval_at: F,
+ proof: &[F],
+ is_first_server: bool,
+ mem: &mut ValidationMemory<F>,
+) -> Result<VerificationMessage<F>, ServerError> {
+ let unpacked = unpack_proof(proof, dimension)?;
+ let proof_length = 2 * (dimension + 1).next_power_of_two();
+
+ // set zero terms
+ mem.points_f[0] = *unpacked.f0;
+ mem.points_g[0] = *unpacked.g0;
+ mem.points_h[0] = *unpacked.h0;
+
+ // set points_f and points_g
+ for (i, x) in unpacked.data.iter().enumerate() {
+ mem.points_f[i + 1] = *x;
+
+ if is_first_server {
+ // only one server needs to subtract one for point_g
+ mem.points_g[i + 1] = *x - F::one();
+ } else {
+ mem.points_g[i + 1] = *x;
+ }
+ }
+
+ // set points_h, skipping over elements that should be zero
+ let mut i = 1;
+ let mut j = 0;
+ while i < proof_length {
+ mem.points_h[i] = unpacked.points_h_packed[j];
+ j += 1;
+ i += 2;
+ }
+
+ // evaluate polynomials at random point
+ let f_r = poly_interpret_eval(
+ &mem.points_f,
+ &mem.poly_mem.roots_n_inverted,
+ eval_at,
+ &mut mem.poly_mem.coeffs,
+ &mut mem.poly_mem.fft_memory,
+ );
+ let g_r = poly_interpret_eval(
+ &mem.points_g,
+ &mem.poly_mem.roots_n_inverted,
+ eval_at,
+ &mut mem.poly_mem.coeffs,
+ &mut mem.poly_mem.fft_memory,
+ );
+ let h_r = poly_interpret_eval(
+ &mem.points_h,
+ &mem.poly_mem.roots_2n_inverted,
+ eval_at,
+ &mut mem.poly_mem.coeffs,
+ &mut mem.poly_mem.fft_memory,
+ );
+
+ Ok(VerificationMessage { f_r, g_r, h_r })
+}
+
+/// Decides if the distributed proof is valid
+pub fn is_valid_share<F: FieldElement>(
+ v1: &VerificationMessage<F>,
+ v2: &VerificationMessage<F>,
+) -> bool {
+ // reconstruct f_r, g_r, h_r
+ let f_r = v1.f_r + v2.f_r;
+ let g_r = v1.g_r + v2.g_r;
+ let h_r = v1.h_r + v2.h_r;
+ // validity check
+ f_r * g_r == h_r
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ encrypt::{encrypt_share, PublicKey},
+ field::{Field32, FieldPrio2},
+ test_vector::Priov2TestVector,
+ util::{self, unpack_proof_mut},
+ };
+ use serde_json;
+
+ #[test]
+ fn test_validation() {
+ let dim = 8;
+ let proof_u32: Vec<u32> = vec![
+ 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
+ 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
+ 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
+ ];
+
+ let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
+ let share2 = util::tests::secret_share(&mut proof);
+ let eval_at = Field32::from(12313);
+
+ let mut validation_mem = ValidationMemory::new(dim);
+
+ let v1 =
+ generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
+ let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
+ .unwrap();
+ assert!(is_valid_share(&v1, &v2));
+ }
+
+ #[test]
+ fn test_verification_message_serde() {
+ let dim = 8;
+ let proof_u32: Vec<u32> = vec![
+ 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
+ 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
+ 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
+ ];
+
+ let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
+ let share2 = util::tests::secret_share(&mut proof);
+ let eval_at = Field32::from(12313);
+
+ let mut validation_mem = ValidationMemory::new(dim);
+
+ let v1 =
+ generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
+ let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
+ .unwrap();
+
+ // serialize and deserialize the first verification message
+ let serialized = serde_json::to_string(&v1).unwrap();
+ let deserialized: VerificationMessage<Field32> = serde_json::from_str(&serialized).unwrap();
+
+ assert!(is_valid_share(&deserialized, &v2));
+ }
+
+ #[derive(Debug, Clone, Copy, PartialEq)]
+ enum Tweak {
+ None,
+ WrongInput,
+ DataPartOfShare,
+ ZeroTermF,
+ ZeroTermG,
+ ZeroTermH,
+ PointsH,
+ VerificationF,
+ VerificationG,
+ VerificationH,
+ }
+
+ fn tweaks(tweak: Tweak) {
+ let dim = 123;
+
+ // We generate a test vector just to get a `Client` and `Server`s with
+ // encryption keys but construct and tweak inputs below.
+ let test_vector = Priov2TestVector::new(dim, 0).unwrap();
+ let mut server1 = test_vector.server_1().unwrap();
+ let mut server2 = test_vector.server_2().unwrap();
+ let mut client = test_vector.client().unwrap();
+
+ // all zero data
+ let mut data = vec![FieldPrio2::zero(); dim];
+
+ if let Tweak::WrongInput = tweak {
+ data[0] = FieldPrio2::from(2);
+ }
+
+ let (share1_original, share2) = client.encode_simple(&data).unwrap();
+
+ let decrypted_share1 = decrypt_share(&share1_original, &server1.private_key).unwrap();
+ let mut share1_field = FieldPrio2::byte_slice_into_vec(&decrypted_share1).unwrap();
+ let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap();
+
+ let one = FieldPrio2::from(1);
+
+ match tweak {
+ Tweak::DataPartOfShare => unpacked_share1.data[0] += one,
+ Tweak::ZeroTermF => *unpacked_share1.f0 += one,
+ Tweak::ZeroTermG => *unpacked_share1.g0 += one,
+ Tweak::ZeroTermH => *unpacked_share1.h0 += one,
+ Tweak::PointsH => unpacked_share1.points_h_packed[0] += one,
+ _ => (),
+ };
+
+ // reserialize altered share1
+ let share1_modified = encrypt_share(
+ &FieldPrio2::slice_into_byte_vec(&share1_field),
+ &PublicKey::from(&server1.private_key),
+ )
+ .unwrap();
+
+ let eval_at = server1.choose_eval_at();
+
+ let mut v1 = server1
+ .generate_verification_message(eval_at, &share1_modified)
+ .unwrap();
+ let v2 = server2
+ .generate_verification_message(eval_at, &share2)
+ .unwrap();
+
+ match tweak {
+ Tweak::VerificationF => v1.f_r += one,
+ Tweak::VerificationG => v1.g_r += one,
+ Tweak::VerificationH => v1.h_r += one,
+ _ => (),
+ }
+
+ let should_be_valid = matches!(tweak, Tweak::None);
+ assert_eq!(
+ server1.aggregate(&share1_modified, &v1, &v2).unwrap(),
+ should_be_valid
+ );
+ assert_eq!(
+ server2.aggregate(&share2, &v1, &v2).unwrap(),
+ should_be_valid
+ );
+ }
+
+ #[test]
+ fn tweak_none() {
+ tweaks(Tweak::None);
+ }
+
+ #[test]
+ fn tweak_input() {
+ tweaks(Tweak::WrongInput);
+ }
+
+ #[test]
+ fn tweak_data() {
+ tweaks(Tweak::DataPartOfShare);
+ }
+
+ #[test]
+ fn tweak_f_zero() {
+ tweaks(Tweak::ZeroTermF);
+ }
+
+ #[test]
+ fn tweak_g_zero() {
+ tweaks(Tweak::ZeroTermG);
+ }
+
+ #[test]
+ fn tweak_h_zero() {
+ tweaks(Tweak::ZeroTermH);
+ }
+
+ #[test]
+ fn tweak_h_points() {
+ tweaks(Tweak::PointsH);
+ }
+
+ #[test]
+ fn tweak_f_verif() {
+ tweaks(Tweak::VerificationF);
+ }
+
+ #[test]
+ fn tweak_g_verif() {
+ tweaks(Tweak::VerificationG);
+ }
+
+ #[test]
+ fn tweak_h_verif() {
+ tweaks(Tweak::VerificationH);
+ }
+}
diff --git a/third_party/rust/prio/src/test_vector.rs b/third_party/rust/prio/src/test_vector.rs
new file mode 100644
index 0000000000..1306bc26f9
--- /dev/null
+++ b/third_party/rust/prio/src/test_vector.rs
@@ -0,0 +1,244 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Module `test_vector` generates test vectors of serialized Prio inputs and
+//! support for working with test vectors, enabling backward compatibility
+//! testing.
+
+use crate::{
+ client::{Client, ClientError},
+ encrypt::{PrivateKey, PublicKey},
+ field::{FieldElement, FieldPrio2},
+ server::{Server, ServerError},
+};
+use rand::Rng;
+use serde::{Deserialize, Serialize};
+use std::fmt::Debug;
+
+/// Errors propagated by functions in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum TestVectorError {
+ /// Error from Prio client
+ #[error("Prio client error {0}")]
+ Client(#[from] ClientError),
+ /// Error from Prio server
+ #[error("Prio server error {0}")]
+ Server(#[from] ServerError),
+ /// Error while converting primitive to FieldElement associated integer type
+ #[error("Integer conversion error {0}")]
+ IntegerConversion(String),
+}
+
+const SERVER_1_PRIVATE_KEY: &str =
+ "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBH\
+ fNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
+const SERVER_2_PRIVATE_KEY: &str =
+ "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rD\
+ ULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";
+
+/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate
+/// test vectors and later to decrypt them.
+fn server_1_private_key() -> PrivateKey {
+ PrivateKey::from_base64(SERVER_1_PRIVATE_KEY).unwrap()
+}
+
+/// The public portion of [`server_1_private_key`].
+fn server_1_public_key() -> PublicKey {
+ PublicKey::from(&server_1_private_key())
+}
+
+/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate
+/// test vectors and later to decrypt them.
+fn server_2_private_key() -> PrivateKey {
+ PrivateKey::from_base64(SERVER_2_PRIVATE_KEY).unwrap()
+}
+
+/// The public portion of [`server_2_private_key`].
+fn server_2_public_key() -> PublicKey {
+ PublicKey::from(&server_2_private_key())
+}
+
+/// A test vector of Prio inputs, serialized and encrypted in the Priov2 format,
+/// along with a reference sum. The field is always [`FieldPrio2`].
+#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub struct Priov2TestVector {
+ /// Base64 encoded private key for the "first" a.k.a. "PHA" server, which
+ /// may be used to decrypt `server_1_shares`.
+ pub server_1_private_key: String,
+ /// Base64 encoded private key for the non-"first" a.k.a. "facilitator"
+ /// server, which may be used to decrypt `server_2_shares`.
+ pub server_2_private_key: String,
+ /// Dimension (number of buckets) of the inputs
+ pub dimension: usize,
+ /// Encrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA"
+ /// server. The inner `Vec`s are encrypted bytes.
+ #[serde(
+ serialize_with = "base64::serialize_bytes",
+ deserialize_with = "base64::deserialize_bytes"
+ )]
+ pub server_1_shares: Vec<Vec<u8>>,
+ /// Encrypted share of Priov2 format inputs for the non-"first" a.k.a.
+ /// "facilitator" server.
+ #[serde(
+ serialize_with = "base64::serialize_bytes",
+ deserialize_with = "base64::deserialize_bytes"
+ )]
+ pub server_2_shares: Vec<Vec<u8>>,
+ /// The sum over the inputs.
+ #[serde(
+ serialize_with = "base64::serialize_field",
+ deserialize_with = "base64::deserialize_field"
+ )]
+ pub reference_sum: Vec<FieldPrio2>,
+ /// The version of the crate that generated this test vector
+ pub prio_crate_version: String,
+}
+
+impl Priov2TestVector {
+ /// Construct a test vector of `number_of_clients` inputs, each of which is a
+ /// `dimension`-dimension vector of random Boolean values encoded as
+ /// [`FieldPrio2`].
+ pub fn new(dimension: usize, number_of_clients: usize) -> Result<Self, TestVectorError> {
+ let mut client: Client<FieldPrio2> =
+ Client::new(dimension, server_1_public_key(), server_2_public_key())?;
+
+ let mut reference_sum = vec![FieldPrio2::zero(); dimension];
+ let mut server_1_shares = Vec::with_capacity(number_of_clients);
+ let mut server_2_shares = Vec::with_capacity(number_of_clients);
+
+ let mut rng = rand::thread_rng();
+
+ for _ in 0..number_of_clients {
+ // Generate a random vector of booleans
+ let data: Vec<FieldPrio2> = (0..dimension)
+ .map(|_| FieldPrio2::from(rng.gen_range(0..2)))
+ .collect();
+
+ // Update reference sum
+ for (r, d) in reference_sum.iter_mut().zip(&data) {
+ *r += *d;
+ }
+
+ let (server_1_share, server_2_share) = client.encode_simple(&data)?;
+
+ server_1_shares.push(server_1_share);
+ server_2_shares.push(server_2_share);
+ }
+
+ Ok(Self {
+ server_1_private_key: SERVER_1_PRIVATE_KEY.to_owned(),
+ server_2_private_key: SERVER_2_PRIVATE_KEY.to_owned(),
+ dimension,
+ server_1_shares,
+ server_2_shares,
+ reference_sum,
+ prio_crate_version: env!("CARGO_PKG_VERSION").to_owned(),
+ })
+ }
+
+ /// Construct a [`Client`] that can encrypt input shares to this test
+ /// vector's servers.
+ pub fn client(&self) -> Result<Client<FieldPrio2>, TestVectorError> {
+ Ok(Client::new(
+ self.dimension,
+ PublicKey::from(&PrivateKey::from_base64(&self.server_1_private_key).unwrap()),
+ PublicKey::from(&PrivateKey::from_base64(&self.server_2_private_key).unwrap()),
+ )?)
+ }
+
+ /// Construct a [`Server`] that can decrypt `server_1_shares`.
+ pub fn server_1(&self) -> Result<Server<FieldPrio2>, TestVectorError> {
+ Ok(Server::new(
+ self.dimension,
+ true,
+ PrivateKey::from_base64(&self.server_1_private_key).unwrap(),
+ )?)
+ }
+
+ /// Construct a [`Server`] that can decrypt `server_2_shares`.
+ pub fn server_2(&self) -> Result<Server<FieldPrio2>, TestVectorError> {
+ Ok(Server::new(
+ self.dimension,
+ false,
+ PrivateKey::from_base64(&self.server_2_private_key).unwrap(),
+ )?)
+ }
+}
+
+mod base64 {
+ //! Custom serialization module used for some members of struct
+ //! `Priov2TestVector` so that byte slices are serialized as base64 strings
+ //! instead of an array of an array of integers when serializing to JSON.
+ //
+ // Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2
+ use crate::field::{FieldElement, FieldPrio2};
+ use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
+
+ pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> {
+ let base64_vec = v.iter().map(base64::encode).collect();
+ <Vec<String>>::serialize(&base64_vec, s)
+ }
+
+ pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> {
+ <Vec<String>>::deserialize(d)?
+ .iter()
+ .map(|s| base64::decode(s.as_bytes()).map_err(Error::custom))
+ .collect()
+ }
+
+ pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> {
+ String::serialize(&base64::encode(FieldPrio2::slice_into_byte_vec(v)), s)
+ }
+
+ pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> {
+ let bytes = base64::decode(String::deserialize(d)?.as_bytes()).map_err(Error::custom)?;
+ FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::util::reconstruct_shares;
+
+ #[test]
+ fn roundtrip_test_vector_serialization() {
+ let test_vector = Priov2TestVector::new(123, 100).unwrap();
+ let serialized = serde_json::to_vec(&test_vector).unwrap();
+ let test_vector_again: Priov2TestVector = serde_json::from_slice(&serialized).unwrap();
+
+ assert_eq!(test_vector, test_vector_again);
+ }
+
+ #[test]
+ fn accumulation_field_priov2() {
+ let dimension = 123;
+ let test_vector = Priov2TestVector::new(dimension, 100).unwrap();
+
+ let mut server1 = test_vector.server_1().unwrap();
+ let mut server2 = test_vector.server_2().unwrap();
+
+ for (server_1_share, server_2_share) in test_vector
+ .server_1_shares
+ .iter()
+ .zip(&test_vector.server_2_shares)
+ {
+ let eval_at = server1.choose_eval_at();
+
+ let v1 = server1
+ .generate_verification_message(eval_at, server_1_share)
+ .unwrap();
+ let v2 = server2
+ .generate_verification_message(eval_at, server_2_share)
+ .unwrap();
+
+ assert!(server1.aggregate(server_1_share, &v1, &v2).unwrap());
+ assert!(server2.aggregate(server_2_share, &v1, &v2).unwrap());
+ }
+
+ let total1 = server1.total_shares();
+ let total2 = server2.total_shares();
+
+ let reconstructed = reconstruct_shares(total1, total2).unwrap();
+ assert_eq!(reconstructed, test_vector.reference_sum);
+ }
+}
diff --git a/third_party/rust/prio/src/util.rs b/third_party/rust/prio/src/util.rs
new file mode 100644
index 0000000000..0518112d83
--- /dev/null
+++ b/third_party/rust/prio/src/util.rs
@@ -0,0 +1,201 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Utility functions for handling Prio stuff.
+
+use crate::field::{FieldElement, FieldError};
+
+/// Serialization errors
+#[derive(Debug, thiserror::Error)]
+pub enum SerializeError {
+ /// Emitted by `unpack_proof[_mut]` if the serialized share+proof has the wrong length
+ #[error("serialized input has wrong length")]
+ UnpackInputSizeMismatch,
+ /// Finite field operation error.
+ #[error("finite field operation error")]
+ Field(#[from] FieldError),
+}
+
+/// Returns the number of field elements in the proof for given dimension of
+/// data elements
+///
+/// Proof is a vector, where the first `dimension` elements are the data
+/// elements, the next 3 elements are the zero terms for polynomials f, g and h
+/// and the remaining elements are non-zero points of h(x).
+pub fn proof_length(dimension: usize) -> usize {
+ // number of data items + number of zero terms + N
+ dimension + 3 + (dimension + 1).next_power_of_two()
+}
+
+/// Unpacked proof with subcomponents
+#[derive(Debug)]
+pub struct UnpackedProof<'a, F: FieldElement> {
+ /// Data
+ pub data: &'a [F],
+ /// Zeroth coefficient of polynomial f
+ pub f0: &'a F,
+ /// Zeroth coefficient of polynomial g
+ pub g0: &'a F,
+ /// Zeroth coefficient of polynomial h
+ pub h0: &'a F,
+ /// Non-zero points of polynomial h
+ pub points_h_packed: &'a [F],
+}
+
+/// Unpacked proof with mutable subcomponents
+#[derive(Debug)]
+pub struct UnpackedProofMut<'a, F: FieldElement> {
+ /// Data
+ pub data: &'a mut [F],
+ /// Zeroth coefficient of polynomial f
+ pub f0: &'a mut F,
+ /// Zeroth coefficient of polynomial g
+ pub g0: &'a mut F,
+ /// Zeroth coefficient of polynomial h
+ pub h0: &'a mut F,
+ /// Non-zero points of polynomial h
+ pub points_h_packed: &'a mut [F],
+}
+
+/// Unpacks the proof vector into subcomponents
+pub(crate) fn unpack_proof<F: FieldElement>(
+ proof: &[F],
+ dimension: usize,
+) -> Result<UnpackedProof<F>, SerializeError> {
+ // check the proof length
+ if proof.len() != proof_length(dimension) {
+ return Err(SerializeError::UnpackInputSizeMismatch);
+ }
+ // split share into components
+ let (data, rest) = proof.split_at(dimension);
+ if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) {
+ Ok(UnpackedProof {
+ data,
+ f0,
+ g0,
+ h0,
+ points_h_packed,
+ })
+ } else {
+ Err(SerializeError::UnpackInputSizeMismatch)
+ }
+}
+
+/// Unpacks a mutable proof vector into mutable subcomponents
+// TODO(timg): This is public because it is used by tests/tweaks.rs. We should
+// refactor that test so it doesn't require the crate to expose this function or
+// UnpackedProofMut.
+pub fn unpack_proof_mut<F: FieldElement>(
+ proof: &mut [F],
+ dimension: usize,
+) -> Result<UnpackedProofMut<F>, SerializeError> {
+ // check the share length
+ if proof.len() != proof_length(dimension) {
+ return Err(SerializeError::UnpackInputSizeMismatch);
+ }
+ // split share into components
+ let (data, rest) = proof.split_at_mut(dimension);
+ if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) {
+ Ok(UnpackedProofMut {
+ data,
+ f0,
+ g0,
+ h0,
+ points_h_packed,
+ })
+ } else {
+ Err(SerializeError::UnpackInputSizeMismatch)
+ }
+}
+
+/// Add two field element arrays together elementwise.
+///
+/// Returns None, when array lengths are not equal.
+pub fn reconstruct_shares<F: FieldElement>(share1: &[F], share2: &[F]) -> Option<Vec<F>> {
+ if share1.len() != share2.len() {
+ return None;
+ }
+
+ let mut reconstructed: Vec<F> = vec![F::zero(); share1.len()];
+
+ for (r, (s1, s2)) in reconstructed
+ .iter_mut()
+ .zip(share1.iter().zip(share2.iter()))
+ {
+ *r = *s1 + *s2;
+ }
+
+ Some(reconstructed)
+}
+
+#[cfg(test)]
+pub mod tests {
+ use super::*;
+ use crate::field::{Field32, Field64};
+ use assert_matches::assert_matches;
+
+ pub fn secret_share(share: &mut [Field32]) -> Vec<Field32> {
+ use rand::Rng;
+ let mut rng = rand::thread_rng();
+ let mut random = vec![0u32; share.len()];
+ let mut share2 = vec![Field32::zero(); share.len()];
+
+ rng.fill(&mut random[..]);
+
+ for (r, f) in random.iter().zip(share2.iter_mut()) {
+ *f = Field32::from(*r);
+ }
+
+ for (f1, f2) in share.iter_mut().zip(share2.iter()) {
+ *f1 -= *f2;
+ }
+
+ share2
+ }
+
+ #[test]
+ fn test_unpack_share_mut() {
+ let dim = 15;
+ let len = proof_length(dim);
+
+ let mut share = vec![Field32::from(0); len];
+ let unpacked = unpack_proof_mut(&mut share, dim).unwrap();
+ *unpacked.f0 = Field32::from(12);
+ assert_eq!(share[dim], 12);
+
+ let mut short_share = vec![Field32::from(0); len - 1];
+ assert_matches!(
+ unpack_proof_mut(&mut short_share, dim),
+ Err(SerializeError::UnpackInputSizeMismatch)
+ );
+ }
+
+ #[test]
+ fn test_unpack_share() {
+ let dim = 15;
+ let len = proof_length(dim);
+
+ let share = vec![Field64::from(0); len];
+ unpack_proof(&share, dim).unwrap();
+
+ let short_share = vec![Field64::from(0); len - 1];
+ assert_matches!(
+ unpack_proof(&short_share, dim),
+ Err(SerializeError::UnpackInputSizeMismatch)
+ );
+ }
+
+ #[test]
+ fn secret_sharing() {
+ let mut share1 = vec![Field32::zero(); 10];
+ share1[3] = 21.into();
+ share1[8] = 123.into();
+
+ let original_data = share1.clone();
+
+ let share2 = secret_share(&mut share1);
+
+ let reconstructed = reconstruct_shares(&share1, &share2).unwrap();
+ assert_eq!(reconstructed, original_data);
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf.rs b/third_party/rust/prio/src/vdaf.rs
new file mode 100644
index 0000000000..f75a2c488b
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf.rs
@@ -0,0 +1,562 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Verifiable Distributed Aggregation Functions (VDAFs) as described in
+//! [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
+use crate::field::{FieldElement, FieldError};
+use crate::flp::FlpError;
+use crate::prng::PrngError;
+use crate::vdaf::prg::Seed;
+use serde::{Deserialize, Serialize};
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::io::Cursor;
+
+/// A component of the domain-separation tag, used to bind the VDAF operations to the document
+/// version. This will be revised with each draft with breaking changes.
+const VERSION: &[u8] = b"vdaf-03";
+/// Length of the domain-separation tag, including document version and algorithm ID.
+const DST_LEN: usize = VERSION.len() + 4;
+
+/// Errors emitted by this module.
+#[derive(Debug, thiserror::Error)]
+pub enum VdafError {
+ /// An error occurred.
+ #[error("vdaf error: {0}")]
+ Uncategorized(String),
+
+ /// Field error.
+ #[error("field error: {0}")]
+ Field(#[from] FieldError),
+
+ /// An error occured while parsing a message.
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+
+ /// FLP error.
+ #[error("flp error: {0}")]
+ Flp(#[from] FlpError),
+
+ /// PRNG error.
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+
+ /// failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+}
+
+/// An additive share of a vector of field elements.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum Share<F, const L: usize> {
+ /// An uncompressed share, typically sent to the leader.
+ Leader(Vec<F>),
+
+ /// A compressed share, typically sent to the helper.
+ Helper(Seed<L>),
+}
+
+impl<F: Clone, const L: usize> Share<F, L> {
+ /// Truncate the Leader's share to the given length. If this is the Helper's share, then this
+ /// method clones the input without modifying it.
+ #[cfg(feature = "prio2")]
+ pub(crate) fn truncated(&self, len: usize) -> Self {
+ match self {
+ Self::Leader(ref data) => Self::Leader(data[..len].to_vec()),
+ Self::Helper(ref seed) => Self::Helper(seed.clone()),
+ }
+ }
+}
+
+/// Parameters needed to decode a [`Share`]
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub(crate) enum ShareDecodingParameter<const L: usize> {
+ Leader(usize),
+ Helper,
+}
+
+impl<F: FieldElement, const L: usize> ParameterizedDecode<ShareDecodingParameter<L>>
+ for Share<F, L>
+{
+ fn decode_with_param(
+ decoding_parameter: &ShareDecodingParameter<L>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ match decoding_parameter {
+ ShareDecodingParameter::Leader(share_length) => {
+ let mut data = Vec::with_capacity(*share_length);
+ for _ in 0..*share_length {
+ data.push(F::decode(bytes)?)
+ }
+ Ok(Self::Leader(data))
+ }
+ ShareDecodingParameter::Helper => {
+ let seed = Seed::decode(bytes)?;
+ Ok(Self::Helper(seed))
+ }
+ }
+ }
+}
+
+impl<F: FieldElement, const L: usize> Encode for Share<F, L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ match self {
+ Share::Leader(share_data) => {
+ for x in share_data {
+ x.encode(bytes);
+ }
+ }
+ Share::Helper(share_seed) => {
+ share_seed.encode(bytes);
+ }
+ }
+ }
+}
+
+/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`],
+/// and [`Collector`], which define the roles of the various parties involved in the execution of
+/// the VDAF.
+// TODO(brandon): once GATs are stabilized [https://github.com/rust-lang/rust/issues/44265],
+// state the "&AggregateShare must implement Into<Vec<u8>>" constraint in terms of a where clause
+// on the associated type instead of a where clause on the trait.
+pub trait Vdaf: Clone + Debug
+where
+ for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
+{
+ /// Algorithm identifier for this VDAF.
+ const ID: u32;
+
+ /// The type of Client measurement to be aggregated.
+ type Measurement: Clone + Debug;
+
+ /// The aggregate result of the VDAF execution.
+ type AggregateResult: Clone + Debug;
+
+ /// The aggregation parameter, used by the Aggregators to map their input shares to output
+ /// shares.
+ type AggregationParam: Clone + Debug + Decode + Encode;
+
+ /// A public share sent by a Client.
+ type PublicShare: Clone + Debug + for<'a> ParameterizedDecode<&'a Self> + Encode;
+
+ /// An input share sent by a Client.
+ type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode;
+
+ /// An output share recovered from an input share by an Aggregator.
+ type OutputShare: Clone + Debug;
+
+ /// An Aggregator's share of the aggregate result.
+ type AggregateShare: Aggregatable<OutputShare = Self::OutputShare> + for<'a> TryFrom<&'a [u8]>;
+
+ /// The number of Aggregators. The Client generates as many input shares as there are
+ /// Aggregators.
+ fn num_aggregators(&self) -> usize;
+}
+
+/// The Client's role in the execution of a VDAF.
+pub trait Client: Vdaf
+where
+ for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
+{
+ /// Shards a measurement into a public share and a sequence of input shares, one for each
+ /// Aggregator.
+ fn shard(
+ &self,
+ measurement: &Self::Measurement,
+ ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError>;
+}
+
+/// The Aggregator's role in the execution of a VDAF.
+pub trait Aggregator<const L: usize>: Vdaf
+where
+ for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
+{
+ /// State of the Aggregator during the Prepare process.
+ type PrepareState: Clone + Debug;
+
+ /// The type of messages broadcast by each aggregator at each round of the Prepare Process.
+ ///
+ /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
+ /// associated with any aggregator involved in the execution of the VDAF.
+ type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;
+
+ /// Result of preprocessing a round of preparation shares.
+ ///
+ /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
+ /// associated with any aggregator involved in the execution of the VDAF.
+ type PrepareMessage: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;
+
+ /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned
+ /// is passed to [`Aggregator::prepare_step`] to get this aggregator's first-round prepare
+ /// message.
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; L],
+ agg_id: usize,
+ agg_param: &Self::AggregationParam,
+ nonce: &[u8],
+ public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>;
+
+ /// Preprocess a round of preparation shares into a single input to [`Aggregator::prepare_step`].
+ fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>(
+ &self,
+ inputs: M,
+ ) -> Result<Self::PrepareMessage, VdafError>;
+
+ /// Compute the next state transition from the current state and the previous round of input
+ /// messages. If this returns [`PrepareTransition::Continue`], then the returned
+ /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
+ /// this round and passed into another call to this method. This continues until this method
+ /// returns [`PrepareTransition::Finish`], at which point the returned output share may be
+ /// aggregated. If the method returns an error, the aggregator should consider its input share
+ /// invalid and not attempt to process it any further.
+ fn prepare_step(
+ &self,
+ state: Self::PrepareState,
+ input: Self::PrepareMessage,
+ ) -> Result<PrepareTransition<Self, L>, VdafError>;
+
+ /// Aggregates a sequence of output shares into an aggregate share.
+ fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
+ &self,
+ agg_param: &Self::AggregationParam,
+ output_shares: M,
+ ) -> Result<Self::AggregateShare, VdafError>;
+}
+
+/// The Collector's role in the execution of a VDAF.
+pub trait Collector: Vdaf
+where
+ for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
+{
+ /// Combines aggregate shares into the aggregate result.
+ fn unshard<M: IntoIterator<Item = Self::AggregateShare>>(
+ &self,
+ agg_param: &Self::AggregationParam,
+ agg_shares: M,
+ num_measurements: usize,
+ ) -> Result<Self::AggregateResult, VdafError>;
+}
+
+/// A state transition of an Aggregator during the Prepare process.
+#[derive(Debug)]
+pub enum PrepareTransition<V: Aggregator<L>, const L: usize>
+where
+ for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
+{
+ /// Continue processing.
+ Continue(V::PrepareState, V::PrepareShare),
+
+ /// Finish processing and return the output share.
+ Finish(V::OutputShare),
+}
+
+/// An aggregate share resulting from aggregating output shares together that
+/// can merged with aggregate shares of the same type.
+pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> {
+ /// Type of output shares that can be accumulated into an aggregate share.
+ type OutputShare;
+
+ /// Update an aggregate share by merging it with another (`agg_share`).
+ fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>;
+
+ /// Update an aggregate share by adding `output share`
+ fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>;
+}
+
+/// An output share comprised of a vector of `F` elements.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct OutputShare<F>(Vec<F>);
+
+impl<F> AsRef<[F]> for OutputShare<F> {
+ fn as_ref(&self) -> &[F] {
+ &self.0
+ }
+}
+
+impl<F> From<Vec<F>> for OutputShare<F> {
+ fn from(other: Vec<F>) -> Self {
+ Self(other)
+ }
+}
+
+impl<F: FieldElement> TryFrom<&[u8]> for OutputShare<F> {
+ type Error = FieldError;
+
+ fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
+ fieldvec_try_from_bytes(bytes)
+ }
+}
+
+impl<F: FieldElement> From<&OutputShare<F>> for Vec<u8> {
+ fn from(output_share: &OutputShare<F>) -> Self {
+ fieldvec_to_vec(&output_share.0)
+ }
+}
+
+/// An aggregate share suitable for VDAFs whose output shares and aggregate
+/// shares are vectors of `F` elements, and an output share needs no special
+/// transformation to be merged into an aggregate share.
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
+pub struct AggregateShare<F>(Vec<F>);
+
+impl<F> AsRef<[F]> for AggregateShare<F> {
+ fn as_ref(&self) -> &[F] {
+ &self.0
+ }
+}
+
+impl<F> From<OutputShare<F>> for AggregateShare<F> {
+ fn from(other: OutputShare<F>) -> Self {
+ Self(other.0)
+ }
+}
+
+impl<F> From<Vec<F>> for AggregateShare<F> {
+ fn from(other: Vec<F>) -> Self {
+ Self(other)
+ }
+}
+
+impl<F: FieldElement> Aggregatable for AggregateShare<F> {
+ type OutputShare = OutputShare<F>;
+
+ fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> {
+ self.sum(agg_share.as_ref())
+ }
+
+ fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> {
+ // For prio3 and poplar1, no conversion is needed between output shares and aggregation
+ // shares.
+ self.sum(output_share.as_ref())
+ }
+}
+
+impl<F: FieldElement> AggregateShare<F> {
+ fn sum(&mut self, other: &[F]) -> Result<(), VdafError> {
+ if self.0.len() != other.len() {
+ return Err(VdafError::Uncategorized(format!(
+ "cannot sum shares of different lengths (left = {}, right = {}",
+ self.0.len(),
+ other.len()
+ )));
+ }
+
+ for (x, y) in self.0.iter_mut().zip(other) {
+ *x += *y;
+ }
+
+ Ok(())
+ }
+}
+
+impl<F: FieldElement> TryFrom<&[u8]> for AggregateShare<F> {
+ type Error = FieldError;
+
+ fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
+ fieldvec_try_from_bytes(bytes)
+ }
+}
+
+impl<F: FieldElement> From<&AggregateShare<F>> for Vec<u8> {
+ fn from(aggregate_share: &AggregateShare<F>) -> Self {
+ fieldvec_to_vec(&aggregate_share.0)
+ }
+}
+
+/// fieldvec_try_from_bytes converts a slice of bytes to a type that is equivalent to a vector of
+/// field elements.
+#[inline(always)]
+fn fieldvec_try_from_bytes<F: FieldElement, T: From<Vec<F>>>(
+ bytes: &[u8],
+) -> Result<T, FieldError> {
+ F::byte_slice_into_vec(bytes).map(T::from)
+}
+
+/// fieldvec_to_vec converts a type that is equivalent to a vector of field elements into a vector
+/// of bytes.
+#[inline(always)]
+fn fieldvec_to_vec<F: FieldElement, T: AsRef<[F]>>(val: T) -> Vec<u8> {
+ F::slice_into_byte_vec(val.as_ref())
+}
+
+#[cfg(test)]
+pub(crate) fn run_vdaf<V, M, const L: usize>(
+ vdaf: &V,
+ agg_param: &V::AggregationParam,
+ measurements: M,
+) -> Result<V::AggregateResult, VdafError>
+where
+ V: Client + Aggregator<L> + Collector,
+ for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
+ M: IntoIterator<Item = V::Measurement>,
+{
+ use rand::prelude::*;
+ let mut verify_key = [0; L];
+ thread_rng().fill(&mut verify_key[..]);
+
+ // NOTE Here we use the same nonce for each measurement for testing purposes. However, this is
+ // not secure. In use, the Aggregators MUST ensure that nonces are unique for each measurement.
+ let nonce = b"this is a nonce";
+
+ let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
+ let mut num_measurements: usize = 0;
+ for measurement in measurements.into_iter() {
+ num_measurements += 1;
+ let (public_share, input_shares) = vdaf.shard(&measurement)?;
+ let out_shares = run_vdaf_prepare(
+ vdaf,
+ &verify_key,
+ agg_param,
+ nonce,
+ public_share,
+ input_shares,
+ )?;
+ for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
+ if let Some(ref mut inner) = agg_share {
+ inner.merge(&out_share.into())?;
+ } else {
+ *agg_share = Some(out_share.into());
+ }
+ }
+ }
+
+ let res = vdaf.unshard(
+ agg_param,
+ agg_shares.into_iter().map(|option| option.unwrap()),
+ num_measurements,
+ )?;
+ Ok(res)
+}
+
+#[cfg(test)]
+pub(crate) fn run_vdaf_prepare<V, M, const L: usize>(
+ vdaf: &V,
+ verify_key: &[u8; L],
+ agg_param: &V::AggregationParam,
+ nonce: &[u8],
+ public_share: V::PublicShare,
+ input_shares: M,
+) -> Result<Vec<V::OutputShare>, VdafError>
+where
+ V: Client + Aggregator<L> + Collector,
+ for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
+ M: IntoIterator<Item = V::InputShare>,
+{
+ let input_shares = input_shares
+ .into_iter()
+ .map(|input_share| input_share.get_encoded());
+
+ let mut states = Vec::new();
+ let mut outbound = Vec::new();
+ for (agg_id, input_share) in input_shares.enumerate() {
+ let (state, msg) = vdaf.prepare_init(
+ verify_key,
+ agg_id,
+ agg_param,
+ nonce,
+ &public_share,
+ &V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
+ .expect("failed to decode input share"),
+ )?;
+ states.push(state);
+ outbound.push(msg.get_encoded());
+ }
+
+ let mut inbound = vdaf
+ .prepare_preprocess(outbound.iter().map(|encoded| {
+ V::PrepareShare::get_decoded_with_param(&states[0], encoded)
+ .expect("failed to decode prep share")
+ }))?
+ .get_encoded();
+
+ let mut out_shares = Vec::new();
+ loop {
+ let mut outbound = Vec::new();
+ for state in states.iter_mut() {
+ match vdaf.prepare_step(
+ state.clone(),
+ V::PrepareMessage::get_decoded_with_param(state, &inbound)
+ .expect("failed to decode prep message"),
+ )? {
+ PrepareTransition::Continue(new_state, msg) => {
+ outbound.push(msg.get_encoded());
+ *state = new_state
+ }
+ PrepareTransition::Finish(out_share) => {
+ out_shares.push(out_share);
+ }
+ }
+ }
+
+ if outbound.len() == vdaf.num_aggregators() {
+ // Another round is required before output shares are computed.
+ inbound = vdaf
+ .prepare_preprocess(outbound.iter().map(|encoded| {
+ V::PrepareShare::get_decoded_with_param(&states[0], encoded)
+ .expect("failed to decode prep share")
+ }))?
+ .get_encoded();
+ } else if outbound.is_empty() {
+ // Each Aggregator recovered an output share.
+ break;
+ } else {
+ panic!("Aggregators did not finish the prepare phase at the same time");
+ }
+ }
+
+ Ok(out_shares)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{AggregateShare, OutputShare};
+ use crate::field::{Field128, Field64, FieldElement};
+ use itertools::iterate;
+ use std::convert::TryFrom;
+ use std::fmt::Debug;
+
+ fn fieldvec_roundtrip_test<F, T>()
+ where
+ F: FieldElement,
+ for<'a> T: Debug + PartialEq + From<Vec<F>> + TryFrom<&'a [u8]>,
+ for<'a> <T as TryFrom<&'a [u8]>>::Error: Debug,
+ for<'a> Vec<u8>: From<&'a T>,
+ {
+ // Generate a value based on an arbitrary vector of field elements.
+ let g = F::generator();
+ let want_value = T::from(iterate(F::one(), |&v| g * v).take(10).collect());
+
+ // Round-trip the value through a byte-vector.
+ let buf: Vec<u8> = (&want_value).into();
+ let got_value = T::try_from(&buf).unwrap();
+
+ assert_eq!(want_value, got_value);
+ }
+
+ #[test]
+ fn roundtrip_output_share() {
+ fieldvec_roundtrip_test::<Field64, OutputShare<Field64>>();
+ fieldvec_roundtrip_test::<Field128, OutputShare<Field128>>();
+ }
+
+ #[test]
+ fn roundtrip_aggregate_share() {
+ fieldvec_roundtrip_test::<Field64, AggregateShare<Field64>>();
+ fieldvec_roundtrip_test::<Field128, AggregateShare<Field128>>();
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+pub mod poplar1;
+pub mod prg;
+#[cfg(feature = "prio2")]
+pub mod prio2;
+pub mod prio3;
+#[cfg(test)]
+mod prio3_test;
diff --git a/third_party/rust/prio/src/vdaf/poplar1.rs b/third_party/rust/prio/src/vdaf/poplar1.rs
new file mode 100644
index 0000000000..f6ab110ebb
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/poplar1.rs
@@ -0,0 +1,933 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! **(NOTE: This module is experimental. Applications should not use it yet.)** This module
+//! partially implements the core component of the Poplar protocol [[BBCG+21]]. Named for the
+//! Poplar1 section of [[draft-irtf-cfrg-vdaf-03]], the specification of this VDAF is under active
+//! development. Thus this code should be regarded as experimental and not compliant with any
+//! existing speciication.
+//!
+//! TODO Make the input shares stateful so that applications can efficiently evaluate the IDPF over
+//! multiple rounds. Question: Will this require API changes to [`crate::vdaf::Vdaf`]?
+//!
+//! TODO Update trait [`Idpf`] so that the IDPF can have a different field type at the leaves than
+//! at the inner nodes.
+//!
+//! TODO Implement the efficient IDPF of [[BBCG+21]]. [`ToyIdpf`] is not space efficient and is
+//! merely intended as a proof-of-concept.
+//!
+//! [BBCG+21]: https://eprint.iacr.org/2021/017
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+use std::cmp::Ordering;
+use std::collections::{BTreeMap, BTreeSet};
+use std::convert::{TryFrom, TryInto};
+use std::fmt::Debug;
+use std::io::Cursor;
+use std::iter::FromIterator;
+use std::marker::PhantomData;
+
+use crate::codec::{
+ decode_u16_items, decode_u24_items, encode_u16_items, encode_u24_items, CodecError, Decode,
+ Encode, ParameterizedDecode,
+};
+use crate::field::{split_vector, FieldElement};
+use crate::fp::log2;
+use crate::prng::Prng;
+use crate::vdaf::prg::{Prg, Seed};
+use crate::vdaf::{
+ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
+ Share, ShareDecodingParameter, Vdaf, VdafError,
+};
+
+/// An input for an IDPF ([`Idpf`]).
+///
+/// TODO Make this an associated type of `Idpf`.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub struct IdpfInput {
+ index: usize,
+ level: usize,
+}
+
+impl IdpfInput {
+ /// Constructs an IDPF input using the first `level` bits of `data`.
+ pub fn new(data: &[u8], level: usize) -> Result<Self, VdafError> {
+ if level > data.len() << 3 {
+ return Err(VdafError::Uncategorized(format!(
+ "desired bit length ({} bits) exceeds data length ({} bytes)",
+ level,
+ data.len()
+ )));
+ }
+
+ let mut index = 0;
+ let mut i = 0;
+ for byte in data {
+ for j in 0..8 {
+ let bit = (byte >> j) & 1;
+ if i < level {
+ index |= (bit as usize) << i;
+ }
+ i += 1;
+ }
+ }
+
+ Ok(Self { index, level })
+ }
+
+ /// Construct a new input that is a prefix of `self`. Bounds checking is performed by the
+ /// caller.
+ fn prefix(&self, level: usize) -> Self {
+ let index = self.index & ((1 << level) - 1);
+ Self { index, level }
+ }
+
+ /// Return the position of `self` in the look-up table of `ToyIdpf`.
+ fn data_index(&self) -> usize {
+ self.index | (1 << self.level)
+ }
+}
+
+impl Ord for IdpfInput {
+ fn cmp(&self, other: &Self) -> Ordering {
+ match self.level.cmp(&other.level) {
+ Ordering::Equal => self.index.cmp(&other.index),
+ ord => ord,
+ }
+ }
+}
+
+impl PartialOrd for IdpfInput {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Encode for IdpfInput {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ (self.index as u64).encode(bytes);
+ (self.level as u64).encode(bytes);
+ }
+}
+
+impl Decode for IdpfInput {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let index = u64::decode(bytes)? as usize;
+ let level = u64::decode(bytes)? as usize;
+
+ Ok(Self { index, level })
+ }
+}
+
+/// An Incremental Distributed Point Function (IDPF), as defined by [[BBCG+21]].
+///
+/// [BBCG+21]: https://eprint.iacr.org/2021/017
+//
+// NOTE(cjpatton) The real IDPF API probably needs to be stateful.
+pub trait Idpf<const KEY_LEN: usize, const OUT_LEN: usize>:
+ Sized + Clone + Debug + Encode + Decode
+{
+ /// The finite field over which the IDPF is defined.
+ //
+ // NOTE(cjpatton) The IDPF of [BBCG+21] might use different fields for different levels of the
+ // prefix tree.
+ type Field: FieldElement;
+
+ /// Generate and return a sequence of IDPF shares for `input`. Parameter `output` is an
+ /// iterator that is invoked to get the output value for each successive level of the prefix
+ /// tree.
+ fn gen<M: IntoIterator<Item = [Self::Field; OUT_LEN]>>(
+ input: &IdpfInput,
+ values: M,
+ ) -> Result<[Self; KEY_LEN], VdafError>;
+
+ /// Evaluate an IDPF share on `prefix`.
+ fn eval(&self, prefix: &IdpfInput) -> Result<[Self::Field; OUT_LEN], VdafError>;
+}
+
+/// A "toy" IDPF used for demonstration purposes. The space consumed by each share is `O(2^n)`,
+/// where `n` is the length of the input. The size of each share is restricted to 1MB, so this IDPF
+/// is only suitable for very short inputs.
+//
+// NOTE(cjpatton) It would be straight-forward to generalize this construction to any `KEY_LEN` and
+// `OUT_LEN`.
+#[derive(Debug, Clone)]
+pub struct ToyIdpf<F> {
+ data0: Vec<F>,
+ data1: Vec<F>,
+ level: usize,
+}
+
+impl<F: FieldElement> Idpf<2, 2> for ToyIdpf<F> {
+ type Field = F;
+
+ fn gen<M: IntoIterator<Item = [Self::Field; 2]>>(
+ input: &IdpfInput,
+ values: M,
+ ) -> Result<[Self; 2], VdafError> {
+ const MAX_DATA_BYTES: usize = 1024 * 1024; // 1MB
+
+ let max_input_len =
+ usize::try_from(log2((MAX_DATA_BYTES / F::ENCODED_SIZE) as u128)).unwrap();
+ if input.level > max_input_len {
+ return Err(VdafError::Uncategorized(format!(
+ "input length ({}) exceeds maximum of ({})",
+ input.level, max_input_len
+ )));
+ }
+
+ let data_len = 1 << (input.level + 1);
+ let mut data0 = vec![F::zero(); data_len];
+ let mut data1 = vec![F::zero(); data_len];
+ let mut values = values.into_iter();
+ for level in 0..input.level + 1 {
+ let value = values.next().unwrap();
+ let index = input.prefix(level).data_index();
+ data0[index] = value[0];
+ data1[index] = value[1];
+ }
+
+ let mut data0 = split_vector(&data0, 2)?.into_iter();
+ let mut data1 = split_vector(&data1, 2)?.into_iter();
+ Ok([
+ ToyIdpf {
+ data0: data0.next().unwrap(),
+ data1: data1.next().unwrap(),
+ level: input.level,
+ },
+ ToyIdpf {
+ data0: data0.next().unwrap(),
+ data1: data1.next().unwrap(),
+ level: input.level,
+ },
+ ])
+ }
+
+ fn eval(&self, prefix: &IdpfInput) -> Result<[F; 2], VdafError> {
+ if prefix.level > self.level {
+ return Err(VdafError::Uncategorized(format!(
+ "prefix length ({}) exceeds input length ({})",
+ prefix.level, self.level
+ )));
+ }
+
+ let index = prefix.data_index();
+ Ok([self.data0[index], self.data1[index]])
+ }
+}
+
+impl<F: FieldElement> Encode for ToyIdpf<F> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ encode_u24_items(bytes, &(), &self.data0);
+ encode_u24_items(bytes, &(), &self.data1);
+ (self.level as u64).encode(bytes);
+ }
+}
+
+impl<F: FieldElement> Decode for ToyIdpf<F> {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let data0 = decode_u24_items(&(), bytes)?;
+ let data1 = decode_u24_items(&(), bytes)?;
+ let level = u64::decode(bytes)? as usize;
+
+ Ok(Self {
+ data0,
+ data1,
+ level,
+ })
+ }
+}
+
+impl Encode for BTreeSet<IdpfInput> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // Encodes the aggregation parameter as a variable length vector of
+ // [`IdpfInput`], because the size of the aggregation parameter is not
+ // determined by the VDAF.
+ let items: Vec<IdpfInput> = self.iter().map(IdpfInput::clone).collect();
+ encode_u24_items(bytes, &(), &items);
+ }
+}
+
+impl Decode for BTreeSet<IdpfInput> {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let inputs = decode_u24_items(&(), bytes)?;
+ Ok(Self::from_iter(inputs.into_iter()))
+ }
+}
+
+/// An input share for the `poplar1` VDAF.
+#[derive(Debug, Clone)]
+pub struct Poplar1InputShare<I: Idpf<2, 2>, const L: usize> {
+ /// IDPF share of input
+ idpf: I,
+
+ /// PRNG seed used to generate the aggregator's share of the randomness used in the first part
+ /// of the sketching protocol.
+ sketch_start_seed: Seed<L>,
+
+ /// Aggregator's share of the randomness used in the second part of the sketching protocol.
+ sketch_next: Share<I::Field, L>,
+}
+
+impl<I: Idpf<2, 2>, const L: usize> Encode for Poplar1InputShare<I, L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.idpf.encode(bytes);
+ self.sketch_start_seed.encode(bytes);
+ self.sketch_next.encode(bytes);
+ }
+}
+
+impl<'a, I, P, const L: usize> ParameterizedDecode<(&'a Poplar1<I, P, L>, usize)>
+ for Poplar1InputShare<I, L>
+where
+ I: Idpf<2, 2>,
+{
+ fn decode_with_param(
+ (poplar1, agg_id): &(&'a Poplar1<I, P, L>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let idpf = I::decode(bytes)?;
+ let sketch_start_seed = Seed::decode(bytes)?;
+ let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
+
+ let share_decoding_parameter = if is_leader {
+ // The sketch is two field elements for every bit of input, plus two more, corresponding
+ // to construction of shares in `Poplar1::shard`.
+ ShareDecodingParameter::Leader((poplar1.input_length + 1) * 2)
+ } else {
+ ShareDecodingParameter::Helper
+ };
+
+ let sketch_next =
+ <Share<I::Field, L>>::decode_with_param(&share_decoding_parameter, bytes)?;
+
+ Ok(Self {
+ idpf,
+ sketch_start_seed,
+ sketch_next,
+ })
+ }
+}
+
+/// The poplar1 VDAF.
+#[derive(Debug)]
+pub struct Poplar1<I, P, const L: usize> {
+ input_length: usize,
+ phantom: PhantomData<(I, P)>,
+}
+
+impl<I, P, const L: usize> Poplar1<I, P, L> {
+ /// Create an instance of the poplar1 VDAF. The caller provides a cipher suite `suite` used for
+ /// deriving pseudorandom sequences of field elements, and a input length in bits, corresponding
+ /// to `BITS` as defined in the [VDAF specification][1].
+ ///
+ /// [1]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+ pub fn new(bits: usize) -> Self {
+ Self {
+ input_length: bits,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<I, P, const L: usize> Clone for Poplar1<I, P, L> {
+ fn clone(&self) -> Self {
+ Self::new(self.input_length)
+ }
+}
+impl<I, P, const L: usize> Vdaf for Poplar1<I, P, L>
+where
+ I: Idpf<2, 2>,
+ P: Prg<L>,
+{
+ // TODO: This currently uses a codepoint reserved for testing purposes. Replace it with
+ // 0x00001000 once the implementation is updated to match draft-irtf-cfrg-vdaf-03.
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = IdpfInput;
+ type AggregateResult = BTreeMap<IdpfInput, u64>;
+ type AggregationParam = BTreeSet<IdpfInput>;
+ type PublicShare = (); // TODO: Replace this when the IDPF from [BBCGGI21] is implemented.
+ type InputShare = Poplar1InputShare<I, L>;
+ type OutputShare = OutputShare<I::Field>;
+ type AggregateShare = AggregateShare<I::Field>;
+
+ fn num_aggregators(&self) -> usize {
+ 2
+ }
+}
+
+impl<I, P, const L: usize> Client for Poplar1<I, P, L>
+where
+ I: Idpf<2, 2>,
+ P: Prg<L>,
+{
+ #[allow(clippy::many_single_char_names)]
+ fn shard(&self, input: &IdpfInput) -> Result<((), Vec<Poplar1InputShare<I, L>>), VdafError> {
+ let idpf_values: Vec<[I::Field; 2]> = Prng::new()?
+ .take(input.level + 1)
+ .map(|k| [I::Field::one(), k])
+ .collect();
+
+ // For each level of the prefix tree, generate correlated randomness that the aggregators use
+ // to validate the output. See [BBCG+21, Appendix C.4].
+ let leader_sketch_start_seed = Seed::generate()?;
+ let helper_sketch_start_seed = Seed::generate()?;
+ let helper_sketch_next_seed = Seed::generate()?;
+ let mut leader_sketch_start_prng: Prng<I::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&leader_sketch_start_seed, b""));
+ let mut helper_sketch_start_prng: Prng<I::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&helper_sketch_start_seed, b""));
+ let mut helper_sketch_next_prng: Prng<I::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&helper_sketch_next_seed, b""));
+ let mut leader_sketch_next: Vec<I::Field> = Vec::with_capacity(2 * idpf_values.len());
+ for value in idpf_values.iter() {
+ let k = value[1];
+
+ // [BBCG+21, Appendix C.4]
+ //
+ // $(a, b, c)$
+ let a = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
+ let b = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
+ let c = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
+
+ // $A = -2a + k$
+ // $B = a^2 + b + -ak + c$
+ let d = k - (a + a);
+ let e = (a * a) + b - (a * k) + c;
+ leader_sketch_next.push(d - helper_sketch_next_prng.get());
+ leader_sketch_next.push(e - helper_sketch_next_prng.get());
+ }
+
+ // Generate IDPF shares of the data and authentication vectors.
+ let idpf_shares = I::gen(input, idpf_values)?;
+
+ Ok((
+ (),
+ vec![
+ Poplar1InputShare {
+ idpf: idpf_shares[0].clone(),
+ sketch_start_seed: leader_sketch_start_seed,
+ sketch_next: Share::Leader(leader_sketch_next),
+ },
+ Poplar1InputShare {
+ idpf: idpf_shares[1].clone(),
+ sketch_start_seed: helper_sketch_start_seed,
+ sketch_next: Share::Helper(helper_sketch_next_seed),
+ },
+ ],
+ ))
+ }
+}
+
+fn get_level(agg_param: &BTreeSet<IdpfInput>) -> Result<usize, VdafError> {
+ let mut level = None;
+ for prefix in agg_param {
+ if let Some(l) = level {
+ if prefix.level != l {
+ return Err(VdafError::Uncategorized(
+ "prefixes must all have the same length".to_string(),
+ ));
+ }
+ } else {
+ level = Some(prefix.level);
+ }
+ }
+
+ match level {
+ Some(level) => Ok(level),
+ None => Err(VdafError::Uncategorized("prefix set is empty".to_string())),
+ }
+}
+
+impl<I, P, const L: usize> Aggregator<L> for Poplar1<I, P, L>
+where
+ I: Idpf<2, 2>,
+ P: Prg<L>,
+{
+ type PrepareState = Poplar1PrepareState<I::Field>;
+ type PrepareShare = Poplar1PrepareMessage<I::Field>;
+ type PrepareMessage = Poplar1PrepareMessage<I::Field>;
+
+ #[allow(clippy::type_complexity)]
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; L],
+ agg_id: usize,
+ agg_param: &BTreeSet<IdpfInput>,
+ nonce: &[u8],
+ _public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ ) -> Result<
+ (
+ Poplar1PrepareState<I::Field>,
+ Poplar1PrepareMessage<I::Field>,
+ ),
+ VdafError,
+ > {
+ let level = get_level(agg_param)?;
+ let is_leader = role_try_from(agg_id)?;
+
+ // Derive the verification randomness.
+ let mut p = P::init(verify_key);
+ p.update(nonce);
+ let mut verify_rand_prng: Prng<I::Field, _> = Prng::from_seed_stream(p.into_seed_stream());
+
+ // Evaluate the IDPF shares and compute the polynomial coefficients.
+ let mut z = [I::Field::zero(); 3];
+ let mut output_share = Vec::with_capacity(agg_param.len());
+ for prefix in agg_param.iter() {
+ let value = input_share.idpf.eval(prefix)?;
+ let (v, k) = (value[0], value[1]);
+ let r = verify_rand_prng.get();
+
+ // [BBCG+21, Appendix C.4]
+ //
+ // $(z_\sigma, z^*_\sigma, z^{**}_\sigma)$
+ let tmp = r * v;
+ z[0] += tmp;
+ z[1] += r * tmp;
+ z[2] += r * k;
+ output_share.push(v);
+ }
+
+ // [BBCG+21, Appendix C.4]
+ //
+ // Add blind shares $(a_\sigma b_\sigma, c_\sigma)$
+ //
+ // NOTE(cjpatton) We can make this faster by a factor of 3 by using three seed shares instead
+ // of one. On the other hand, if the input shares are made stateful, then we could store
+ // the PRNG state theire and avoid fast-forwarding.
+ let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream(
+ &input_share.sketch_start_seed,
+ b"",
+ ))
+ .skip(3 * level);
+ z[0] += prng.next().unwrap();
+ z[1] += prng.next().unwrap();
+ z[2] += prng.next().unwrap();
+
+ let (d, e) = match &input_share.sketch_next {
+ Share::Leader(data) => (data[2 * level], data[2 * level + 1]),
+ Share::Helper(seed) => {
+ let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream(seed, b""))
+ .skip(2 * level);
+ (prng.next().unwrap(), prng.next().unwrap())
+ }
+ };
+
+ let x = if is_leader {
+ I::Field::one()
+ } else {
+ I::Field::zero()
+ };
+
+ Ok((
+ Poplar1PrepareState {
+ sketch: SketchState::RoundOne,
+ output_share: OutputShare(output_share),
+ d,
+ e,
+ x,
+ },
+ Poplar1PrepareMessage(z.to_vec()),
+ ))
+ }
+
+ fn prepare_preprocess<M: IntoIterator<Item = Poplar1PrepareMessage<I::Field>>>(
+ &self,
+ inputs: M,
+ ) -> Result<Poplar1PrepareMessage<I::Field>, VdafError> {
+ let mut output: Option<Vec<I::Field>> = None;
+ let mut count = 0;
+ for data_share in inputs.into_iter() {
+ count += 1;
+ if let Some(ref mut data) = output {
+ if data_share.0.len() != data.len() {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message length: got {}; want {}",
+ data_share.0.len(),
+ data.len(),
+ )));
+ }
+
+ for (x, y) in data.iter_mut().zip(data_share.0.iter()) {
+ *x += *y;
+ }
+ } else {
+ output = Some(data_share.0);
+ }
+ }
+
+ if count != 2 {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message count: got {}; want 2",
+ count,
+ )));
+ }
+
+ Ok(Poplar1PrepareMessage(output.unwrap()))
+ }
+
+ fn prepare_step(
+ &self,
+ mut state: Poplar1PrepareState<I::Field>,
+ msg: Poplar1PrepareMessage<I::Field>,
+ ) -> Result<PrepareTransition<Self, L>, VdafError> {
+ match &state.sketch {
+ SketchState::RoundOne => {
+ if msg.0.len() != 3 {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message length ({:?}): got {}; want 3",
+ state.sketch,
+ msg.0.len(),
+ )));
+ }
+
+ // Compute polynomial coefficients.
+ let z: [I::Field; 3] = msg.0.try_into().unwrap();
+ let y_share =
+ vec![(state.d * z[0]) + state.e + state.x * ((z[0] * z[0]) - z[1] - z[2])];
+
+ state.sketch = SketchState::RoundTwo;
+ Ok(PrepareTransition::Continue(
+ state,
+ Poplar1PrepareMessage(y_share),
+ ))
+ }
+
+ SketchState::RoundTwo => {
+ if msg.0.len() != 1 {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message length ({:?}): got {}; want 1",
+ state.sketch,
+ msg.0.len(),
+ )));
+ }
+
+ let y = msg.0[0];
+ if y != I::Field::zero() {
+ return Err(VdafError::Uncategorized(format!(
+ "output is invalid: polynomial evaluated to {}; want {}",
+ y,
+ I::Field::zero(),
+ )));
+ }
+
+ Ok(PrepareTransition::Finish(state.output_share))
+ }
+ }
+ }
+
+ fn aggregate<M: IntoIterator<Item = OutputShare<I::Field>>>(
+ &self,
+ agg_param: &BTreeSet<IdpfInput>,
+ output_shares: M,
+ ) -> Result<AggregateShare<I::Field>, VdafError> {
+ let mut agg_share = AggregateShare(vec![I::Field::zero(); agg_param.len()]);
+ for output_share in output_shares.into_iter() {
+ agg_share.accumulate(&output_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+/// A prepare message sent exchanged between Poplar1 aggregators
+#[derive(Clone, Debug)]
+pub struct Poplar1PrepareMessage<F>(Vec<F>);
+
+impl<F> AsRef<[F]> for Poplar1PrepareMessage<F> {
+ fn as_ref(&self) -> &[F] {
+ &self.0
+ }
+}
+
+impl<F: FieldElement> Encode for Poplar1PrepareMessage<F> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // TODO: This is encoded as a variable length vector of F, but we may
+ // be able to make this a fixed-length vector for specific Poplar1
+ // instantations
+ encode_u16_items(bytes, &(), &self.0);
+ }
+}
+
+impl<F: FieldElement> ParameterizedDecode<Poplar1PrepareState<F>> for Poplar1PrepareMessage<F> {
+ fn decode_with_param(
+ _decoding_parameter: &Poplar1PrepareState<F>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ // TODO: This is decoded as a variable length vector of F, but we may be
+ // able to make this a fixed-length vector for specific Poplar1
+ // instantiations.
+ let items = decode_u16_items(&(), bytes)?;
+
+ Ok(Self(items))
+ }
+}
+
+/// The state of each Aggregator during the Prepare process.
+#[derive(Clone, Debug)]
+pub struct Poplar1PrepareState<F> {
+ /// State of the secure sketching protocol.
+ sketch: SketchState,
+
+ /// The output share.
+ output_share: OutputShare<F>,
+
+ /// Aggregator's share of $A = -2a + k$.
+ d: F,
+
+ /// Aggregator's share of $B = a^2 + b -ak + c$.
+ e: F,
+
+ /// Equal to 1 if this Aggregator is the "leader" and 0 otherwise.
+ x: F,
+}
+
+#[derive(Clone, Debug)]
+enum SketchState {
+ RoundOne,
+ RoundTwo,
+}
+
+impl<I, P, const L: usize> Collector for Poplar1<I, P, L>
+where
+ I: Idpf<2, 2>,
+ P: Prg<L>,
+{
+ fn unshard<M: IntoIterator<Item = AggregateShare<I::Field>>>(
+ &self,
+ agg_param: &BTreeSet<IdpfInput>,
+ agg_shares: M,
+ _num_measurements: usize,
+ ) -> Result<BTreeMap<IdpfInput, u64>, VdafError> {
+ let mut agg_data = AggregateShare(vec![I::Field::zero(); agg_param.len()]);
+ for agg_share in agg_shares.into_iter() {
+ agg_data.merge(&agg_share)?;
+ }
+
+ let mut agg = BTreeMap::new();
+ for (prefix, count) in agg_param.iter().zip(agg_data.as_ref()) {
+ let count = <I::Field as FieldElement>::Integer::from(*count);
+ let count: u64 = count
+ .try_into()
+ .map_err(|_| VdafError::Uncategorized("aggregate overflow".to_string()))?;
+ agg.insert(*prefix, count);
+ }
+ Ok(agg)
+ }
+}
+
+fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
+ match agg_id {
+ 0 => Ok(true),
+ 1 => Ok(false),
+ _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::field::Field128;
+ use crate::vdaf::prg::PrgAes128;
+ use crate::vdaf::{run_vdaf, run_vdaf_prepare};
+ use rand::prelude::*;
+
+ #[test]
+ fn test_idpf() {
+ // IDPF input equality tests.
+ assert_eq!(
+ IdpfInput::new(b"hello", 40).unwrap(),
+ IdpfInput::new(b"hello", 40).unwrap()
+ );
+ assert_eq!(
+ IdpfInput::new(b"hi", 9).unwrap(),
+ IdpfInput::new(b"ha", 9).unwrap(),
+ );
+ assert_eq!(
+ IdpfInput::new(b"hello", 25).unwrap(),
+ IdpfInput::new(b"help", 25).unwrap()
+ );
+ assert_ne!(
+ IdpfInput::new(b"hello", 40).unwrap(),
+ IdpfInput::new(b"hello", 39).unwrap()
+ );
+ assert_ne!(
+ IdpfInput::new(b"hello", 40).unwrap(),
+ IdpfInput::new(b"hell-", 40).unwrap()
+ );
+
+ // IDPF uniqueness tests
+ let mut unique = BTreeSet::new();
+ assert!(unique.insert(IdpfInput::new(b"hello", 40).unwrap()));
+ assert!(!unique.insert(IdpfInput::new(b"hello", 40).unwrap()));
+ assert!(unique.insert(IdpfInput::new(b"hello", 39).unwrap()));
+ assert!(unique.insert(IdpfInput::new(b"bye", 20).unwrap()));
+
+ // Generate IDPF keys.
+ let input = IdpfInput::new(b"hi", 16).unwrap();
+ let keys = ToyIdpf::<Field128>::gen(
+ &input,
+ std::iter::repeat([Field128::one(), Field128::one()]),
+ )
+ .unwrap();
+
+ // Try evaluating the IDPF keys on all prefixes.
+ for prefix_len in 0..input.level + 1 {
+ let res = eval_idpf(
+ &keys,
+ &input.prefix(prefix_len),
+ &[Field128::one(), Field128::one()],
+ );
+ assert!(res.is_ok(), "prefix_len={} error: {:?}", prefix_len, res);
+ }
+
+ // Try evaluating the IDPF keys on incorrect prefixes.
+ eval_idpf(
+ &keys,
+ &IdpfInput::new(&[2], 2).unwrap(),
+ &[Field128::zero(), Field128::zero()],
+ )
+ .unwrap();
+
+ eval_idpf(
+ &keys,
+ &IdpfInput::new(&[23, 1], 12).unwrap(),
+ &[Field128::zero(), Field128::zero()],
+ )
+ .unwrap();
+ }
+
+ fn eval_idpf<I, const KEY_LEN: usize, const OUT_LEN: usize>(
+ keys: &[I; KEY_LEN],
+ input: &IdpfInput,
+ expected_output: &[I::Field; OUT_LEN],
+ ) -> Result<(), VdafError>
+ where
+ I: Idpf<KEY_LEN, OUT_LEN>,
+ {
+ let mut output = [I::Field::zero(); OUT_LEN];
+ for key in keys {
+ let output_share = key.eval(input)?;
+ for (x, y) in output.iter_mut().zip(output_share) {
+ *x += y;
+ }
+ }
+
+ if expected_output != &output {
+ return Err(VdafError::Uncategorized(format!(
+ "eval_idpf(): unexpected output: got {:?}; want {:?}",
+ output, expected_output
+ )));
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_poplar1() {
+ const INPUT_LEN: usize = 8;
+
+ let vdaf: Poplar1<ToyIdpf<Field128>, PrgAes128, 16> = Poplar1::new(INPUT_LEN);
+ assert_eq!(vdaf.num_aggregators(), 2);
+
+ // Run the VDAF input-distribution algorithm.
+ let input = vec![IdpfInput::new(&[0b0110_1000], INPUT_LEN).unwrap()];
+
+ let mut agg_param = BTreeSet::new();
+ agg_param.insert(input[0]);
+ check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]);
+
+ // Try evaluating the VDAF on each prefix of the input.
+ for prefix_len in 0..input[0].level + 1 {
+ let mut agg_param = BTreeSet::new();
+ agg_param.insert(input[0].prefix(prefix_len));
+ check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]);
+ }
+
+ // Try various prefixes.
+ let prefix_len = 4;
+ let mut agg_param = BTreeSet::new();
+ // At length 4, the next two prefixes are equal. Neither one matches the input.
+ agg_param.insert(IdpfInput::new(&[0b0000_0000], prefix_len).unwrap());
+ agg_param.insert(IdpfInput::new(&[0b0001_0000], prefix_len).unwrap());
+ agg_param.insert(IdpfInput::new(&[0b0000_0001], prefix_len).unwrap());
+ agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap());
+ // At length 4, the next two prefixes are equal. Both match the input.
+ agg_param.insert(IdpfInput::new(&[0b0111_1101], prefix_len).unwrap());
+ agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap());
+ let aggregate = run_vdaf(&vdaf, &agg_param, input.clone()).unwrap();
+ assert_eq!(aggregate.len(), agg_param.len());
+ check_btree(
+ &aggregate,
+ // We put six prefixes in the aggregation parameter, but the vector we get back is only
+ // 4 elements because at the given prefix length, some of the prefixes are equal.
+ &[0, 0, 0, 1],
+ );
+
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ let nonce = b"this is a nonce";
+
+ // Try evaluating the VDAF with an invalid aggregation parameter. (It's an error to have a
+ // mixture of prefix lengths.)
+ let mut agg_param = BTreeSet::new();
+ agg_param.insert(IdpfInput::new(&[0b0000_0111], 6).unwrap());
+ agg_param.insert(IdpfInput::new(&[0b0000_1000], 7).unwrap());
+ let (public_share, input_shares) = vdaf.shard(&input[0]).unwrap();
+ run_vdaf_prepare(
+ &vdaf,
+ &verify_key,
+ &agg_param,
+ nonce,
+ public_share,
+ input_shares,
+ )
+ .unwrap_err();
+
+ // Try evaluating the VDAF with malformed inputs.
+ //
+ // This IDPF key pair evaluates to 1 everywhere, which is illegal.
+ let (public_share, mut input_shares) = vdaf.shard(&input[0]).unwrap();
+ for (i, x) in input_shares[0].idpf.data0.iter_mut().enumerate() {
+ if i != input[0].index {
+ *x += Field128::one();
+ }
+ }
+ let mut agg_param = BTreeSet::new();
+ agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap());
+ run_vdaf_prepare(
+ &vdaf,
+ &verify_key,
+ &agg_param,
+ nonce,
+ public_share,
+ input_shares,
+ )
+ .unwrap_err();
+
+ // This IDPF key pair has a garbled authentication vector.
+ let (public_share, mut input_shares) = vdaf.shard(&input[0]).unwrap();
+ for x in input_shares[0].idpf.data1.iter_mut() {
+ *x = Field128::zero();
+ }
+ let mut agg_param = BTreeSet::new();
+ agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap());
+ run_vdaf_prepare(
+ &vdaf,
+ &verify_key,
+ &agg_param,
+ nonce,
+ public_share,
+ input_shares,
+ )
+ .unwrap_err();
+ }
+
+ fn check_btree(btree: &BTreeMap<IdpfInput, u64>, counts: &[u64]) {
+ for (got, want) in btree.values().zip(counts.iter()) {
+ assert_eq!(got, want, "got {:?} want {:?}", btree.values(), counts);
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prg.rs b/third_party/rust/prio/src/vdaf/prg.rs
new file mode 100644
index 0000000000..a5930f1283
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prg.rs
@@ -0,0 +1,239 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementations of PRGs specified in [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+use crate::vdaf::{CodecError, Decode, Encode};
+#[cfg(feature = "crypto-dependencies")]
+use aes::{
+ cipher::{KeyIvInit, StreamCipher},
+ Aes128,
+};
+#[cfg(feature = "crypto-dependencies")]
+use cmac::{Cmac, Mac};
+#[cfg(feature = "crypto-dependencies")]
+use ctr::Ctr64BE;
+#[cfg(feature = "crypto-dependencies")]
+use std::fmt::Formatter;
+use std::{
+ fmt::Debug,
+ io::{Cursor, Read},
+};
+
+/// Function pointer to fill a buffer with random bytes. Under normal operation,
+/// `getrandom::getrandom()` will be used, but other implementations can be used to control
+/// randomness when generating or verifying test vectors.
+pub(crate) type RandSource = fn(&mut [u8]) -> Result<(), getrandom::Error>;
+
+/// Input of [`Prg`].
+#[derive(Clone, Debug, Eq)]
+pub struct Seed<const L: usize>(pub(crate) [u8; L]);
+
+impl<const L: usize> Seed<L> {
+ /// Generate a uniform random seed.
+ pub fn generate() -> Result<Self, getrandom::Error> {
+ Self::from_rand_source(getrandom::getrandom)
+ }
+
+ pub(crate) fn from_rand_source(rand_source: RandSource) -> Result<Self, getrandom::Error> {
+ let mut seed = [0; L];
+ rand_source(&mut seed)?;
+ Ok(Self(seed))
+ }
+}
+
+impl<const L: usize> AsRef<[u8; L]> for Seed<L> {
+ fn as_ref(&self) -> &[u8; L] {
+ &self.0
+ }
+}
+
+impl<const L: usize> PartialEq for Seed<L> {
+ fn eq(&self, other: &Self) -> bool {
+ // Do constant-time compare.
+ let mut r = 0;
+ for (x, y) in self.0[..].iter().zip(&other.0[..]) {
+ r |= x ^ y;
+ }
+ r == 0
+ }
+}
+
+impl<const L: usize> Encode for Seed<L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&self.0[..]);
+ }
+}
+
+impl<const L: usize> Decode for Seed<L> {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut seed = [0; L];
+ bytes.read_exact(&mut seed)?;
+ Ok(Seed(seed))
+ }
+}
+
+/// A stream of pseudorandom bytes derived from a seed.
+pub trait SeedStream {
+ /// Fill `buf` with the next `buf.len()` bytes of output.
+ fn fill(&mut self, buf: &mut [u8]);
+}
+
+/// A pseudorandom generator (PRG) with the interface specified in [[draft-irtf-cfrg-vdaf-03]].
+///
+/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+pub trait Prg<const L: usize>: Clone + Debug {
+ /// The type of stream produced by this PRG.
+ type SeedStream: SeedStream;
+
+ /// Construct an instance of [`Prg`] with the given seed.
+ fn init(seed_bytes: &[u8; L]) -> Self;
+
+ /// Update the PRG state by passing in the next fragment of the info string. The final info
+ /// string is assembled from the concatenation of sequence of fragments passed to this method.
+ fn update(&mut self, data: &[u8]);
+
+ /// Finalize the PRG state, producing a seed stream.
+ fn into_seed_stream(self) -> Self::SeedStream;
+
+ /// Finalize the PRG state, producing a seed.
+ fn into_seed(self) -> Seed<L> {
+ let mut new_seed = [0; L];
+ let mut seed_stream = self.into_seed_stream();
+ seed_stream.fill(&mut new_seed);
+ Seed(new_seed)
+ }
+
+ /// Construct a seed stream from the given seed and info string.
+ fn seed_stream(seed: &Seed<L>, info: &[u8]) -> Self::SeedStream {
+ let mut prg = Self::init(seed.as_ref());
+ prg.update(info);
+ prg.into_seed_stream()
+ }
+}
+
+/// The PRG based on AES128 as specified in [[draft-irtf-cfrg-vdaf-03]].
+///
+/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+#[derive(Clone, Debug)]
+#[cfg(feature = "crypto-dependencies")]
+pub struct PrgAes128(Cmac<Aes128>);
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prg<16> for PrgAes128 {
+ type SeedStream = SeedStreamAes128;
+
+ fn init(seed_bytes: &[u8; 16]) -> Self {
+ Self(Cmac::new_from_slice(seed_bytes).unwrap())
+ }
+
+ fn update(&mut self, data: &[u8]) {
+ self.0.update(data);
+ }
+
+ fn into_seed_stream(self) -> SeedStreamAes128 {
+ let key = self.0.finalize().into_bytes();
+ SeedStreamAes128::new(&key, &[0; 16])
+ }
+}
+
+/// The key stream produced by AES128 in CTR-mode.
+#[cfg(feature = "crypto-dependencies")]
+pub struct SeedStreamAes128(Ctr64BE<Aes128>);
+
+#[cfg(feature = "crypto-dependencies")]
+impl SeedStreamAes128 {
+ pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self {
+ SeedStreamAes128(Ctr64BE::<Aes128>::new(key.into(), iv.into()))
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl SeedStream for SeedStreamAes128 {
+ fn fill(&mut self, buf: &mut [u8]) {
+ buf.fill(0);
+ self.0.apply_keystream(buf);
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl Debug for SeedStreamAes128 {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ // Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that
+ // with [`cipher::StreamCipherCoreWrapper::get_core`][2].
+ //
+ // [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html
+ // [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html
+ self.0.get_core().fmt(f)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{field::Field128, prng::Prng};
+ use serde::{Deserialize, Serialize};
+ use std::convert::TryInto;
+
+ #[derive(Deserialize, Serialize)]
+ struct PrgTestVector {
+ #[serde(with = "hex")]
+ seed: Vec<u8>,
+ #[serde(with = "hex")]
+ info: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex")]
+ derived_seed: Vec<u8>,
+ #[serde(with = "hex")]
+ expanded_vec_field128: Vec<u8>,
+ }
+
+ // Test correctness of dervied methods.
+ fn test_prg<P, const L: usize>()
+ where
+ P: Prg<L>,
+ {
+ let seed = Seed::generate().unwrap();
+ let info = b"info string";
+
+ let mut prg = P::init(seed.as_ref());
+ prg.update(info);
+
+ let mut want = Seed([0; L]);
+ prg.clone().into_seed_stream().fill(&mut want.0[..]);
+ let got = prg.clone().into_seed();
+ assert_eq!(got, want);
+
+ let mut want = [0; 45];
+ prg.clone().into_seed_stream().fill(&mut want);
+ let mut got = [0; 45];
+ P::seed_stream(&seed, info).fill(&mut got);
+ assert_eq!(got, want);
+ }
+
+ #[test]
+ fn prg_aes128() {
+ let t: PrgTestVector =
+ serde_json::from_str(include_str!("test_vec/03/PrgAes128.json")).unwrap();
+ let mut prg = PrgAes128::init(&t.seed.try_into().unwrap());
+ prg.update(&t.info);
+
+ assert_eq!(
+ prg.clone().into_seed(),
+ Seed(t.derived_seed.try_into().unwrap())
+ );
+
+ let mut bytes = std::io::Cursor::new(t.expanded_vec_field128.as_slice());
+ let mut want = Vec::with_capacity(t.length);
+ while (bytes.position() as usize) < t.expanded_vec_field128.len() {
+ want.push(Field128::decode(&mut bytes).unwrap())
+ }
+ let got: Vec<Field128> = Prng::from_seed_stream(prg.clone().into_seed_stream())
+ .take(t.length)
+ .collect();
+ assert_eq!(got, want);
+
+ test_prg::<PrgAes128, 16>();
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs
new file mode 100644
index 0000000000..47fc076790
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio2.rs
@@ -0,0 +1,425 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Port of the ENPA Prio system to a VDAF. It is backwards compatible with
+//! [`Client`](crate::client::Client) and [`Server`](crate::server::Server).
+
+use crate::{
+ client as v2_client,
+ codec::{CodecError, Decode, Encode, ParameterizedDecode},
+ field::{FieldElement, FieldPrio2},
+ prng::Prng,
+ server as v2_server,
+ util::proof_length,
+ vdaf::{
+ prg::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare,
+ PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError,
+ },
+};
+use ring::hmac;
+use std::{
+ convert::{TryFrom, TryInto},
+ io::Cursor,
+};
+
+/// The Prio2 VDAF. It supports the same measurement type as
+/// [`Prio3Aes128CountVec`](crate::vdaf::prio3::Prio3Aes128CountVec) but uses the proof system
+/// and finite field deployed in ENPA.
+#[derive(Clone, Debug)]
+pub struct Prio2 {
+ input_len: usize,
+}
+
+impl Prio2 {
+ /// Returns an instance of the VDAF for the given input length.
+ pub fn new(input_len: usize) -> Result<Self, VdafError> {
+ let n = (input_len + 1).next_power_of_two();
+ if let Ok(size) = u32::try_from(2 * n) {
+ if size > FieldPrio2::generator_order() {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+ } else {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds memory capacity".into(),
+ ));
+ }
+
+ Ok(Prio2 { input_len })
+ }
+
+ /// Prepare an input share for aggregation using the given field element `query_rand` to
+ /// compute the verifier share.
+ ///
+ /// In the [`Aggregator`](crate::vdaf::Aggregator) trait implementation for [`Prio2`], the
+ /// query randomness is computed jointly by the Aggregators. This method is designed to be used
+ /// in applications, like ENPA, in which the query randomness is instead chosen by a
+ /// third-party.
+ pub fn prepare_init_with_query_rand(
+ &self,
+ query_rand: FieldPrio2,
+ input_share: &Share<FieldPrio2, 32>,
+ is_leader: bool,
+ ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
+ let expanded_data: Option<Vec<FieldPrio2>> = match input_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_prio2_seed(seed.as_ref());
+ Some(prng.take(proof_length(self.input_len)).collect())
+ }
+ };
+ let data = match input_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_data.as_ref().unwrap(),
+ };
+
+ let mut mem = v2_server::ValidationMemory::new(self.input_len);
+ let verifier_share = v2_server::generate_verification_message(
+ self.input_len,
+ query_rand,
+ data, // Combined input and proof shares
+ is_leader,
+ &mut mem,
+ )
+ .map_err(|e| VdafError::Uncategorized(e.to_string()))?;
+
+ Ok((
+ Prio2PrepareState(input_share.truncated(self.input_len)),
+ Prio2PrepareShare(verifier_share),
+ ))
+ }
+}
+
+impl Vdaf for Prio2 {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = Vec<u32>;
+ type AggregateResult = Vec<u32>;
+ type AggregationParam = ();
+ type PublicShare = ();
+ type InputShare = Share<FieldPrio2, 32>;
+ type OutputShare = OutputShare<FieldPrio2>;
+ type AggregateShare = AggregateShare<FieldPrio2>;
+
+ fn num_aggregators(&self) -> usize {
+ // Prio2 can easily be extended to support more than two Aggregators.
+ 2
+ }
+}
+
+impl Client for Prio2 {
+ fn shard(&self, measurement: &Vec<u32>) -> Result<((), Vec<Share<FieldPrio2, 32>>), VdafError> {
+ if measurement.len() != self.input_len {
+ return Err(VdafError::Uncategorized("incorrect input length".into()));
+ }
+ let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len());
+ for int in measurement {
+ input.push((*int).into());
+ }
+
+ let mut mem = v2_client::ClientMemory::new(self.input_len)?;
+ let copy_data = |share_data: &mut [FieldPrio2]| {
+ share_data[..].clone_from_slice(&input);
+ };
+ let mut leader_data = mem.prove_with(self.input_len, copy_data);
+
+ let helper_seed = Seed::generate()?;
+ let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref());
+ for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) {
+ *s1 -= d;
+ }
+
+ Ok((
+ (),
+ vec![Share::Leader(leader_data), Share::Helper(helper_seed)],
+ ))
+ }
+}
+
+/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Prio2PrepareState(Share<FieldPrio2, 32>);
+
+impl Encode for Prio2PrepareState {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes);
+ }
+}
+
+impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
+ fn decode_with_param(
+ (prio2, agg_id): &(&'a Prio2, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let share_decoder = if *agg_id == 0 {
+ ShareDecodingParameter::Leader(prio2.input_len)
+ } else {
+ ShareDecodingParameter::Helper
+ };
+ let out_share = Share::decode_with_param(&share_decoder, bytes)?;
+ Ok(Self(out_share))
+ }
+}
+
+/// Message emitted by each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
+#[derive(Clone, Debug)]
+pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);
+
+impl Encode for Prio2PrepareShare {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.f_r.encode(bytes);
+ self.0.g_r.encode(bytes);
+ self.0.h_r.encode(bytes);
+ }
+}
+
+impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
+ fn decode_with_param(
+ _state: &Prio2PrepareState,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Ok(Self(v2_server::VerificationMessage {
+ f_r: FieldPrio2::decode(bytes)?,
+ g_r: FieldPrio2::decode(bytes)?,
+ h_r: FieldPrio2::decode(bytes)?,
+ }))
+ }
+}
+
+impl Aggregator<32> for Prio2 {
+ type PrepareState = Prio2PrepareState;
+ type PrepareShare = Prio2PrepareShare;
+ type PrepareMessage = ();
+
+ fn prepare_init(
+ &self,
+ agg_key: &[u8; 32],
+ agg_id: usize,
+ _agg_param: &(),
+ nonce: &[u8],
+ _public_share: &Self::PublicShare,
+ input_share: &Share<FieldPrio2, 32>,
+ ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
+ let is_leader = role_try_from(agg_id)?;
+
+ // In the ENPA Prio system, the query randomness is generated by a third party and
+ // distributed to the Aggregators after they receive their input shares. In a VDAF, shared
+ // randomness is derived from a nonce selected by the client. For Prio2 we compute the
+ // query using HMAC-SHA256 evaluated over the nonce.
+ let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, agg_key);
+ let hmac_tag = hmac::sign(&hmac_key, nonce);
+ let query_rand = Prng::from_prio2_seed(hmac_tag.as_ref().try_into().unwrap())
+ .next()
+ .unwrap();
+
+ self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
+ }
+
+ fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>(
+ &self,
+ inputs: M,
+ ) -> Result<(), VdafError> {
+ let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
+ inputs.into_iter().map(|msg| msg.0).collect();
+ if verifier_shares.len() != 2 {
+ return Err(VdafError::Uncategorized(
+ "wrong number of verifier shares".into(),
+ ));
+ }
+
+ if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) {
+ return Err(VdafError::Uncategorized(
+ "proof verifier check failed".into(),
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn prepare_step(
+ &self,
+ state: Prio2PrepareState,
+ _input: (),
+ ) -> Result<PrepareTransition<Self, 32>, VdafError> {
+ let data = match state.0 {
+ Share::Leader(data) => data,
+ Share::Helper(seed) => {
+ let prng = Prng::from_prio2_seed(seed.as_ref());
+ prng.take(self.input_len).collect()
+ }
+ };
+ Ok(PrepareTransition::Finish(OutputShare::from(data)))
+ }
+
+ fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>(
+ &self,
+ _agg_param: &(),
+ out_shares: M,
+ ) -> Result<AggregateShare<FieldPrio2>, VdafError> {
+ let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
+ for out_share in out_shares.into_iter() {
+ agg_share.accumulate(&out_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+impl Collector for Prio2 {
+ fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>(
+ &self,
+ _agg_param: &(),
+ agg_shares: M,
+ _num_measurements: usize,
+ ) -> Result<Vec<u32>, VdafError> {
+ let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
+ for agg_share in agg_shares.into_iter() {
+ agg.merge(&agg_share)?;
+ }
+
+ Ok(agg.0.into_iter().map(u32::from).collect())
+ }
+}
+
+impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> {
+ fn decode_with_param(
+ (prio2, agg_id): &(&'a Prio2, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
+ let decoder = if is_leader {
+ ShareDecodingParameter::Leader(proof_length(prio2.input_len))
+ } else {
+ ShareDecodingParameter::Helper
+ };
+
+ Share::decode_with_param(&decoder, bytes)
+ }
+}
+
+fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
+ match agg_id {
+ 0 => Ok(true),
+ 1 => Ok(false),
+ _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ client::encode_simple,
+ encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey},
+ field::random_vector,
+ server::Server,
+ vdaf::{run_vdaf, run_vdaf_prepare},
+ };
+ use rand::prelude::*;
+
+ const PRIV_KEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
+ const PRIV_KEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";
+
+ #[test]
+ fn run_prio2() {
+ let prio2 = Prio2::new(6).unwrap();
+
+ assert_eq!(
+ run_vdaf(
+ &prio2,
+ &(),
+ [
+ vec![0, 0, 0, 0, 1, 0],
+ vec![0, 1, 0, 0, 0, 0],
+ vec![0, 1, 1, 0, 0, 0],
+ vec![1, 1, 1, 0, 0, 0],
+ vec![0, 0, 0, 0, 1, 1],
+ ]
+ )
+ .unwrap(),
+ vec![1, 3, 2, 0, 2, 1],
+ );
+ }
+
+ #[test]
+ fn enpa_client_interop() {
+ let mut rng = thread_rng();
+ let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
+ let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
+ let pub_key1 = PublicKey::from(&priv_key1);
+ let pub_key2 = PublicKey::from(&priv_key2);
+
+ let data: Vec<FieldPrio2> = [0, 0, 1, 1, 0]
+ .iter()
+ .map(|x| FieldPrio2::from(*x))
+ .collect();
+ let (encrypted_input_share1, encrypted_input_share2) =
+ encode_simple(&data, pub_key1, pub_key2).unwrap();
+
+ let input_share1 = decrypt_share(&encrypted_input_share1, &priv_key1).unwrap();
+ let input_share2 = decrypt_share(&encrypted_input_share2, &priv_key2).unwrap();
+
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let input_shares = vec![
+ Share::get_decoded_with_param(&(&prio2, 0), &input_share1).unwrap(),
+ Share::get_decoded_with_param(&(&prio2, 1), &input_share2).unwrap(),
+ ];
+
+ let verify_key = rng.gen();
+ let mut nonce = [0; 16];
+ rng.fill(&mut nonce);
+ run_vdaf_prepare(&prio2, &verify_key, &(), &nonce, (), input_shares).unwrap();
+ }
+
+ #[test]
+ fn enpa_server_interop() {
+ let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
+ let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
+ let pub_key1 = PublicKey::from(&priv_key1);
+ let pub_key2 = PublicKey::from(&priv_key2);
+
+ let data = vec![0, 0, 1, 1, 0];
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let (_public_share, input_shares) = prio2.shard(&data).unwrap();
+
+ let encrypted_input_share1 =
+ encrypt_share(&input_shares[0].get_encoded(), &pub_key1).unwrap();
+ let encrypted_input_share2 =
+ encrypt_share(&input_shares[1].get_encoded(), &pub_key2).unwrap();
+
+ let mut server1 = Server::new(data.len(), true, priv_key1).unwrap();
+ let mut server2 = Server::new(data.len(), false, priv_key2).unwrap();
+
+ let eval_at: FieldPrio2 = random_vector(1).unwrap()[0];
+ let verifier1 = server1
+ .generate_verification_message(eval_at, &encrypted_input_share1)
+ .unwrap();
+ let verifier2 = server2
+ .generate_verification_message(eval_at, &encrypted_input_share2)
+ .unwrap();
+
+ server1
+ .aggregate(&encrypted_input_share1, &verifier1, &verifier2)
+ .unwrap();
+ server2
+ .aggregate(&encrypted_input_share2, &verifier1, &verifier2)
+ .unwrap();
+ }
+
+ #[test]
+ fn prepare_state_serialization() {
+ let mut verify_key = [0; 32];
+ thread_rng().fill(&mut verify_key[..]);
+ let data = vec![0, 0, 1, 1, 0];
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let (public_share, input_shares) = prio2.shard(&data).unwrap();
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (want, _msg) = prio2
+ .prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)
+ .unwrap();
+ let got =
+ Prio2PrepareState::get_decoded_with_param(&(&prio2, agg_id), &want.get_encoded())
+ .expect("failed to decode prepare step");
+ assert_eq!(got, want);
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs
new file mode 100644
index 0000000000..31853f15ab
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio3.rs
@@ -0,0 +1,1168 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! **WARNING:** Neither this code nor the cryptographic construction it implements has undergone
+//! significant security analysis. Use at your own risk.
+//!
+//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented
+//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO
+//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication
+//! cost.
+//!
+//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-03]] into
+//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of
+//! which are instantiated here:
+//!
+//! - [`Prio3Aes128Count`] for aggregating a counter (*)
+//! - [`Prio3Aes128CountVec`] for aggregating a vector of counters
+//! - [`Prio3Aes128Sum`] for copmputing the sum of integers (*)
+//! - [`Prio3Aes128Histogram`] for estimating a distribution via a histogram (*)
+//!
+//! Additional types can be constructed from [`Prio3`] as needed.
+//!
+//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! [BBCG+19]: https://ia.cr/2019/188
+//! [CGB17]: https://crypto.stanford.edu/prio/
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+#[cfg(feature = "crypto-dependencies")]
+use super::prg::PrgAes128;
+use super::{DST_LEN, VERSION};
+use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
+use crate::field::FieldElement;
+#[cfg(feature = "crypto-dependencies")]
+use crate::field::{Field128, Field64};
+#[cfg(feature = "multithreaded")]
+use crate::flp::gadgets::ParallelSumMultithreaded;
+#[cfg(feature = "crypto-dependencies")]
+use crate::flp::gadgets::{BlindPolyEval, ParallelSum};
+#[cfg(feature = "crypto-dependencies")]
+use crate::flp::types::{Average, Count, CountVec, Histogram, Sum};
+use crate::flp::Type;
+use crate::prng::Prng;
+use crate::vdaf::prg::{Prg, RandSource, Seed};
+use crate::vdaf::{
+ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
+ Share, ShareDecodingParameter, Vdaf, VdafError,
+};
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::io::Cursor;
+use std::iter::IntoIterator;
+use std::marker::PhantomData;
+
+/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Count = Prio3<Count<Field64>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Count {
+ /// Construct an instance of Prio3Aes128Count with the given number of aggregators.
+ pub fn new_aes128_count(num_aggregators: u8) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, Count::new())
+ }
+}
+
+/// The count-vector type. Each measurement is a vector of integers in `[0,2)` and the aggregate is
+/// the element-wise sum.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128CountVec =
+ Prio3<CountVec<Field128, ParallelSum<Field128, BlindPolyEval<Field128>>>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128CountVec {
+ /// Construct an instance of Prio3Aes1238CountVec with the given number of aggregators. `len`
+ /// defines the length of each measurement.
+ pub fn new_aes128_count_vec(num_aggregators: u8, len: usize) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, CountVec::new(len))
+ }
+}
+
+/// Like [`Prio3Aes128CountVec`] except this type uses multithreading to improve sharding and
+/// preparation time. Note that the improvement is only noticeable for very large input lengths,
+/// e.g., 201 and up. (Your system's mileage may vary.)
+#[cfg(feature = "multithreaded")]
+#[cfg(feature = "crypto-dependencies")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+pub type Prio3Aes128CountVecMultithreaded = Prio3<
+ CountVec<Field128, ParallelSumMultithreaded<Field128, BlindPolyEval<Field128>>>,
+ PrgAes128,
+ 16,
+>;
+
+#[cfg(feature = "multithreaded")]
+#[cfg(feature = "crypto-dependencies")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+impl Prio3Aes128CountVecMultithreaded {
+ /// Construct an instance of Prio3Aes1238CountVecMultithreaded with the given number of
+ /// aggregators. `len` defines the length of each measurement.
+ pub fn new_aes128_count_vec_multithreaded(
+ num_aggregators: u8,
+ len: usize,
+ ) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, CountVec::new(len))
+ }
+}
+
+/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
+/// aggregate is the sum.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Sum = Prio3<Sum<Field128>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Sum {
+ /// Construct an instance of Prio3Aes128Sum with the given number of aggregators and required
+ /// bit length. The bit length must not exceed 64.
+ pub fn new_aes128_sum(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({}) exceeds limit for aggregate type (64)",
+ bits
+ )));
+ }
+
+ Prio3::new(num_aggregators, Sum::new(bits as usize)?)
+ }
+}
+
+/// The histogram type. Each measurement is an unsigned integer and the result is a histogram
+/// representation of the distribution. The bucket boundaries are fixed in advance.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Histogram = Prio3<Histogram<Field128>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Histogram {
+ /// Constructs an instance of Prio3Aes128Histogram with the given number of aggregators and
+ /// desired histogram bucket boundaries.
+ pub fn new_aes128_histogram(num_aggregators: u8, buckets: &[u64]) -> Result<Self, VdafError> {
+ let buckets = buckets.iter().map(|bucket| *bucket as u128).collect();
+
+ Prio3::new(num_aggregators, Histogram::new(buckets)?)
+ }
+}
+
+/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and
+/// the aggregate is the arithmetic average.
+#[cfg(feature = "crypto-dependencies")]
+pub type Prio3Aes128Average = Prio3<Average<Field128>, PrgAes128, 16>;
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prio3Aes128Average {
+ /// Construct an instance of Prio3Aes128Average with the given number of aggregators and
+ /// required bit length. The bit length must not exceed 64.
+ pub fn new_aes128_average(num_aggregators: u8, bits: u32) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({}) exceeds limit for aggregate type (64)",
+ bits
+ )));
+ }
+
+ Ok(Prio3 {
+ num_aggregators,
+ typ: Average::new(bits as usize)?,
+ phantom: PhantomData,
+ })
+ }
+}
+
+/// The base type for Prio3.
+///
+/// An instance of Prio3 is determined by:
+///
+/// - a [`Type`](crate::flp::Type) that defines the set of valid input measurements; and
+/// - a [`Prg`](crate::vdaf::prg::Prg) for deriving vectors of field elements from seeds.
+///
+/// New instances can be defined by aliasing the base type. For example, [`Prio3Aes128Count`] is an
+/// alias for `Prio3<Count<Field64>, PrgAes128, 16>`.
+///
+/// ```
+/// use prio::vdaf::{
+/// Aggregator, Client, Collector, PrepareTransition,
+/// prio3::Prio3,
+/// };
+/// use rand::prelude::*;
+///
+/// let num_shares = 2;
+/// let vdaf = Prio3::new_aes128_count(num_shares).unwrap();
+///
+/// let mut out_shares = vec![vec![]; num_shares.into()];
+/// let mut rng = thread_rng();
+/// let verify_key = rng.gen();
+/// let measurements = [0, 1, 1, 1, 0];
+/// for measurement in measurements {
+/// // Shard
+/// let (public_share, input_shares) = vdaf.shard(&measurement).unwrap();
+/// let mut nonce = [0; 16];
+/// rng.fill(&mut nonce);
+///
+/// // Prepare
+/// let mut prep_states = vec![];
+/// let mut prep_shares = vec![];
+/// for (agg_id, input_share) in input_shares.iter().enumerate() {
+/// let (state, share) = vdaf.prepare_init(
+/// &verify_key,
+/// agg_id,
+/// &(),
+/// &nonce,
+/// &public_share,
+/// input_share
+/// ).unwrap();
+/// prep_states.push(state);
+/// prep_shares.push(share);
+/// }
+/// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap();
+///
+/// for (agg_id, state) in prep_states.into_iter().enumerate() {
+/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() {
+/// PrepareTransition::Finish(out_share) => out_share,
+/// _ => panic!("unexpected transition"),
+/// };
+/// out_shares[agg_id].push(out_share);
+/// }
+/// }
+///
+/// // Aggregate
+/// let agg_shares = out_shares.into_iter()
+/// .map(|o| vdaf.aggregate(&(), o).unwrap());
+///
+/// // Unshard
+/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap();
+/// assert_eq!(agg_res, 3);
+/// ```
+///
+/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+#[derive(Clone, Debug)]
+pub struct Prio3<T, P, const L: usize>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ num_aggregators: u8,
+ typ: T,
+ phantom: PhantomData<P>,
+}
+
+impl<T, P, const L: usize> Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the
+ /// underlying type.
+ pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+ Ok(Self {
+ num_aggregators,
+ typ,
+ phantom: PhantomData,
+ })
+ }
+
+ /// The output length of the underlying FLP.
+ pub fn output_len(&self) -> usize {
+ self.typ.output_len()
+ }
+
+ /// The verifier length of the underlying FLP.
+ pub fn verifier_len(&self) -> usize {
+ self.typ.verifier_len()
+ }
+
+ fn derive_joint_randomness<'a>(parts: impl Iterator<Item = &'a Seed<L>>) -> Seed<L> {
+ let mut info = [0; VERSION.len() + 5];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..VERSION.len() + 4].copy_from_slice(&Self::ID.to_be_bytes());
+ info[VERSION.len() + 4] = 255;
+ let mut deriver = P::init(&[0; L]);
+ deriver.update(&info);
+ for part in parts {
+ deriver.update(part.as_ref());
+ }
+ deriver.into_seed()
+ }
+
+ fn shard_with_rand_source(
+ &self,
+ measurement: &T::Measurement,
+ rand_source: RandSource,
+ ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
+ let mut info = [0; DST_LEN + 1];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
+
+ let num_aggregators = self.num_aggregators;
+ let input = self.typ.encode_measurement(measurement)?;
+
+ // Generate the input shares and compute the joint randomness.
+ let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
+ let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
+ Some(Vec::with_capacity(num_aggregators as usize - 1))
+ } else {
+ None
+ };
+ let mut leader_input_share = input.clone();
+ for agg_id in 1..num_aggregators {
+ let helper = HelperShare::from_rand_source(rand_source)?;
+
+ let mut deriver = P::init(helper.joint_rand_param.blind.as_ref());
+ info[DST_LEN] = agg_id;
+ deriver.update(&info);
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&helper.input_share, &info));
+ for (x, y) in leader_input_share
+ .iter_mut()
+ .zip(prng)
+ .take(self.typ.input_len())
+ {
+ *x -= y;
+ deriver.update(&y.into());
+ }
+
+ if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() {
+ helper_joint_rand_parts.push(deriver.into_seed());
+ }
+ helper_shares.push(helper);
+ }
+
+ let leader_blind = Seed::from_rand_source(rand_source)?;
+
+ info[DST_LEN] = 0; // ID of the leader
+ let mut deriver = P::init(leader_blind.as_ref());
+ deriver.update(&info);
+ for x in leader_input_share.iter() {
+ deriver.update(&(*x).into());
+ }
+
+ let leader_joint_rand_seed_part = deriver.into_seed();
+
+ // Compute the joint randomness seed.
+ let joint_rand_seed = helper_joint_rand_parts.as_ref().map(|parts| {
+ Self::derive_joint_randomness(
+ std::iter::once(&leader_joint_rand_seed_part).chain(parts.iter()),
+ )
+ });
+
+ // Run the proof-generation algorithm.
+ let domain_separation_tag = &info[..DST_LEN];
+ let joint_rand: Vec<T::Field> = joint_rand_seed
+ .map(|joint_rand_seed| {
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
+ prng.take(self.typ.joint_rand_len()).collect()
+ })
+ .unwrap_or_default();
+ let prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
+ &Seed::from_rand_source(rand_source)?,
+ domain_separation_tag,
+ ));
+ let prove_rand: Vec<T::Field> = prng.take(self.typ.prove_rand_len()).collect();
+ let mut leader_proof_share = self.typ.prove(&input, &prove_rand, &joint_rand)?;
+
+ // Generate the proof shares and distribute the joint randomness seed hints.
+ for (j, helper) in helper_shares.iter_mut().enumerate() {
+ info[DST_LEN] = j as u8 + 1;
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&helper.proof_share, &info));
+ for (x, y) in leader_proof_share
+ .iter_mut()
+ .zip(prng)
+ .take(self.typ.proof_len())
+ {
+ *x -= y;
+ }
+
+ if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_ref() {
+ let mut hint = Vec::with_capacity(num_aggregators as usize - 1);
+ hint.push(leader_joint_rand_seed_part.clone());
+ hint.extend(helper_joint_rand_parts[..j].iter().cloned());
+ hint.extend(helper_joint_rand_parts[j + 1..].iter().cloned());
+ helper.joint_rand_param.seed_hint = hint;
+ }
+ }
+
+ let leader_joint_rand_param = if self.typ.joint_rand_len() > 0 {
+ Some(JointRandParam {
+ seed_hint: helper_joint_rand_parts.unwrap_or_default(),
+ blind: leader_blind,
+ })
+ } else {
+ None
+ };
+
+ // Prep the output messages.
+ let mut out = Vec::with_capacity(num_aggregators as usize);
+ out.push(Prio3InputShare {
+ input_share: Share::Leader(leader_input_share),
+ proof_share: Share::Leader(leader_proof_share),
+ joint_rand_param: leader_joint_rand_param,
+ });
+
+ for helper in helper_shares.into_iter() {
+ let helper_joint_rand_param = if self.typ.joint_rand_len() > 0 {
+ Some(helper.joint_rand_param)
+ } else {
+ None
+ };
+
+ out.push(Prio3InputShare {
+ input_share: Share::Helper(helper.input_share),
+ proof_share: Share::Helper(helper.proof_share),
+ joint_rand_param: helper_joint_rand_param,
+ });
+ }
+
+ Ok(out)
+ }
+
+ /// Shard measurement with constant randomness of repeated bytes.
+ /// This method is not secure. It is used for running test vectors for Prio3.
+ #[cfg(feature = "test-util")]
+ pub fn test_vec_shard(
+ &self,
+ measurement: &T::Measurement,
+ ) -> Result<Vec<Prio3InputShare<T::Field, L>>, VdafError> {
+ self.shard_with_rand_source(measurement, |buf| {
+ buf.fill(1);
+ Ok(())
+ })
+ }
+
+ fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
+ if agg_id >= self.num_aggregators as usize {
+ return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
+ }
+ Ok(u8::try_from(agg_id).unwrap())
+ }
+}
+
+impl<T, P, const L: usize> Vdaf for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ const ID: u32 = T::ID;
+ type Measurement = T::Measurement;
+ type AggregateResult = T::AggregateResult;
+ type AggregationParam = ();
+ type PublicShare = ();
+ type InputShare = Prio3InputShare<T::Field, L>;
+ type OutputShare = OutputShare<T::Field>;
+ type AggregateShare = AggregateShare<T::Field>;
+
+ fn num_aggregators(&self) -> usize {
+ self.num_aggregators as usize
+ }
+}
+
+/// Message sent by the [`Client`](crate::vdaf::Client) to each
+/// [`Aggregator`](crate::vdaf::Aggregator) during the Sharding phase.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Prio3InputShare<F, const L: usize> {
+ /// The input share.
+ input_share: Share<F, L>,
+
+ /// The proof share.
+ proof_share: Share<F, L>,
+
+ /// Parameters used by the Aggregator to compute the joint randomness. This field is optional
+ /// because not every [`Type`](`crate::flp::Type`) requires joint randomness.
+ joint_rand_param: Option<JointRandParam<L>>,
+}
+
+impl<F: FieldElement, const L: usize> Encode for Prio3InputShare<F, L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if matches!(
+ (&self.input_share, &self.proof_share),
+ (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
+ ) {
+ panic!("tried to encode input share with ambiguous encoding")
+ }
+
+ self.input_share.encode(bytes);
+ self.proof_share.encode(bytes);
+ if let Some(ref param) = self.joint_rand_param {
+ param.blind.encode(bytes);
+ for part in param.seed_hint.iter() {
+ part.encode(bytes);
+ }
+ }
+ }
+}
+
+impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
+ for Prio3InputShare<T::Field, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ fn decode_with_param(
+ (prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let agg_id = prio3
+ .role_try_from(*agg_id)
+ .map_err(|e| CodecError::Other(Box::new(e)))?;
+ let (input_decoder, proof_decoder) = if agg_id == 0 {
+ (
+ ShareDecodingParameter::Leader(prio3.typ.input_len()),
+ ShareDecodingParameter::Leader(prio3.typ.proof_len()),
+ )
+ } else {
+ (
+ ShareDecodingParameter::Helper,
+ ShareDecodingParameter::Helper,
+ )
+ };
+
+ let input_share = Share::decode_with_param(&input_decoder, bytes)?;
+ let proof_share = Share::decode_with_param(&proof_decoder, bytes)?;
+ let joint_rand_param = if prio3.typ.joint_rand_len() > 0 {
+ let num_aggregators = prio3.num_aggregators();
+ let blind = Seed::decode(bytes)?;
+ let seed_hint = std::iter::repeat_with(|| Seed::decode(bytes))
+ .take(num_aggregators - 1)
+ .collect::<Result<Vec<_>, _>>()?;
+ Some(JointRandParam { blind, seed_hint })
+ } else {
+ None
+ };
+
+ Ok(Prio3InputShare {
+ input_share,
+ proof_share,
+ joint_rand_param,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+/// Message broadcast by each [`Aggregator`](crate::vdaf::Aggregator) in each round of the
+/// Preparation phase.
+pub struct Prio3PrepareShare<F, const L: usize> {
+ /// A share of the FLP verifier message. (See [`Type`](crate::flp::Type).)
+ verifier: Vec<F>,
+
+ /// A part of the joint randomness seed.
+ joint_rand_part: Option<Seed<L>>,
+}
+
+impl<F: FieldElement, const L: usize> Encode for Prio3PrepareShare<F, L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ for x in &self.verifier {
+ x.encode(bytes);
+ }
+ if let Some(ref seed) = self.joint_rand_part {
+ seed.encode(bytes);
+ }
+ }
+}
+
+impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
+ for Prio3PrepareShare<F, L>
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3PrepareState<F, L>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len);
+ for _ in 0..decoding_parameter.verifier_len {
+ verifier.push(F::decode(bytes)?);
+ }
+
+ let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareShare {
+ verifier,
+ joint_rand_part,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+/// Result of combining a round of [`Prio3PrepareShare`] messages.
+pub struct Prio3PrepareMessage<const L: usize> {
+ /// The joint randomness seed computed by the Aggregators.
+ joint_rand_seed: Option<Seed<L>>,
+}
+
+impl<const L: usize> Encode for Prio3PrepareMessage<L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encode(bytes);
+ }
+ }
+}
+
+impl<F: FieldElement, const L: usize> ParameterizedDecode<Prio3PrepareState<F, L>>
+ for Prio3PrepareMessage<L>
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3PrepareState<F, L>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareMessage { joint_rand_seed })
+ }
+}
+
+impl<T, P, const L: usize> Client for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ #[allow(clippy::type_complexity)]
+ fn shard(
+ &self,
+ measurement: &T::Measurement,
+ ) -> Result<((), Vec<Prio3InputShare<T::Field, L>>), VdafError> {
+ self.shard_with_rand_source(measurement, getrandom::getrandom)
+ .map(|input_shares| ((), input_shares))
+ }
+}
+
+/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Prio3PrepareState<F, const L: usize> {
+ input_share: Share<F, L>,
+ joint_rand_seed: Option<Seed<L>>,
+ agg_id: u8,
+ verifier_len: usize,
+}
+
+impl<F: FieldElement, const L: usize> Encode for Prio3PrepareState<F, L> {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.input_share.encode(bytes);
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encode(bytes);
+ }
+ }
+}
+
+impl<'a, T, P, const L: usize> ParameterizedDecode<(&'a Prio3<T, P, L>, usize)>
+ for Prio3PrepareState<T::Field, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ fn decode_with_param(
+ (prio3, agg_id): &(&'a Prio3<T, P, L>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let agg_id = prio3
+ .role_try_from(*agg_id)
+ .map_err(|e| CodecError::Other(Box::new(e)))?;
+
+ let share_decoder = if agg_id == 0 {
+ ShareDecodingParameter::Leader(prio3.typ.input_len())
+ } else {
+ ShareDecodingParameter::Helper
+ };
+ let input_share = Share::decode_with_param(&share_decoder, bytes)?;
+
+ let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Self {
+ input_share,
+ joint_rand_seed,
+ agg_id,
+ verifier_len: prio3.typ.verifier_len(),
+ })
+ }
+}
+
+impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ type PrepareState = Prio3PrepareState<T::Field, L>;
+ type PrepareShare = Prio3PrepareShare<T::Field, L>;
+ type PrepareMessage = Prio3PrepareMessage<L>;
+
+ /// Begins the Prep process with the other aggregators. The result of this process is
+ /// the aggregator's output share.
+ #[allow(clippy::type_complexity)]
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; L],
+ agg_id: usize,
+ _agg_param: &(),
+ nonce: &[u8],
+ _public_share: &(),
+ msg: &Prio3InputShare<T::Field, L>,
+ ) -> Result<
+ (
+ Prio3PrepareState<T::Field, L>,
+ Prio3PrepareShare<T::Field, L>,
+ ),
+ VdafError,
+ > {
+ let agg_id = self.role_try_from(agg_id)?;
+ let mut info = [0; DST_LEN + 1];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
+ info[DST_LEN] = agg_id;
+ let domain_separation_tag = &info[..DST_LEN];
+
+ let mut deriver = P::init(verify_key);
+ deriver.update(domain_separation_tag);
+ deriver.update(&[255]);
+ deriver.update(nonce);
+ let query_rand_prng = Prng::from_seed_stream(deriver.into_seed_stream());
+
+ // Create a reference to the (expanded) input share.
+ let expanded_input_share: Option<Vec<T::Field>> = match msg.input_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
+ Some(prng.take(self.typ.input_len()).collect())
+ }
+ };
+ let input_share = match msg.input_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_input_share.as_ref().unwrap(),
+ };
+
+ // Create a reference to the (expanded) proof share.
+ let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_seed_stream(P::seed_stream(seed, &info));
+ Some(prng.take(self.typ.proof_len()).collect())
+ }
+ };
+ let proof_share = match msg.proof_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_proof_share.as_ref().unwrap(),
+ };
+
+ // Compute the joint randomness.
+ let (joint_rand_seed, joint_rand_seed_part, joint_rand) = if self.typ.joint_rand_len() > 0 {
+ let mut deriver = P::init(msg.joint_rand_param.as_ref().unwrap().blind.as_ref());
+ deriver.update(&info);
+ for x in input_share {
+ deriver.update(&(*x).into());
+ }
+ let joint_rand_seed_part = deriver.into_seed();
+
+ let hints = &msg.joint_rand_param.as_ref().unwrap().seed_hint;
+ let joint_rand_seed = Self::derive_joint_randomness(
+ hints[..agg_id as usize]
+ .iter()
+ .chain(std::iter::once(&joint_rand_seed_part))
+ .chain(hints[agg_id as usize..].iter()),
+ );
+
+ let prng: Prng<T::Field, _> =
+ Prng::from_seed_stream(P::seed_stream(&joint_rand_seed, domain_separation_tag));
+ (
+ Some(joint_rand_seed),
+ Some(joint_rand_seed_part),
+ prng.take(self.typ.joint_rand_len()).collect(),
+ )
+ } else {
+ (None, None, Vec::new())
+ };
+
+ // Compute the query randomness.
+ let query_rand: Vec<T::Field> = query_rand_prng.take(self.typ.query_rand_len()).collect();
+
+ // Run the query-generation algorithm.
+ let verifier_share = self.typ.query(
+ input_share,
+ proof_share,
+ &query_rand,
+ &joint_rand,
+ self.num_aggregators as usize,
+ )?;
+
+ Ok((
+ Prio3PrepareState {
+ input_share: msg.input_share.clone(),
+ joint_rand_seed,
+ agg_id,
+ verifier_len: verifier_share.len(),
+ },
+ Prio3PrepareShare {
+ verifier: verifier_share,
+ joint_rand_part: joint_rand_seed_part,
+ },
+ ))
+ }
+
+ fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, L>>>(
+ &self,
+ inputs: M,
+ ) -> Result<Prio3PrepareMessage<L>, VdafError> {
+ let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()];
+ let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
+ let mut count = 0;
+ for share in inputs.into_iter() {
+ count += 1;
+
+ if share.verifier.len() != verifier.len() {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected verifier share length: got {}; want {}",
+ share.verifier.len(),
+ verifier.len(),
+ )));
+ }
+
+ if self.typ.joint_rand_len() > 0 {
+ let joint_rand_seed_part = share.joint_rand_part.unwrap();
+ joint_rand_parts.push(joint_rand_seed_part);
+ }
+
+ for (x, y) in verifier.iter_mut().zip(share.verifier) {
+ *x += y;
+ }
+ }
+
+ if count != self.num_aggregators {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message count: got {}; want {}",
+ count, self.num_aggregators,
+ )));
+ }
+
+ // Check the proof verifier.
+ match self.typ.decide(&verifier) {
+ Ok(true) => (),
+ Ok(false) => {
+ return Err(VdafError::Uncategorized(
+ "proof verifier check failed".into(),
+ ))
+ }
+ Err(err) => return Err(VdafError::from(err)),
+ };
+
+ let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
+ Some(Self::derive_joint_randomness(joint_rand_parts.iter()))
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareMessage { joint_rand_seed })
+ }
+
+ fn prepare_step(
+ &self,
+ step: Prio3PrepareState<T::Field, L>,
+ msg: Prio3PrepareMessage<L>,
+ ) -> Result<PrepareTransition<Self, L>, VdafError> {
+ if self.typ.joint_rand_len() > 0 {
+ // Check that the joint randomness was correct.
+ if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() {
+ return Err(VdafError::Uncategorized(
+ "joint randomness mismatch".to_string(),
+ ));
+ }
+ }
+
+ // Compute the output share.
+ let input_share = match step.input_share {
+ Share::Leader(data) => data,
+ Share::Helper(seed) => {
+ let mut info = [0; DST_LEN + 1];
+ info[..VERSION.len()].copy_from_slice(VERSION);
+ info[VERSION.len()..DST_LEN].copy_from_slice(&Self::ID.to_be_bytes());
+ info[DST_LEN] = step.agg_id;
+ let prng = Prng::from_seed_stream(P::seed_stream(&seed, &info));
+ prng.take(self.typ.input_len()).collect()
+ }
+ };
+
+ let output_share = match self.typ.truncate(input_share) {
+ Ok(data) => OutputShare(data),
+ Err(err) => {
+ return Err(VdafError::from(err));
+ }
+ };
+
+ Ok(PrepareTransition::Finish(output_share))
+ }
+
+ /// Aggregates a sequence of output shares into an aggregate share.
+ fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
+ &self,
+ _agg_param: &(),
+ output_shares: It,
+ ) -> Result<AggregateShare<T::Field>, VdafError> {
+ let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
+ for output_share in output_shares.into_iter() {
+ agg_share.accumulate(&output_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+impl<T, P, const L: usize> Collector for Prio3<T, P, L>
+where
+ T: Type,
+ P: Prg<L>,
+{
+ /// Combines aggregate shares into the aggregate result.
+ fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
+ &self,
+ _agg_param: &(),
+ agg_shares: It,
+ num_measurements: usize,
+ ) -> Result<T::AggregateResult, VdafError> {
+ let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
+ for agg_share in agg_shares.into_iter() {
+ agg.merge(&agg_share)?;
+ }
+
+ Ok(self.typ.decode_result(&agg.0, num_measurements)?)
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+struct JointRandParam<const L: usize> {
+ /// The joint randomness seed parts corresponding to the other Aggregators' shares.
+ seed_hint: Vec<Seed<L>>,
+
+ /// The blinding factor, used to derive the aggregator's joint randomness seed part.
+ blind: Seed<L>,
+}
+
+#[derive(Clone)]
+struct HelperShare<const L: usize> {
+ input_share: Seed<L>,
+ proof_share: Seed<L>,
+ joint_rand_param: JointRandParam<L>,
+}
+
+impl<const L: usize> HelperShare<L> {
+ fn from_rand_source(rand_source: RandSource) -> Result<Self, VdafError> {
+ Ok(HelperShare {
+ input_share: Seed::from_rand_source(rand_source)?,
+ proof_share: Seed::from_rand_source(rand_source)?,
+ joint_rand_param: JointRandParam {
+ seed_hint: Vec::new(),
+ blind: Seed::from_rand_source(rand_source)?,
+ },
+ })
+ }
+}
+
+fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
+ if num_aggregators == 0 {
+ return Err(VdafError::Uncategorized(format!(
+ "at least one aggregator is required; got {}",
+ num_aggregators
+ )));
+ } else if num_aggregators > 254 {
+ return Err(VdafError::Uncategorized(format!(
+ "number of aggregators must not exceed 254; got {}",
+ num_aggregators
+ )));
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::vdaf::{run_vdaf, run_vdaf_prepare};
+ use assert_matches::assert_matches;
+ use rand::prelude::*;
+
+ #[test]
+ fn test_prio3_count() {
+ let prio3 = Prio3::new_aes128_count(2).unwrap();
+
+ assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3);
+
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ let nonce = b"This is a good nonce.";
+
+ let (public_share, input_shares) = prio3.shard(&0).unwrap();
+ run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();
+
+ let (public_share, input_shares) = prio3.shard(&1).unwrap();
+ run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares).unwrap();
+
+ test_prepare_state_serialization(&prio3, &1).unwrap();
+
+ let prio3_extra_helper = Prio3::new_aes128_count(3).unwrap();
+ assert_eq!(
+ run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(),
+ 3,
+ );
+ }
+
+ #[test]
+ fn test_prio3_sum() {
+ let prio3 = Prio3::new_aes128_sum(3, 16).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
+ (1 << 16) + 1
+ );
+
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ let nonce = b"This is a good nonce.";
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ input_shares[0].joint_rand_param.as_mut().unwrap().blind.0[0] ^= 255;
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ input_shares[0].joint_rand_param.as_mut().unwrap().seed_hint[0].0[0] ^= 255;
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ assert_matches!(input_shares[0].input_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1).unwrap();
+ assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ test_prepare_state_serialization(&prio3, &1).unwrap();
+ }
+
+ #[test]
+ fn test_prio3_countvec() {
+ let prio3 = Prio3::new_aes128_count_vec(2, 20).unwrap();
+ assert_eq!(
+ run_vdaf(
+ &prio3,
+ &(),
+ [vec![
+ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
+ ]]
+ )
+ .unwrap(),
+ vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
+ );
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_prio3_countvec_multithreaded() {
+ let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, 20).unwrap();
+ assert_eq!(
+ run_vdaf(
+ &prio3,
+ &(),
+ [vec![
+ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
+ ]]
+ )
+ .unwrap(),
+ vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,]
+ );
+ }
+
+ #[test]
+ fn test_prio3_histogram() {
+ let prio3 = Prio3::new_aes128_histogram(2, &[0, 10, 20]).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, 10, 20, 9999]).unwrap(),
+ vec![1, 1, 1, 1]
+ );
+ assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [5]).unwrap(), vec![0, 1, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [10]).unwrap(), vec![0, 1, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [15]).unwrap(), vec![0, 0, 1, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [20]).unwrap(), vec![0, 0, 1, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [25]).unwrap(), vec![0, 0, 0, 1]);
+ test_prepare_state_serialization(&prio3, &23).unwrap();
+ }
+
+ #[test]
+ fn test_prio3_average() {
+ let prio3 = Prio3::new_aes128_average(2, 64).unwrap();
+
+ assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
+ assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
+ assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
+ assert_eq!(
+ run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
+ 207.5f64
+ );
+ }
+
+ #[test]
+ fn test_prio3_input_share() {
+ let prio3 = Prio3::new_aes128_sum(5, 16).unwrap();
+ let (_public_share, input_shares) = prio3.shard(&1).unwrap();
+
+ // Check that seed shares are distinct.
+ for (i, x) in input_shares.iter().enumerate() {
+ for (j, y) in input_shares.iter().enumerate() {
+ if i != j {
+ if let (Share::Helper(left), Share::Helper(right)) =
+ (&x.input_share, &y.input_share)
+ {
+ assert_ne!(left, right);
+ }
+
+ if let (Share::Helper(left), Share::Helper(right)) =
+ (&x.proof_share, &y.proof_share)
+ {
+ assert_ne!(left, right);
+ }
+
+ assert_ne!(x.joint_rand_param, y.joint_rand_param);
+ }
+ }
+ }
+ }
+
+ fn test_prepare_state_serialization<T, P, const L: usize>(
+ prio3: &Prio3<T, P, L>,
+ measurement: &T::Measurement,
+ ) -> Result<(), VdafError>
+ where
+ T: Type,
+ P: Prg<L>,
+ {
+ let mut verify_key = [0; L];
+ thread_rng().fill(&mut verify_key[..]);
+ let (public_share, input_shares) = prio3.shard(measurement)?;
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (want, _msg) =
+ prio3.prepare_init(&verify_key, agg_id, &(), &[], &public_share, input_share)?;
+ let got =
+ Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &want.get_encoded())
+ .expect("failed to decode prepare step");
+ assert_eq!(got, want);
+ }
+ Ok(())
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs
new file mode 100644
index 0000000000..d4c9151ce0
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio3_test.rs
@@ -0,0 +1,162 @@
+// SPDX-License-Identifier: MPL-2.0
+
+use crate::{
+ codec::{Encode, ParameterizedDecode},
+ flp::Type,
+ vdaf::{
+ prg::Prg,
+ prio3::{Prio3, Prio3InputShare, Prio3PrepareShare},
+ Aggregator, PrepareTransition,
+ },
+};
+use serde::{Deserialize, Serialize};
+use std::{convert::TryInto, fmt::Debug};
+
+#[derive(Debug, Deserialize, Serialize)]
+struct TEncoded(#[serde(with = "hex")] Vec<u8>);
+
+impl AsRef<[u8]> for TEncoded {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[derive(Deserialize, Serialize)]
+struct TPrio3Prep<M> {
+ measurement: M,
+ #[serde(with = "hex")]
+ nonce: Vec<u8>,
+ input_shares: Vec<TEncoded>,
+ prep_shares: Vec<Vec<TEncoded>>,
+ prep_messages: Vec<TEncoded>,
+ out_shares: Vec<Vec<M>>,
+}
+
+#[derive(Deserialize, Serialize)]
+struct TPrio3<M> {
+ verify_key: TEncoded,
+ prep: Vec<TPrio3Prep<M>>,
+}
+
+macro_rules! err {
+ (
+ $test_num:ident,
+ $error:expr,
+ $msg:expr
+ ) => {
+ panic!("test #{} failed: {} err: {}", $test_num, $msg, $error)
+ };
+}
+
+// TODO Generalize this method to work with any VDAF. To do so we would need to add
+// `test_vec_setup()` and `test_vec_shard()` to traits. (There may be a less invasive alternative.)
+fn check_prep_test_vec<M, T, P, const L: usize>(
+ prio3: &Prio3<T, P, L>,
+ verify_key: &[u8; L],
+ test_num: usize,
+ t: &TPrio3Prep<M>,
+) where
+ T: Type<Measurement = M>,
+ P: Prg<L>,
+ M: From<<T as Type>::Field> + Debug + PartialEq,
+{
+ let input_shares = prio3
+ .test_vec_shard(&t.measurement)
+ .expect("failed to generate input shares");
+
+ assert_eq!(2, t.input_shares.len(), "#{}", test_num);
+ for (agg_id, want) in t.input_shares.iter().enumerate() {
+ assert_eq!(
+ input_shares[agg_id],
+ Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")),
+ "#{}",
+ test_num
+ );
+ assert_eq!(
+ input_shares[agg_id].get_encoded(),
+ want.as_ref(),
+ "#{}",
+ test_num
+ )
+ }
+
+ let mut states = Vec::new();
+ let mut prep_shares = Vec::new();
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (state, prep_share) = prio3
+ .prepare_init(verify_key, agg_id, &(), &t.nonce, &(), input_share)
+ .unwrap_or_else(|e| err!(test_num, e, "prep state init"));
+ states.push(state);
+ prep_shares.push(prep_share);
+ }
+
+ assert_eq!(1, t.prep_shares.len(), "#{}", test_num);
+ for (i, want) in t.prep_shares[0].iter().enumerate() {
+ assert_eq!(
+ prep_shares[i],
+ Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
+ "#{}",
+ test_num
+ );
+ assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{}", test_num);
+ }
+
+ let inbound = prio3
+ .prepare_preprocess(prep_shares)
+ .unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
+ assert_eq!(t.prep_messages.len(), 1);
+ assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());
+
+ let mut out_shares = Vec::new();
+ for state in states.iter_mut() {
+ match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() {
+ PrepareTransition::Finish(out_share) => {
+ out_shares.push(out_share);
+ }
+ _ => panic!("unexpected transition"),
+ }
+ }
+
+ for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
+ let got: Vec<M> = got.as_ref().iter().map(|x| M::from(*x)).collect();
+ assert_eq!(&got, want);
+ }
+}
+
+#[test]
+fn test_vec_prio3_count() {
+ let t: TPrio3<u64> =
+ serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Count_0.json")).unwrap();
+ let prio3 = Prio3::new_aes128_count(2).unwrap();
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ for (test_num, p) in t.prep.iter().enumerate() {
+ check_prep_test_vec(&prio3, &verify_key, test_num, p);
+ }
+}
+
+#[test]
+fn test_vec_prio3_sum() {
+ let t: TPrio3<u128> =
+ serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Sum_0.json")).unwrap();
+ let prio3 = Prio3::new_aes128_sum(2, 8).unwrap();
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ for (test_num, p) in t.prep.iter().enumerate() {
+ check_prep_test_vec(&prio3, &verify_key, test_num, p);
+ }
+}
+
+#[test]
+fn test_vec_prio3_histogram() {
+ let t: TPrio3<u128> =
+ serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Histogram_0.json")).unwrap();
+ let prio3 = Prio3::new_aes128_histogram(2, &[1, 10, 100]).unwrap();
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ for (test_num, p) in t.prep.iter().enumerate() {
+ check_prep_test_vec(&prio3, &verify_key, test_num, p);
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json b/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json
new file mode 100644
index 0000000000..e450665173
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/03/PrgAes128.json
@@ -0,0 +1,7 @@
+{
+ "derived_seed": "ccf3be704c982182ad2961e9795a88aa",
+ "expanded_vec_field128": "ccf3be704c982182ad2961e9795a88aa8df71c0b5ea5c13bcf3173c3f3626505e1bf4738874d5405805082cc38c55d1f04f85fbb88b8cf8592ffed8a4ac7f76991c58d850a15e8deb34fb289ab6fab584554ffef16c683228db2b76e792ca4f3c15760044d0703b438c2aefd7975c5dd4b9992ee6f87f20e570572dea18fa580ee17204903c1234f1332d47a442ea636580518ce7aa5943c415117460a049bc19cc81edbb0114d71890cbdbe4ea2664cd038e57b88fb7fd3557830ad363c20b9840d35fd6bee6c3c8424f026ee7fbca3daf3c396a4d6736d7bd3b65b2c228d22a40f4404e47c61b26ac3c88bebf2f268fa972f8831f18bee374a22af0f8bb94d9331a1584bdf8cf3e8a5318b546efee8acd28f6cba8b21b9d52acbae8e726500340da98d643d0a5f1270ecb94c574130cea61224b0bc6d438b2f4f74152e72b37e6a9541c9dc5515f8f98fd0d1bce8743f033ab3e8574180ffc3363f3a0490f6f9583bf73a87b9bb4b51bfd0ef260637a4288c37a491c6cbdc46b6a86cd26edf611793236e912e7227bfb85b560308b06238bbd978f72ed4a58583cf0c6e134066eb6b399ad2f26fa01d69a62d8a2d04b4b8acf82299b07a834d4c2f48fee23a24c20307f9cabcd34b6d69f1969588ebde777e46e9522e866e6dd1e14119a1cb4c0709fa9ea347d9f872e76a39313e7d49bfbf3e5ce807183f43271ba2b5c6aaeaef22da301327c1fd9fedde7c5a68d9b97fa6eb687ec8ca692cb0f631f46e6699a211a1254026c9a0a43eceb450dc97cfa923321baf1f4b6f233260d46182b844dccec153aaddd20f920e9e13ff11434bcd2aa632bf4f544f41b5ddced962939676476f70e0b8640c3471fc7af62d80053781295b070388f7b7f1fa66220cb3",
+ "info": "696e666f20737472696e67",
+ "length": 40,
+ "seed": "01010101010101010101010101010101"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json
new file mode 100644
index 0000000000..9e79888745
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Count_0.json
@@ -0,0 +1,37 @@
+{
+ "agg_param": null,
+ "agg_result": 1,
+ "agg_shares": [
+ "ad8bb894e3222b47",
+ "5274476a1cddd4bb"
+ ],
+ "prep": [
+ {
+ "input_shares": [
+ "ad8bb894e3222b47b70eb67d4f70cb78644826d67d31129e422b910cf0aab70c0b78fa57b4a7b3aaafae57bd1012e813",
+ "0101010101010101010101010101010101010101010101010101010101010101"
+ ],
+ "measurement": 1,
+ "nonce": "01010101010101010101010101010101",
+ "out_shares": [
+ [
+ 12505291739929652039
+ ],
+ [
+ 5941452329484932283
+ ]
+ ],
+ "prep_messages": [
+ ""
+ ],
+ "prep_shares": [
+ [
+ "38d535dd68f3c02ed6681f7ff24239d46fde93c8402d24ebbafa25c77ca3535d",
+ "c72aca21970c3fd35274476a1cddd4bb4efa24ee0d71473e4a0a23713a347d78"
+ ]
+ ],
+ "public_share": ""
+ }
+ ],
+ "verify_key": "01010101010101010101010101010101"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json
new file mode 100644
index 0000000000..f5476455fa
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Histogram_0.json
@@ -0,0 +1,53 @@
+{
+ "agg_param": null,
+ "agg_result": [
+ 0,
+ 0,
+ 1,
+ 0
+ ],
+ "agg_shares": [
+ "ee1076c1ebc2d48a557a71031bc9dd5c9cd5e91180bbb51f4ac366946bcbfa93b908792bd15d402f4ac8da264e24a20f645ef68472180c5894bac704ae0675d7",
+ "11ef893e143d2b59aa858efce43622a5632a16ee7f444ac4b53c996b9434056e46f786d42ea2bfb4b53725d9b1db5df39ba1097b8de7f38b6b4538fb51f98a2a"
+ ],
+ "buckets": [
+ 1,
+ 10,
+ 100
+ ],
+ "prep": [
+ {
+ "input_shares": [
+ "ee1076c1ebc2d48a557a71031bc9dd5c9cd5e91180bbb51f4ac366946bcbfa93b908792bd15d402f4ac8da264e24a20f645ef68472180c5894bac704ae0675d7f16776df4f93852a40b514593a73be51ad64d8c28322a47af92c92223dd489998a3c6687861cdc2e4d834885d03d8d3273af0bf742c47985ae8fec6d16c31216792bb0cdca0d1d1fa2287414cd069f8caa42dc08f78dd43e14c4095e2ef9d9609937caebcd534e813136e79a4233e873397a6c7fd164928d43673b32e061139dc6650152d8433e2342f595149418929b74c9e23f1469ed1eebdaa57d0b5c62f90cb5a53dc68c8e030448bb2d9c07aeed50d82c93e1afe8febd68918933ed9b2dd36b9d8a35fd6c57cd76707011fca77526437aeb8392a2013f829c1e395f7f8ddef030f5bc869833f528ae2137a2e667aa648d8643f6c13e8d76e8832ab9ef7d0101010101010101010101010101010194c3f0f1061c8f440b51f806ad822510",
+ "0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101016195ec204fd5d65c14fac36b73723cde"
+ ],
+ "measurement": 50,
+ "nonce": "01010101010101010101010101010101",
+ "out_shares": [
+ [
+ 316441748434879643753815489063091297628,
+ 208470253761472213750543248431791209107,
+ 245951175238245845331446316072865931791,
+ 133415875449384174923011884997795018199
+ ],
+ [
+ 23840618486058819193050284304809468581,
+ 131812113159466249196322524936109557102,
+ 94331191682692617615419457295034834419,
+ 206866491471554288023853888370105748010
+ ]
+ ],
+ "prep_messages": [
+ "7912f1157c2ce3a4dca6456224aeaeea"
+ ],
+ "prep_shares": [
+ [
+ "f2dc9e823b867d760b2169644633804eabec10e5869fe8f3030c5da6dc0fce03a433572cb8aaa7ca3559959f7bad68306195ec204fd5d65c14fac36b73723cde",
+ "0d23617dc479826df4de969bb9cc7fb3f5e934542e987db0271aee33551b28a4c16f7ad00127c43df9c433a1c224594d94c3f0f1061c8f440b51f806ad822510"
+ ]
+ ],
+ "public_share": ""
+ }
+ ],
+ "verify_key": "01010101010101010101010101010101"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json
new file mode 100644
index 0000000000..55d2a000db
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/03/Prio3Aes128Sum_0.json
@@ -0,0 +1,38 @@
+{
+ "agg_param": null,
+ "agg_result": 100,
+ "agg_shares": [
+ "ab3bcc5ef693737a7a3e76cd9face3e0",
+ "54c433a1096c8c6985c1893260531c85"
+ ],
+ "bits": 8,
+ "prep": [
+ {
+ "input_shares": [
+ "fc1e42a024e3d8b45f63f485ebe2dc8a356c5e5780a0bd751184e6a02a96c0767f518e87282ebdc039590aef02e40e5492c9eb69dd22b6b4f1d630e7ca8612b7a7e090b39460bc4036345f5ef537d691fd585bc05a2ea580c7e354680afd0fd49f3d083d5e383b97755a842cf5e69870a970b14a10595c0c639ad2e7bda42c7146c4b69fd79e7403d89dac5816d0dc6f2bb987fccca4c4aee64444b7f46431433c59c6e7f2839fe2b7ad9316d31a52dcc0df07f1da14aa38e0cd88de380fda29b33704e8c3439376762739aa5b5cff9e925939773d24ca0e75bcf87149c9bcc2f8462afa6b50513ab003ac00c9ae3685ea52bdee3c814ffd5afc8357d93454b3ffaf0b5e9fd351730f0d55aed54a9cfa86f9119601ce9857ee0af3f579251bcc7ffe51b8393adc36ab6142eb0e0d07c9b2d5ab71d8d5639f32c61f7d59b45a95129cbc76d7e30c02a1329454f843553413d4e84bcab2c3ba1a0150292026dfa37488da5dd639c53edd51bf4eb5aa54d5b165fcd55d10f3496008f4e3b6d3eb200c19c5b9c42ad4f12977a857d02f787b14ced27fc5eefb05722b372a7d48c1891d30a32d84ec8d1f9a783a38bfac2793f0da6796cff90521e1d73f497f7d2c910b7fbbea2ba4b906d437a53bcbed16986f5646fd238e736f1c3e9d3a910218ce7f48dea3e9a1a848c580a1c506a80edb0c0a973a269667475ce88f4424674b14a3a8f2b71ef529d2ca96a3c5e4da384545749a55188d4de0074ad601695e934c9fe71d27c139b7678ead7f904cd2ae2a3aafa96d8211579e391507df96bf42c383f2ac71d7a558ebf1e3d5ab086b65422415bd24be9c979ca5b4f381d51b06ec4f6740b1a084999cd95fe63fec4a019f635640ba18d42312de7d1994947502b9010101010101010101010101010101015f0721f50826593dc3908dad39353846",
+ "010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101094240ceae2d63ba1bdda997fa0bcbd8"
+ ],
+ "measurement": 100,
+ "nonce": "01010101010101010101010101010101",
+ "out_shares": [
+ [
+ 227608477929192160221239678567201956832
+ ],
+ [
+ 112673888991746302725626094800698809477
+ ]
+ ],
+ "prep_messages": [
+ "60af733578d766f2305c1d53c840b4b5"
+ ],
+ "prep_shares": [
+ [
+ "0a85b5e51cacf514ee9e9bbe5d3ac023795e910b765411e5cea8ff187973640694bd740cc15bc9cad60bc85785206062094240ceae2d63ba1bdda997fa0bcbd8",
+ "f57a4a1ae3530acf11616441a2c53fde804d262dc42e15e556ee02c588c3ca9d924eefa735a95f6e420f2c5161706e025f0721f50826593dc3908dad39353846"
+ ]
+ ],
+ "public_share": ""
+ }
+ ],
+ "verify_key": "01010101010101010101010101010101"
+}