summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/rust/prio/src
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/prio/src')
-rw-r--r--third_party/rust/prio/src/benchmarked.rs50
-rw-r--r--third_party/rust/prio/src/codec.rs734
-rw-r--r--third_party/rust/prio/src/dp.rs127
-rw-r--r--third_party/rust/prio/src/dp/distributions.rs607
-rw-r--r--third_party/rust/prio/src/fft.rs222
-rw-r--r--third_party/rust/prio/src/field.rs1190
-rw-r--r--third_party/rust/prio/src/field/field255.rs543
-rw-r--r--third_party/rust/prio/src/flp.rs1059
-rw-r--r--third_party/rust/prio/src/flp/gadgets.rs591
-rw-r--r--third_party/rust/prio/src/flp/types.rs1415
-rw-r--r--third_party/rust/prio/src/flp/types/fixedpoint_l2.rs899
-rw-r--r--third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs93
-rw-r--r--third_party/rust/prio/src/fp.rs533
-rw-r--r--third_party/rust/prio/src/idpf.rs2200
-rw-r--r--third_party/rust/prio/src/lib.rs34
-rw-r--r--third_party/rust/prio/src/polynomial.rs383
-rw-r--r--third_party/rust/prio/src/prng.rs278
-rw-r--r--third_party/rust/prio/src/topology/mod.rs7
-rw-r--r--third_party/rust/prio/src/topology/ping_pong.rs968
-rw-r--r--third_party/rust/prio/src/vdaf.rs757
-rw-r--r--third_party/rust/prio/src/vdaf/dummy.rs316
-rw-r--r--third_party/rust/prio/src/vdaf/poplar1.rs2465
-rw-r--r--third_party/rust/prio/src/vdaf/prio2.rs543
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/client.rs306
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/server.rs386
-rw-r--r--third_party/rust/prio/src/vdaf/prio2/test_vector.rs83
-rw-r--r--third_party/rust/prio/src/vdaf/prio3.rs2127
-rw-r--r--third_party/rust/prio/src/vdaf/prio3_test.rs251
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json52
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json56
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json64
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json64
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json76
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json39
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json45
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json52
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json89
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json194
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json146
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json40
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json46
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json8
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json8
-rw-r--r--third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json28
-rw-r--r--third_party/rust/prio/src/vdaf/xof.rs574
45 files changed, 20748 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/benchmarked.rs b/third_party/rust/prio/src/benchmarked.rs
new file mode 100644
index 0000000000..1882de91e7
--- /dev/null
+++ b/third_party/rust/prio/src/benchmarked.rs
@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: MPL-2.0
+
+#![doc(hidden)]
+
+//! This module provides wrappers around internal components of this crate that we want to
+//! benchmark, but which we don't want to expose in the public API.
+
+use crate::fft::discrete_fourier_transform;
+use crate::field::FftFriendlyFieldElement;
+use crate::flp::gadgets::Mul;
+use crate::flp::FlpError;
+use crate::polynomial::{poly_fft, PolyAuxMemory};
+
+/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
+pub fn benchmarked_iterative_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
+ discrete_fourier_transform(outp, inp, inp.len()).unwrap();
+}
+
+/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
+pub fn benchmarked_recursive_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
+ let mut mem = PolyAuxMemory::new(inp.len() / 2);
+ poly_fft(
+ outp,
+ inp,
+ &mem.roots_2n,
+ inp.len(),
+ false,
+ &mut mem.fft_memory,
+ )
+}
+
+/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
+/// uses FFT for multiplication.
+pub fn benchmarked_gadget_mul_call_poly_fft<F: FftFriendlyFieldElement>(
+ g: &mut Mul<F>,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError> {
+ g.call_poly_fft(outp, inp)
+}
+
+/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
+/// does the multiplication directly.
+pub fn benchmarked_gadget_mul_call_poly_direct<F: FftFriendlyFieldElement>(
+ g: &mut Mul<F>,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError> {
+ g.call_poly_direct(outp, inp)
+}
diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs
new file mode 100644
index 0000000000..71f4f8ce5f
--- /dev/null
+++ b/third_party/rust/prio/src/codec.rs
@@ -0,0 +1,734 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Support for encoding and decoding messages to or from the TLS wire encoding, as specified in
+//! [RFC 8446, Section 3][1].
+//!
+//! The [`Encode`], [`Decode`], [`ParameterizedEncode`] and [`ParameterizedDecode`] traits can be
+//! implemented on values that need to be encoded or decoded. Utility functions are provided to
+//! encode or decode sequences of values.
+//!
+//! [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3
+
+use byteorder::{BigEndian, ReadBytesExt};
+use std::{
+ convert::TryInto,
+ error::Error,
+ io::{Cursor, Read},
+ mem::size_of,
+ num::TryFromIntError,
+};
+
+/// An error that occurred during decoding.
+#[derive(Debug, thiserror::Error)]
+pub enum CodecError {
+ /// An I/O error.
+ #[error("I/O error")]
+ Io(#[from] std::io::Error),
+
+ /// Extra data remained in the input after decoding a value.
+ #[error("{0} bytes left in buffer after decoding value")]
+ BytesLeftOver(usize),
+
+ /// The length prefix of an encoded vector exceeds the amount of remaining input.
+ #[error("length prefix of encoded vector overflows buffer: {0}")]
+ LengthPrefixTooBig(usize),
+
+ /// Custom errors from [`Decode`] implementations.
+ #[error("other error: {0}")]
+ Other(#[source] Box<dyn Error + 'static + Send + Sync>),
+
+ /// An invalid value was decoded.
+ #[error("unexpected value")]
+ UnexpectedValue,
+}
+
+/// Describes how to decode an object from a byte sequence.
+pub trait Decode: Sized {
+ /// Read and decode an encoded object from `bytes`. On success, the decoded value is returned
+ /// and `bytes` is advanced by the encoded size of the value. On failure, an error is returned
+ /// and no further attempt to read from `bytes` should be made.
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>;
+
+ /// Convenience method to get a decoded value. Returns an error if [`Self::decode`] fails, or if
+ /// there are any bytes left in `bytes` after decoding a value.
+ fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> {
+ Self::get_decoded_with_param(&(), bytes)
+ }
+}
+
+/// Describes how to decode an object from a byte sequence and a decoding parameter that provides
+/// additional context.
+pub trait ParameterizedDecode<P>: Sized {
+ /// Read and decode an encoded object from `bytes`. `decoding_parameter` provides details of the
+ /// wire encoding such as lengths of different portions of the message. On success, the decoded
+ /// value is returned and `bytes` is advanced by the encoded size of the value. On failure, an
+ /// error is returned and no further attempt to read from `bytes` should be made.
+ fn decode_with_param(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError>;
+
+ /// Convenience method to get a decoded value. Returns an error if [`Self::decode_with_param`]
+ /// fails, or if there are any bytes left in `bytes` after decoding a value.
+ fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> {
+ let mut cursor = Cursor::new(bytes);
+ let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?;
+ if cursor.position() as usize != bytes.len() {
+ return Err(CodecError::BytesLeftOver(
+ bytes.len() - cursor.position() as usize,
+ ));
+ }
+
+ Ok(decoded)
+ }
+}
+
+/// Provide a blanket implementation so that any [`Decode`] can be used as a
+/// `ParameterizedDecode<T>` for any `T`.
+impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D {
+ fn decode_with_param(
+ _decoding_parameter: &T,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Self::decode(bytes)
+ }
+}
+
+/// Describes how to encode objects into a byte sequence.
+pub trait Encode {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ fn encode(&self, bytes: &mut Vec<u8>);
+
+ /// Convenience method to encode a value into a new `Vec<u8>`.
+ fn get_encoded(&self) -> Vec<u8> {
+ self.get_encoded_with_param(&())
+ }
+
+ /// Returns an optional hint indicating how many bytes will be required to encode this value, or
+ /// `None` by default.
+ fn encoded_len(&self) -> Option<usize> {
+ None
+ }
+}
+
+/// Describes how to encode objects into a byte sequence.
+pub trait ParameterizedEncode<P> {
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ /// `encoding_parameter` provides details of the wire encoding, used to control how the value
+ /// is encoded.
+ fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>);
+
+ /// Convenience method to encode a value into a new `Vec<u8>`.
+ fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> {
+ let mut ret = if let Some(length) = self.encoded_len_with_param(encoding_parameter) {
+ Vec::with_capacity(length)
+ } else {
+ Vec::new()
+ };
+ self.encode_with_param(encoding_parameter, &mut ret);
+ ret
+ }
+
+ /// Returns an optional hint indicating how many bytes will be required to encode this value, or
+ /// `None` by default.
+ fn encoded_len_with_param(&self, _encoding_parameter: &P) -> Option<usize> {
+ None
+ }
+}
+
+/// Provide a blanket implementation so that any [`Encode`] can be used as a
+/// `ParameterizedEncode<T>` for any `T`.
+impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
+ fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) {
+ self.encode(bytes)
+ }
+
+ fn encoded_len_with_param(&self, _encoding_parameter: &T) -> Option<usize> {
+ <Self as Encode>::encoded_len(self)
+ }
+}
+
+impl Decode for () {
+ fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(())
+ }
+}
+
+impl Encode for () {
+ fn encode(&self, _bytes: &mut Vec<u8>) {}
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(0)
+ }
+}
+
+impl Decode for u8 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut value = [0u8; size_of::<u8>()];
+ bytes.read_exact(&mut value)?;
+ Ok(value[0])
+ }
+}
+
+impl Encode for u8 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.push(*self);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(1)
+ }
+}
+
+impl Decode for u16 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u16::<BigEndian>()?)
+ }
+}
+
+impl Encode for u16 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u16::to_be_bytes(*self));
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(2)
+ }
+}
+
+/// 24 bit integer, per
+/// [RFC 8443, section 3.3](https://datatracker.ietf.org/doc/html/rfc8446#section-3.3)
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+struct U24(pub u32);
+
+impl Decode for U24 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(U24(bytes.read_u24::<BigEndian>()?))
+ }
+}
+
+impl Encode for U24 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // Encode lower three bytes of the u32 as u24
+ bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(3)
+ }
+}
+
+impl Decode for u32 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u32::<BigEndian>()?)
+ }
+}
+
+impl Encode for u32 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u32::to_be_bytes(*self));
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(4)
+ }
+}
+
+impl Decode for u64 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(bytes.read_u64::<BigEndian>()?)
+ }
+}
+
+impl Encode for u64 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&u64::to_be_bytes(*self));
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(8)
+ }
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ bytes.push(0);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 1;
+ assert!(len <= usize::from(u8::MAX));
+ bytes[len_offset] = len as u8;
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u8_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read one byte to get length of opaque byte vector
+ let length = usize::from(u8::decode(bytes)?);
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ 0u16.encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 2;
+ assert!(len <= usize::from(u16::MAX));
+ for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u16_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read two bytes to get length of opaque byte vector
+ let length = usize::from(u16::decode(bytes)?);
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
+/// `0xffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ U24(0).encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 3;
+ assert!(len <= 0xffffff);
+ for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u24_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read three bytes to get length of opaque byte vector
+ let length = U24::decode(bytes)?.0 as usize;
+
+ decode_items(length, decoding_parameter, bytes)
+}
+
+/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
+/// `0xffffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn encode_u32_items<P, E: ParameterizedEncode<P>>(
+ bytes: &mut Vec<u8>,
+ encoding_parameter: &P,
+ items: &[E],
+) {
+ // Reserve space to later write length
+ let len_offset = bytes.len();
+ 0u32.encode(bytes);
+
+ for item in items {
+ item.encode_with_param(encoding_parameter, bytes);
+ }
+
+ let len = bytes.len() - len_offset - 4;
+ let len: u32 = len.try_into().expect("Length too large");
+ for (offset, byte) in len.to_be_bytes().iter().enumerate() {
+ bytes[len_offset + offset] = *byte;
+ }
+}
+
+/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
+/// maximum length `0xffffffff`.
+///
+/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
+pub fn decode_u32_items<P, D: ParameterizedDecode<P>>(
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ // Read four bytes to get length of opaque byte vector.
+ let len: usize = u32::decode(bytes)?
+ .try_into()
+ .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?;
+
+ decode_items(len, decoding_parameter, bytes)
+}
+
+/// Decode the next `length` bytes from `bytes` into as many instances of `D` as possible.
+fn decode_items<P, D: ParameterizedDecode<P>>(
+ length: usize,
+ decoding_parameter: &P,
+ bytes: &mut Cursor<&[u8]>,
+) -> Result<Vec<D>, CodecError> {
+ let mut decoded = Vec::new();
+ let initial_position = bytes.position() as usize;
+
+ // Create cursor over specified portion of provided cursor to ensure we can't read past length.
+ let inner = bytes.get_ref();
+
+ // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer.
+ let (items_end, overflowed) = initial_position.overflowing_add(length);
+ if overflowed || items_end > inner.len() {
+ return Err(CodecError::LengthPrefixTooBig(length));
+ }
+
+ let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]);
+
+ while sub.position() < length as u64 {
+ decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?);
+ }
+
+ // Advance outer cursor by the amount read in the inner cursor
+ bytes.set_position(initial_position as u64 + sub.position());
+
+ Ok(decoded)
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn encode_nothing() {
+ let mut bytes = vec![];
+ ().encode(&mut bytes);
+ assert_eq!(bytes.len(), 0);
+ }
+
+ #[test]
+ fn roundtrip_u8() {
+ let value = 100u8;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 1);
+
+ let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u16() {
+ let value = 1000u16;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 2);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![3, 232]);
+
+ let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u24() {
+ let value = U24(1_000_000u32);
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 3);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![15, 66, 64]);
+
+ let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u32() {
+ let value = 134_217_728u32;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 4);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![8, 0, 0, 0]);
+
+ let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[test]
+ fn roundtrip_u64() {
+ let value = 137_438_953_472u64;
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), 8);
+ // Check endianness of encoding
+ assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
+
+ let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct TestMessage {
+ field_u8: u8,
+ field_u16: u16,
+ field_u24: U24,
+ field_u32: u32,
+ field_u64: u64,
+ }
+
+ impl Encode for TestMessage {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.field_u8.encode(bytes);
+ self.field_u16.encode(bytes);
+ self.field_u24.encode(bytes);
+ self.field_u32.encode(bytes);
+ self.field_u64.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(
+ self.field_u8.encoded_len()?
+ + self.field_u16.encoded_len()?
+ + self.field_u24.encoded_len()?
+ + self.field_u32.encoded_len()?
+ + self.field_u64.encoded_len()?,
+ )
+ }
+ }
+
+ impl Decode for TestMessage {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let field_u8 = u8::decode(bytes)?;
+ let field_u16 = u16::decode(bytes)?;
+ let field_u24 = U24::decode(bytes)?;
+ let field_u32 = u32::decode(bytes)?;
+ let field_u64 = u64::decode(bytes)?;
+
+ Ok(TestMessage {
+ field_u8,
+ field_u16,
+ field_u24,
+ field_u32,
+ field_u64,
+ })
+ }
+ }
+
+ impl TestMessage {
+ fn encoded_length() -> usize {
+ // u8 field
+ 1 +
+ // u16 field
+ 2 +
+ // u24 field
+ 3 +
+ // u32 field
+ 4 +
+ // u64 field
+ 8
+ }
+ }
+
+ #[test]
+ fn roundtrip_message() {
+ let value = TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ };
+
+ let mut bytes = vec![];
+ value.encode(&mut bytes);
+ assert_eq!(bytes.len(), TestMessage::encoded_length());
+ assert_eq!(value.encoded_len().unwrap(), TestMessage::encoded_length());
+
+ let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(value, decoded);
+ }
+
+ fn messages_vec() -> Vec<TestMessage> {
+ vec![
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ TestMessage {
+ field_u8: 0,
+ field_u16: 300,
+ field_u24: U24(1_000_000),
+ field_u32: 134_217_728,
+ field_u64: 137_438_953_472,
+ },
+ ]
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u8() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u8_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 1 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u16() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u16_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 2 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ // Check endianness of encoded length
+ assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]);
+
+ let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u24() {
+ let values = messages_vec();
+ let mut bytes = vec![];
+ encode_u24_items(&mut bytes, &(), &values);
+
+ assert_eq!(
+ bytes.len(),
+ // Length of opaque vector
+ 3 +
+ // 3 TestMessage values
+ 3 * TestMessage::encoded_length()
+ );
+
+ // Check endianness of encoded length
+ assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]);
+
+ let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn roundtrip_variable_length_u32() {
+ let values = messages_vec();
+ let mut bytes = Vec::new();
+ encode_u32_items(&mut bytes, &(), &values);
+
+ assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length());
+
+ // Check endianness of encoded length.
+ assert_eq!(
+ bytes[0..4],
+ [0, 0, 0, 3 * TestMessage::encoded_length() as u8]
+ );
+
+ let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap();
+ assert_eq!(values, decoded);
+ }
+
+ #[test]
+ fn decode_items_overflow() {
+ let encoded = vec![1u8];
+
+ let mut cursor = Cursor::new(encoded.as_slice());
+ cursor.set_position(1);
+
+ assert_matches!(
+ decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(),
+ CodecError::LengthPrefixTooBig(usize::MAX)
+ );
+ }
+
+ #[test]
+ fn decode_items_too_big() {
+ let encoded = vec![1u8];
+
+ let mut cursor = Cursor::new(encoded.as_slice());
+ cursor.set_position(1);
+
+ assert_matches!(
+ decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(),
+ CodecError::LengthPrefixTooBig(2)
+ );
+ }
+
+ #[test]
+ fn length_hint_correctness() {
+ assert_eq!(().encoded_len().unwrap(), ().get_encoded().len());
+ assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().len());
+ assert_eq!(0u16.encoded_len().unwrap(), 0u16.get_encoded().len());
+ assert_eq!(U24(0).encoded_len().unwrap(), U24(0).get_encoded().len());
+ assert_eq!(0u32.encoded_len().unwrap(), 0u32.get_encoded().len());
+ assert_eq!(0u64.encoded_len().unwrap(), 0u64.get_encoded().len());
+ }
+}
diff --git a/third_party/rust/prio/src/dp.rs b/third_party/rust/prio/src/dp.rs
new file mode 100644
index 0000000000..506676dbb9
--- /dev/null
+++ b/third_party/rust/prio/src/dp.rs
@@ -0,0 +1,127 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Differential privacy (DP) primitives.
+//!
+//! There are three main traits defined in this module:
+//!
+//! - `DifferentialPrivacyBudget`: Implementors should be types of DP-budgets,
+//! i.e., methods to measure the amount of privacy provided by DP-mechanisms.
+//! Examples: zCDP, ApproximateDP (Epsilon-Delta), PureDP
+//!
+//! - `DifferentialPrivacyDistribution`: Distribution from which noise is sampled.
+//! Examples: DiscreteGaussian, DiscreteLaplace
+//!
+//! - `DifferentialPrivacyStrategy`: This is a combination of choices for budget and distribution.
+//! Examples: zCDP-DiscreteGaussian, EpsilonDelta-DiscreteGaussian
+//!
+use num_bigint::{BigInt, BigUint, TryFromBigIntError};
+use num_rational::{BigRational, Ratio};
+use serde::{Deserialize, Serialize};
+
+/// Errors propagated by methods in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum DpError {
+ /// Tried to use an invalid float as privacy parameter.
+ #[error(
+ "DP error: input value was not a valid privacy parameter. \
+ It should to be a non-negative, finite float."
+ )]
+ InvalidFloat,
+
+ /// Tried to construct a rational number with zero denominator.
+ #[error("DP error: input denominator was zero.")]
+ ZeroDenominator,
+
+ /// Tried to convert BigInt into something incompatible.
+ #[error("DP error: {0}")]
+ BigIntConversion(#[from] TryFromBigIntError<BigInt>),
+}
+
+/// Positive arbitrary precision rational number to represent DP and noise distribution parameters in
+/// protocol messages and manipulate them without rounding errors.
+#[derive(Clone, Debug)]
+pub struct Rational(Ratio<BigUint>);
+
+impl Rational {
+ /// Construct a [`Rational`] number from numerator `n` and denominator `d`. Errors if denominator is zero.
+ pub fn from_unsigned<T>(n: T, d: T) -> Result<Self, DpError>
+ where
+ T: Into<u128>,
+ {
+ // we don't want to expose BigUint in the public api, hence the Into<u128> bound
+ let d = d.into();
+ if d == 0 {
+ Err(DpError::ZeroDenominator)
+ } else {
+ Ok(Rational(Ratio::<BigUint>::new(n.into().into(), d.into())))
+ }
+ }
+}
+
+impl TryFrom<f32> for Rational {
+ type Error = DpError;
+ /// Constructs a `Rational` from a given `f32` value.
+ ///
+ /// The special float values (NaN, positive and negative infinity) result in
+ /// an error. All other values are represented exactly, without rounding errors.
+ fn try_from(value: f32) -> Result<Self, DpError> {
+ match BigRational::from_float(value) {
+ Some(y) => Ok(Rational(Ratio::<BigUint>::new(
+ y.numer().clone().try_into()?,
+ y.denom().clone().try_into()?,
+ ))),
+ None => Err(DpError::InvalidFloat)?,
+ }
+ }
+}
+
+/// Marker trait for differential privacy budgets (regardless of the specific accounting method).
+pub trait DifferentialPrivacyBudget {}
+
+/// Marker trait for differential privacy scalar noise distributions.
+pub trait DifferentialPrivacyDistribution {}
+
+/// Zero-concentrated differential privacy (ZCDP) budget as defined in [[BS16]].
+///
+/// [BS16]: https://arxiv.org/pdf/1605.02065.pdf
+#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
+pub struct ZCdpBudget {
+ epsilon: Ratio<BigUint>,
+}
+
+impl ZCdpBudget {
+ /// Create a budget for parameter `epsilon`, using the notation from [[CKS20]] where `rho = (epsilon**2)/2`
+ /// for a `rho`-ZCDP budget.
+ ///
+ /// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+ pub fn new(epsilon: Rational) -> Self {
+ Self { epsilon: epsilon.0 }
+ }
+}
+
+impl DifferentialPrivacyBudget for ZCdpBudget {}
+
+/// Strategy to make aggregate results differentially private, e.g. by adding noise from a specific
+/// type of distribution instantiated with a given DP budget.
+pub trait DifferentialPrivacyStrategy {
+ /// The type of the DP budget, i.e. the variant of differential privacy that can be obtained
+ /// by using this strategy.
+ type Budget: DifferentialPrivacyBudget;
+
+ /// The distribution type this strategy will use to generate the noise.
+ type Distribution: DifferentialPrivacyDistribution;
+
+ /// The type the sensitivity used for privacy analysis has.
+ type Sensitivity;
+
+ /// Create a strategy from a differential privacy budget. The distribution created with
+ /// `create_distribution` should provide the amount of privacy specified here.
+ fn from_budget(b: Self::Budget) -> Self;
+
+ /// Create a new distribution parametrized s.t. adding samples to the result of a function
+ /// with sensitivity `s` will yield differential privacy of the DP variant given in the
+ /// `Budget` type. Can error upon invalid parameters.
+ fn create_distribution(&self, s: Self::Sensitivity) -> Result<Self::Distribution, DpError>;
+}
+
+pub mod distributions;
diff --git a/third_party/rust/prio/src/dp/distributions.rs b/third_party/rust/prio/src/dp/distributions.rs
new file mode 100644
index 0000000000..ba0270df9c
--- /dev/null
+++ b/third_party/rust/prio/src/dp/distributions.rs
@@ -0,0 +1,607 @@
+// Copyright (c) 2023 ISRG
+// SPDX-License-Identifier: MPL-2.0
+//
+// This file contains code covered by the following copyright and permission notice
+// and has been modified by ISRG and collaborators.
+//
+// Copyright (c) 2022 President and Fellows of Harvard College
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+//
+// This file incorporates work covered by the following copyright and
+// permission notice:
+//
+// Copyright 2020 Thomas Steinke
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// The following code is adapted from the opendp implementation to reduce dependencies:
+// https://github.com/opendp/opendp/blob/main/rust/src/traits/samplers/cks20
+
+//! Implementation of a sampler from the Discrete Gaussian Distribution.
+//!
+//! Follows
+//! Clément Canonne, Gautam Kamath, Thomas Steinke. The Discrete Gaussian for Differential Privacy. 2020.
+//! <https://arxiv.org/pdf/2004.00010.pdf>
+
+use num_bigint::{BigInt, BigUint, UniformBigUint};
+use num_integer::Integer;
+use num_iter::range_inclusive;
+use num_rational::Ratio;
+use num_traits::{One, Zero};
+use rand::{distributions::uniform::UniformSampler, distributions::Distribution, Rng};
+use serde::{Deserialize, Serialize};
+
+use super::{
+ DifferentialPrivacyBudget, DifferentialPrivacyDistribution, DifferentialPrivacyStrategy,
+ DpError, ZCdpBudget,
+};
+
+/// Sample from the Bernoulli(gamma) distribution, where $gamma /leq 1$.
+///
+/// `sample_bernoulli(gamma, rng)` returns numbers distributed as $Bernoulli(gamma)$.
+/// using the given random number generator for base randomness. The procedure is as described
+/// on page 30 of [[CKS20]].
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+fn sample_bernoulli<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> bool {
+ let d = gamma.denom();
+ assert!(!d.is_zero());
+ assert!(gamma <= &Ratio::<BigUint>::one());
+
+ // sample uniform biguint in {1,...,d}
+ // uses the implementation of rand::Uniform for num_bigint::BigUint
+ let s = UniformBigUint::sample_single_inclusive(BigUint::one(), d, rng);
+
+ s <= *gamma.numer()
+}
+
+/// Sample from the Bernoulli(exp(-gamma)) distribution where `gamma` is in `[0,1]`.
+///
+/// `sample_bernoulli_exp1(gamma, rng)` returns numbers distributed as $Bernoulli(exp(-gamma))$,
+/// using the given random number generator for base randomness. Follows Algorithm 1 of [[CKS20]],
+/// splitting the branches into two non-recursive functions. This is the `gamma in [0,1]` branch.
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+fn sample_bernoulli_exp1<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> bool {
+ assert!(!gamma.denom().is_zero());
+ assert!(gamma <= &Ratio::<BigUint>::one());
+
+ let mut k = BigUint::one();
+ loop {
+ if sample_bernoulli(&(gamma / k.clone()), rng) {
+ k += 1u8;
+ } else {
+ return k.is_odd();
+ }
+ }
+}
+
+/// Sample from the Bernoulli(exp(-gamma)) distribution.
+///
+/// `sample_bernoulli_exp(gamma, rng)` returns numbers distributed as $Bernoulli(exp(-gamma))$,
+/// using the given random number generator for base randomness. Follows Algorithm 1 of [[CKS20]],
+/// splitting the branches into two non-recursive functions. This is the `gamma > 1` branch.
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+fn sample_bernoulli_exp<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> bool {
+ assert!(!gamma.denom().is_zero());
+ for _ in range_inclusive(BigUint::one(), gamma.floor().to_integer()) {
+ if !sample_bernoulli_exp1(&Ratio::<BigUint>::one(), rng) {
+ return false;
+ }
+ }
+ sample_bernoulli_exp1(&(gamma - gamma.floor()), rng)
+}
+
+/// Sample from the geometric distribution with parameter 1 - exp(-gamma).
+///
+/// `sample_geometric_exp(gamma, rng)` returns numbers distributed according to
+/// $Geometric(1 - exp(-gamma))$, using the given random number generator for base randomness.
+/// The code follows all but the last three lines of Algorithm 2 in [[CKS20]].
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+fn sample_geometric_exp<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> BigUint {
+ let (s, t) = (gamma.numer(), gamma.denom());
+ assert!(!t.is_zero());
+ if gamma.is_zero() {
+ return BigUint::zero();
+ }
+
+ // sampler for uniform biguint in {0...t-1}
+ // uses the implementation of rand::Uniform for num_bigint::BigUint
+ let usampler = UniformBigUint::new(BigUint::zero(), t);
+ let mut u = usampler.sample(rng);
+
+ while !sample_bernoulli_exp1(&Ratio::<BigUint>::new(u.clone(), t.clone()), rng) {
+ u = usampler.sample(rng);
+ }
+
+ let mut v = BigUint::zero();
+ loop {
+ if sample_bernoulli_exp1(&Ratio::<BigUint>::one(), rng) {
+ v += 1u8;
+ } else {
+ break;
+ }
+ }
+
+ // we do integer division, so the following term equals floor((u + t*v)/s)
+ (u + t * v) / s
+}
+
+/// Sample from the discrete Laplace distribution.
+///
+/// `sample_discrete_laplace(scale, rng)` returns numbers distributed according to
+/// $\mathcal{L}_\mathbb{Z}(0, scale)$, using the given random number generator for base randomness.
+/// This follows Algorithm 2 of [[CKS20]], using a subfunction for geometric sampling.
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+fn sample_discrete_laplace<R: Rng + ?Sized>(scale: &Ratio<BigUint>, rng: &mut R) -> BigInt {
+ let (s, t) = (scale.numer(), scale.denom());
+ assert!(!t.is_zero());
+ if s.is_zero() {
+ return BigInt::zero();
+ }
+
+ loop {
+ let negative = sample_bernoulli(&Ratio::<BigUint>::new(BigUint::one(), 2u8.into()), rng);
+ let y: BigInt = sample_geometric_exp(&scale.recip(), rng).into();
+ if negative && y.is_zero() {
+ continue;
+ } else {
+ return if negative { -y } else { y };
+ }
+ }
+}
+
+/// Sample from the discrete Gaussian distribution.
+///
+/// `sample_discrete_gaussian(sigma, rng)` returns `BigInt` numbers distributed as
+/// $\mathcal{N}_\mathbb{Z}(0, sigma^2)$, using the given random number generator for base
+/// randomness. Follows Algorithm 3 from [[CKS20]].
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+fn sample_discrete_gaussian<R: Rng + ?Sized>(sigma: &Ratio<BigUint>, rng: &mut R) -> BigInt {
+ assert!(!sigma.denom().is_zero());
+ if sigma.is_zero() {
+ return 0.into();
+ }
+ let t = sigma.floor() + BigUint::one();
+
+ // no need to compute these parts of the probability term every iteration
+ let summand = sigma.pow(2) / t.clone();
+ // compute probability of accepting the laplace sample y
+ let prob = |term: Ratio<BigUint>| term.pow(2) * (sigma.pow(2) * BigUint::from(2u8)).recip();
+
+ loop {
+ let y = sample_discrete_laplace(&t, rng);
+
+ // absolute value without type conversion
+ let y_abs: Ratio<BigUint> = BigUint::new(y.to_u32_digits().1).into();
+
+ // unsigned subtraction-followed-by-square
+ let prob: Ratio<BigUint> = if y_abs < summand {
+ prob(summand.clone() - y_abs)
+ } else {
+ prob(y_abs - summand.clone())
+ };
+
+ if sample_bernoulli_exp(&prob, rng) {
+ return y;
+ }
+ }
+}
+
+/// Samples `BigInt` numbers according to the discrete Gaussian distribution with mean zero.
+/// The distribution is defined over the integers, represented by arbitrary-precision integers.
+/// The sampling procedure follows [[CKS20]].
+///
+/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+#[derive(Clone, Debug)]
+pub struct DiscreteGaussian {
+ /// The standard deviation of the distribution.
+ std: Ratio<BigUint>,
+}
+
+impl DiscreteGaussian {
+ /// Create a new sampler from the Discrete Gaussian Distribution with the given
+ /// standard deviation and mean zero. Errors if the input has denominator zero.
+ pub fn new(std: Ratio<BigUint>) -> Result<DiscreteGaussian, DpError> {
+ if std.denom().is_zero() {
+ return Err(DpError::ZeroDenominator);
+ }
+ Ok(DiscreteGaussian { std })
+ }
+}
+
+impl Distribution<BigInt> for DiscreteGaussian {
+ fn sample<R>(&self, rng: &mut R) -> BigInt
+ where
+ R: Rng + ?Sized,
+ {
+ sample_discrete_gaussian(&self.std, rng)
+ }
+}
+
+impl DifferentialPrivacyDistribution for DiscreteGaussian {}
+
+/// A DP strategy using the discrete gaussian distribution.
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Ord, PartialOrd)]
+pub struct DiscreteGaussianDpStrategy<B>
+where
+ B: DifferentialPrivacyBudget,
+{
+ budget: B,
+}
+
+/// A DP strategy using the discrete gaussian distribution providing zero-concentrated DP.
+pub type ZCdpDiscreteGaussian = DiscreteGaussianDpStrategy<ZCdpBudget>;
+
+impl DifferentialPrivacyStrategy for DiscreteGaussianDpStrategy<ZCdpBudget> {
+ type Budget = ZCdpBudget;
+ type Distribution = DiscreteGaussian;
+ type Sensitivity = Ratio<BigUint>;
+
+ fn from_budget(budget: ZCdpBudget) -> DiscreteGaussianDpStrategy<ZCdpBudget> {
+ DiscreteGaussianDpStrategy { budget }
+ }
+
+ /// Create a new sampler from the Discrete Gaussian Distribution with a standard
+ /// deviation calibrated to provide `1/2 epsilon^2` zero-concentrated differential
+ /// privacy when added to the result of an integer-valued function with sensitivity
+ /// `sensitivity`, following Theorem 4 from [[CKS20]]
+ ///
+ /// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+ fn create_distribution(
+ &self,
+ sensitivity: Ratio<BigUint>,
+ ) -> Result<DiscreteGaussian, DpError> {
+ DiscreteGaussian::new(sensitivity / self.budget.epsilon.clone())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use crate::dp::Rational;
+ use crate::vdaf::xof::SeedStreamSha3;
+
+ use num_bigint::{BigUint, Sign, ToBigInt, ToBigUint};
+ use num_traits::{One, Signed, ToPrimitive};
+ use rand::{distributions::Distribution, SeedableRng};
+ use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
+ use std::collections::HashMap;
+
+ #[test]
+ fn test_discrete_gaussian() {
+ let sampler =
+ DiscreteGaussian::new(Ratio::<BigUint>::from_integer(BigUint::from(5u8))).unwrap();
+
+ // check samples are consistent
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let samples: Vec<i8> = (0..10)
+ .map(|_| i8::try_from(sampler.sample(&mut rng)).unwrap())
+ .collect();
+ let samples1: Vec<i8> = (0..10)
+ .map(|_| i8::try_from(sampler.sample(&mut rng)).unwrap())
+ .collect();
+ assert_eq!(samples, vec![-3, -11, -3, 5, 1, 5, 2, 2, 1, 18]);
+ assert_eq!(samples1, vec![4, -4, -5, -2, 0, -5, -3, 1, 1, -2]);
+ }
+
+ #[test]
+ /// Make sure that the distribution created by `create_distribution`
+ /// of `ZCdpDicreteGaussian` is the same one as manually creating one
+ /// by using the constructor of `DiscreteGaussian` directly.
+ fn test_zcdp_discrete_gaussian() {
+ // sample from a manually created distribution
+ let sampler1 =
+ DiscreteGaussian::new(Ratio::<BigUint>::from_integer(BigUint::from(4u8))).unwrap();
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let samples1: Vec<i8> = (0..10)
+ .map(|_| i8::try_from(sampler1.sample(&mut rng)).unwrap())
+ .collect();
+
+ // sample from the distribution created by the `zcdp` strategy
+ let zcdp = ZCdpDiscreteGaussian {
+ budget: ZCdpBudget::new(Rational::try_from(0.25).unwrap()),
+ };
+ let sampler2 = zcdp
+ .create_distribution(Ratio::<BigUint>::from_integer(1u8.into()))
+ .unwrap();
+ let mut rng2 = SeedStreamSha3::from_seed([0u8; 16]);
+ let samples2: Vec<i8> = (0..10)
+ .map(|_| i8::try_from(sampler2.sample(&mut rng2)).unwrap())
+ .collect();
+
+ assert_eq!(samples2, samples1);
+ }
+
+ pub fn test_mean<FS: FnMut() -> BigInt>(
+ mut sampler: FS,
+ hyp_mean: f64,
+ hyp_var: f64,
+ alpha: f64,
+ n: u32,
+ ) -> bool {
+ // we test if the mean from our sampler is within the given error margin assuimng its
+ // normally distributed with mean hyp_mean and variance sqrt(hyp_var/n)
+ // this assumption is from the central limit theorem
+
+ // inverse cdf (quantile function) is F s.t. P[X<=F(p)]=p for X ~ N(0,1)
+ // (i.e. X from the standard normal distribution)
+ let probit = |p| Normal::new(0.0, 1.0).unwrap().inverse_cdf(p);
+
+ // x such that the probability of a N(0,1) variable attaining
+ // a value outside of (-x, x) is alpha
+ let z_stat = probit(alpha / 2.).abs();
+
+ // confidence interval for the mean
+ let abs_p_tol = Ratio::<BigInt>::from_float(z_stat * (hyp_var / n as f64).sqrt()).unwrap();
+
+ // take n samples from the distribution, compute empirical mean
+ let emp_mean = Ratio::<BigInt>::new((0..n).map(|_| sampler()).sum::<BigInt>(), n.into());
+
+ (emp_mean - Ratio::<BigInt>::from_float(hyp_mean).unwrap()).abs() < abs_p_tol
+ }
+
+ fn histogram(
+ d: &Vec<BigInt>,
+ bin_bounds: &[Option<(BigInt, BigInt)>],
+ smallest: BigInt,
+ largest: BigInt,
+ ) -> HashMap<Option<(BigInt, BigInt)>, u64> {
+ // a binned histogram of the samples in `d`
+ // used for chi_square test
+
+ fn insert<T>(hist: &mut HashMap<T, u64>, key: &T, val: u64)
+ where
+ T: Eq + std::hash::Hash + Clone,
+ {
+ *hist.entry(key.clone()).or_default() += val;
+ }
+
+ // regular histogram
+ let mut hist = HashMap::<BigInt, u64>::new();
+ //binned histogram
+ let mut bin_hist = HashMap::<Option<(BigInt, BigInt)>, u64>::new();
+
+ for val in d {
+ // throw outliers with bound bins
+ if val < &smallest || val > &largest {
+ insert(&mut bin_hist, &None, 1);
+ } else {
+ insert(&mut hist, val, 1);
+ }
+ }
+ // sort values into their bins
+ for (a, b) in bin_bounds.iter().flatten() {
+ for i in range_inclusive(a.clone(), b.clone()) {
+ if let Some(count) = hist.get(&i) {
+ insert(&mut bin_hist, &Some((a.clone(), b.clone())), *count);
+ }
+ }
+ }
+ bin_hist
+ }
+
+ fn discrete_gauss_cdf_approx(
+ sigma: &BigUint,
+ bin_bounds: &[Option<(BigInt, BigInt)>],
+ ) -> HashMap<Option<(BigInt, BigInt)>, f64> {
+ // approximate bin probabilties from theoretical distribution
+ // formula is eq. (1) on page 3 of [[CKS20]]
+ //
+ // [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
+ let sigma = BigInt::from_biguint(Sign::Plus, sigma.clone());
+ let exp_sum = |lower: &BigInt, upper: &BigInt| {
+ range_inclusive(lower.clone(), upper.clone())
+ .map(|x: BigInt| {
+ f64::exp(
+ Ratio::<BigInt>::new(-(x.pow(2)), 2 * sigma.pow(2))
+ .to_f64()
+ .unwrap(),
+ )
+ })
+ .sum::<f64>()
+ };
+ // denominator is approximate up to 10 times the variance
+ // outside of that probabilities should be very small
+ // so the error will be negligible for the test
+ let denom = exp_sum(&(-10i8 * sigma.pow(2)), &(10i8 * sigma.pow(2)));
+
+ // compute probabilities for each bin
+ let mut cdf = HashMap::new();
+ let mut p_outside = 1.0; // probability of not landing inside bin boundaries
+ for (a, b) in bin_bounds.iter().flatten() {
+ let entry = exp_sum(a, b) / denom;
+ assert!(!entry.is_zero() && entry.is_finite());
+ cdf.insert(Some((a.clone(), b.clone())), entry);
+ p_outside -= entry;
+ }
+ cdf.insert(None, p_outside);
+ cdf
+ }
+
+ fn chi_square(sigma: &BigUint, n_bins: usize, alpha: f64) -> bool {
+ // perform pearsons chi-squared test on the discrete gaussian sampler
+
+ let sigma_signed = BigInt::from_biguint(Sign::Plus, sigma.clone());
+
+ // cut off at 3 times the std. and collect all outliers in a seperate bin
+ let global_bound = 3u8 * sigma_signed;
+
+ // bounds of bins
+ let lower_bounds = range_inclusive(-global_bound.clone(), global_bound.clone()).step_by(
+ ((2u8 * global_bound.clone()) / BigInt::from(n_bins))
+ .try_into()
+ .unwrap(),
+ );
+ let mut bin_bounds: Vec<Option<(BigInt, BigInt)>> = std::iter::zip(
+ lower_bounds.clone().take(n_bins),
+ lower_bounds.map(|x: BigInt| x - 1u8).skip(1),
+ )
+ .map(Some)
+ .collect();
+ bin_bounds.push(None); // bin for outliers
+
+ // approximate bin probabilities
+ let cdf = discrete_gauss_cdf_approx(sigma, &bin_bounds);
+
+ // chi2 stat wants at least 5 expected entries per bin
+ // so we choose n_samples in a way that gives us that
+ let n_samples = cdf
+ .values()
+ .map(|val| f64::ceil(5.0 / *val) as u32)
+ .max()
+ .unwrap();
+
+ // collect that number of samples
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let samples: Vec<BigInt> = (1..n_samples)
+ .map(|_| {
+ sample_discrete_gaussian(&Ratio::<BigUint>::from_integer(sigma.clone()), &mut rng)
+ })
+ .collect();
+
+ // make a histogram from the samples
+ let hist = histogram(&samples, &bin_bounds, -global_bound.clone(), global_bound);
+
+ // compute pearsons chi-squared test statistic
+ let stat: f64 = bin_bounds
+ .iter()
+ .map(|key| {
+ let expected = cdf.get(&(key.clone())).unwrap() * n_samples as f64;
+ if let Some(val) = hist.get(&(key.clone())) {
+ (*val as f64 - expected).powf(2.) / expected
+ } else {
+ 0.0
+ }
+ })
+ .sum::<f64>();
+
+ let chi2 = ChiSquared::new((cdf.len() - 1) as f64).unwrap();
+ // the probability of observing X >= stat for X ~ chi-squared
+ // (the "p-value")
+ let p = 1.0 - chi2.cdf(stat);
+
+ p > alpha
+ }
+
+ #[test]
+ fn empirical_test_gauss() {
+ [100, 2000, 20000].iter().for_each(|p| {
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let sampler = || {
+ sample_discrete_gaussian(
+ &Ratio::<BigUint>::from_integer((*p).to_biguint().unwrap()),
+ &mut rng,
+ )
+ };
+ let mean = 0.0;
+ let var = (p * p) as f64;
+ assert!(
+ test_mean(sampler, mean, var, 0.00001, 1000),
+ "Empirical evaluation of discrete Gaussian({:?}) sampler mean failed.",
+ p
+ );
+ });
+ // we only do chi square for std 100 because it's expensive
+ assert!(chi_square(&(100u8.to_biguint().unwrap()), 10, 0.05));
+ }
+
+ #[test]
+ fn empirical_test_bernoulli_mean() {
+ [2u8, 5u8, 7u8, 9u8].iter().for_each(|p| {
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let sampler = || {
+ if sample_bernoulli(
+ &Ratio::<BigUint>::new(BigUint::one(), (*p).into()),
+ &mut rng,
+ ) {
+ BigInt::one()
+ } else {
+ BigInt::zero()
+ }
+ };
+ let mean = 1. / (*p as f64);
+ let var = mean * (1. - mean);
+ assert!(
+ test_mean(sampler, mean, var, 0.00001, 1000),
+ "Empirical evaluation of the Bernoulli(1/{:?}) distribution mean failed",
+ p
+ );
+ })
+ }
+
+ #[test]
+ fn empirical_test_geometric_mean() {
+ [2u8, 5u8, 7u8, 9u8].iter().for_each(|p| {
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let sampler = || {
+ sample_geometric_exp(
+ &Ratio::<BigUint>::new(BigUint::one(), (*p).into()),
+ &mut rng,
+ )
+ .to_bigint()
+ .unwrap()
+ };
+ let p_prob = 1. - f64::exp(-(1. / *p as f64));
+ let mean = (1. - p_prob) / p_prob;
+ let var = (1. - p_prob) / p_prob.powi(2);
+ assert!(
+ test_mean(sampler, mean, var, 0.0001, 1000),
+ "Empirical evaluation of the Geometric(1-exp(-1/{:?})) distribution mean failed",
+ p
+ );
+ })
+ }
+
+ #[test]
+ fn empirical_test_laplace_mean() {
+ [2u8, 5u8, 7u8, 9u8].iter().for_each(|p| {
+ let mut rng = SeedStreamSha3::from_seed([0u8; 16]);
+ let sampler = || {
+ sample_discrete_laplace(
+ &Ratio::<BigUint>::new(BigUint::one(), (*p).into()),
+ &mut rng,
+ )
+ };
+ let mean = 0.0;
+ let var = (1. / *p as f64).powi(2);
+ assert!(
+ test_mean(sampler, mean, var, 0.0001, 1000),
+ "Empirical evaluation of the Laplace(0,1/{:?}) distribution mean failed",
+ p
+ );
+ })
+ }
+}
diff --git a/third_party/rust/prio/src/fft.rs b/third_party/rust/prio/src/fft.rs
new file mode 100644
index 0000000000..cac59a89ea
--- /dev/null
+++ b/third_party/rust/prio/src/fft.rs
@@ -0,0 +1,222 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
+//! Transform (DFT) over a slice of field elements.
+
+use crate::field::FftFriendlyFieldElement;
+use crate::fp::{log2, MAX_ROOTS};
+
+use std::convert::TryFrom;
+
+/// An error returned by an FFT operation.
+#[derive(Debug, PartialEq, Eq, thiserror::Error)]
+pub enum FftError {
+ /// The output is too small.
+ #[error("output slice is smaller than specified size")]
+ OutputTooSmall,
+ /// The specified size is too large.
+ #[error("size is larger than than maximum permitted")]
+ SizeTooLarge,
+ /// The specified size is not a power of 2.
+ #[error("size is not a power of 2")]
+ SizeInvalid,
+}
+
+/// Sets `outp` to the DFT of `inp`.
+///
+/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input
+/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of
+/// unity.
+#[allow(clippy::many_single_char_names)]
+pub fn discrete_fourier_transform<F: FftFriendlyFieldElement>(
+ outp: &mut [F],
+ inp: &[F],
+ size: usize,
+) -> Result<(), FftError> {
+ let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?;
+
+ if size > outp.len() {
+ return Err(FftError::OutputTooSmall);
+ }
+
+ if size > 1 << MAX_ROOTS {
+ return Err(FftError::SizeTooLarge);
+ }
+
+ if size != 1 << d {
+ return Err(FftError::SizeInvalid);
+ }
+
+ for (i, outp_val) in outp[..size].iter_mut().enumerate() {
+ let j = bitrev(d, i);
+ *outp_val = if j < inp.len() { inp[j] } else { F::zero() };
+ }
+
+ let mut w: F;
+ for l in 1..d + 1 {
+ w = F::one();
+ let r = F::root(l).unwrap();
+ let y = 1 << (l - 1);
+ let chunk = (size / y) >> 1;
+
+ // unrolling first iteration of i-loop.
+ for j in 0..chunk {
+ let x = j << l;
+ let u = outp[x];
+ let v = outp[x + y];
+ outp[x] = u + v;
+ outp[x + y] = u - v;
+ }
+
+ for i in 1..y {
+ w *= r;
+ for j in 0..chunk {
+ let x = (j << l) + i;
+ let u = outp[x];
+ let v = w * outp[x + y];
+ outp[x] = u + v;
+ outp[x + y] = u - v;
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Sets `outp` to the inverse of the DFT of `inp`.
+#[cfg(test)]
+pub(crate) fn discrete_fourier_transform_inv<F: FftFriendlyFieldElement>(
+ outp: &mut [F],
+ inp: &[F],
+ size: usize,
+) -> Result<(), FftError> {
+ let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv();
+ discrete_fourier_transform(outp, inp, size)?;
+ discrete_fourier_transform_inv_finish(outp, size, size_inv);
+ Ok(())
+}
+
+/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to
+/// amortize the cost the modular inverse across multiple inverse DFT operations.
+pub(crate) fn discrete_fourier_transform_inv_finish<F: FftFriendlyFieldElement>(
+ outp: &mut [F],
+ size: usize,
+ size_inv: F,
+) {
+ let mut tmp: F;
+ outp[0] *= size_inv;
+ outp[size >> 1] *= size_inv;
+ for i in 1..size >> 1 {
+ tmp = outp[i] * size_inv;
+ outp[i] = outp[size - i] * size_inv;
+ outp[size - i] = tmp;
+ }
+}
+
+// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
+fn bitrev(d: usize, x: usize) -> usize {
+ let mut y = 0;
+ for i in 0..d {
+ y += ((x >> i) & 1) << (d - i);
+ }
+ y >> 1
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{random_vector, split_vector, Field128, Field64, FieldElement, FieldPrio2};
+ use crate::polynomial::{poly_fft, PolyAuxMemory};
+
+ fn discrete_fourier_transform_then_inv_test<F: FftFriendlyFieldElement>() -> Result<(), FftError>
+ {
+ let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];
+
+ for size in test_sizes.iter() {
+ let mut tmp = vec![F::zero(); *size];
+ let mut got = vec![F::zero(); *size];
+ let want = random_vector(*size).unwrap();
+
+ discrete_fourier_transform(&mut tmp, &want, want.len())?;
+ discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?;
+ assert_eq!(got, want);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_priov2_field32() {
+ discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field64() {
+ discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_field128() {
+ discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error");
+ }
+
+ #[test]
+ fn test_recursive_fft() {
+ let size = 128;
+ let mut mem = PolyAuxMemory::new(size / 2);
+
+ let inp = random_vector(size).unwrap();
+ let mut want = vec![FieldPrio2::zero(); size];
+ let mut got = vec![FieldPrio2::zero(); size];
+
+ discrete_fourier_transform::<FieldPrio2>(&mut want, &inp, inp.len()).unwrap();
+
+ poly_fft(
+ &mut got,
+ &inp,
+ &mem.roots_2n,
+ size,
+ false,
+ &mut mem.fft_memory,
+ );
+
+ assert_eq!(got, want);
+ }
+
+ // This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial
+ // over secret shares and summing up the coefficients is equivalent to interpolating a
+ // polynomial over the plaintext data.
+ #[test]
+ fn test_fft_linearity() {
+ let len = 16;
+ let num_shares = 3;
+ let x: Vec<Field64> = random_vector(len).unwrap();
+ let mut x_shares = split_vector(&x, num_shares).unwrap();
+
+ // Just for fun, let's do something different with a subset of the inputs. For the first
+ // share, every odd element is set to the plaintext value. For all shares but the first,
+ // every odd element is set to 0.
+ for (i, x_val) in x.iter().enumerate() {
+ if i % 2 != 0 {
+ x_shares[0][i] = *x_val;
+ for x_share in x_shares[1..num_shares].iter_mut() {
+ x_share[i] = Field64::zero();
+ }
+ }
+ }
+
+ let mut got = vec![Field64::zero(); len];
+ let mut buf = vec![Field64::zero(); len];
+ for share in x_shares {
+ discrete_fourier_transform_inv(&mut buf, &share, len).unwrap();
+ for i in 0..len {
+ got[i] += buf[i];
+ }
+ }
+
+ let mut want = vec![Field64::zero(); len];
+ discrete_fourier_transform_inv(&mut want, &x, len).unwrap();
+
+ assert_eq!(got, want);
+ }
+}
diff --git a/third_party/rust/prio/src/field.rs b/third_party/rust/prio/src/field.rs
new file mode 100644
index 0000000000..fb931de2d3
--- /dev/null
+++ b/third_party/rust/prio/src/field.rs
@@ -0,0 +1,1190 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Finite field arithmetic.
+//!
+//! Basic field arithmetic is captured in the [`FieldElement`] trait. Fields used in Prio implement
+//! [`FftFriendlyFieldElement`], and have an associated element called the "generator" that
+//! generates a multiplicative subgroup of order `2^n` for some `n`.
+
+#[cfg(feature = "crypto-dependencies")]
+use crate::prng::{Prng, PrngError};
+use crate::{
+ codec::{CodecError, Decode, Encode},
+ fp::{FP128, FP32, FP64},
+};
+use serde::{
+ de::{DeserializeOwned, Visitor},
+ Deserialize, Deserializer, Serialize, Serializer,
+};
+use std::{
+ cmp::min,
+ convert::{TryFrom, TryInto},
+ fmt::{self, Debug, Display, Formatter},
+ hash::{Hash, Hasher},
+ io::{Cursor, Read},
+ marker::PhantomData,
+ ops::{
+ Add, AddAssign, BitAnd, ControlFlow, Div, DivAssign, Mul, MulAssign, Neg, Shl, Shr, Sub,
+ SubAssign,
+ },
+};
+use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
+
+#[cfg(feature = "experimental")]
+mod field255;
+
+#[cfg(feature = "experimental")]
+pub use field255::Field255;
+
+/// Possible errors from finite field operations.
+#[derive(Debug, thiserror::Error)]
+pub enum FieldError {
+ /// Input sizes do not match.
+ #[error("input sizes do not match")]
+ InputSizeMismatch,
+ /// Returned when decoding a [`FieldElement`] from a too-short byte string.
+ #[error("short read from bytes")]
+ ShortRead,
+ /// Returned when decoding a [`FieldElement`] from a byte string that encodes an integer greater
+ /// than or equal to the field modulus.
+ #[error("read from byte slice exceeds modulus")]
+ ModulusOverflow,
+ /// Error while performing I/O.
+ #[error("I/O error")]
+ Io(#[from] std::io::Error),
+ /// Error encoding or decoding a field.
+ #[error("Codec error")]
+ Codec(#[from] CodecError),
+ /// Error converting to [`FieldElementWithInteger::Integer`].
+ #[error("Integer TryFrom error")]
+ IntegerTryFrom,
+}
+
+/// Objects with this trait represent an element of `GF(p)` for some prime `p`.
+pub trait FieldElement:
+ Sized
+ + Debug
+ + Copy
+ + PartialEq
+ + Eq
+ + ConstantTimeEq
+ + ConditionallySelectable
+ + ConditionallyNegatable
+ + Add<Output = Self>
+ + AddAssign
+ + Sub<Output = Self>
+ + SubAssign
+ + Mul<Output = Self>
+ + MulAssign
+ + Div<Output = Self>
+ + DivAssign
+ + Neg<Output = Self>
+ + Display
+ + for<'a> TryFrom<&'a [u8], Error = FieldError>
+ // NOTE Ideally we would require `Into<[u8; Self::ENCODED_SIZE]>` instead of `Into<Vec<u8>>`,
+ // since the former avoids a heap allocation and can easily be converted into Vec<u8>, but that
+ // isn't possible yet[1]. However we can provide the impl on FieldElement implementations.
+ // [1]: https://github.com/rust-lang/rust/issues/60551
+ + Into<Vec<u8>>
+ + Serialize
+ + DeserializeOwned
+ + Encode
+ + Decode
+ + 'static // NOTE This bound is needed for downcasting a `dyn Gadget<F>>` to a concrete type.
+{
+ /// Size in bytes of an encoded field element.
+ const ENCODED_SIZE: usize;
+
+ /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
+ fn inv(&self) -> Self;
+
+ /// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the
+ /// field. The `m` most significant bits are cleared, where `m` is equal to the length of
+ /// [`Self::Integer`] in bits minus the length of the modulus in bits.
+ ///
+ /// # Errors
+ ///
+ /// An error is returned if the provided slice is too small to encode a field element or if the
+ /// result encodes an integer larger than or equal to the field modulus.
+ ///
+ /// # Warnings
+ ///
+ /// This function should only be used within [`prng::Prng`] to convert a random byte string into
+ /// a field element. Use [`Self::decode`] to deserialize field elements. Use
+ /// [`field::rand`] or [`prng::Prng`] to randomly generate field elements.
+ #[doc(hidden)]
+ fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError>;
+
+ /// Returns the additive identity.
+ fn zero() -> Self;
+
+ /// Returns the multiplicative identity.
+ fn one() -> Self;
+
+ /// Convert a slice of field elements into a vector of bytes.
+ ///
+ /// # Notes
+ ///
+ /// Ideally we would implement `From<&[F: FieldElement]> for Vec<u8>` or the corresponding
+ /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
+ /// impossible.
+ fn slice_into_byte_vec(values: &[Self]) -> Vec<u8> {
+ let mut vec = Vec::with_capacity(values.len() * Self::ENCODED_SIZE);
+ encode_fieldvec(values, &mut vec);
+ vec
+ }
+
+ /// Convert a slice of bytes into a vector of field elements. The slice is interpreted as a
+ /// sequence of [`Self::ENCODED_SIZE`]-byte sequences.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the length of the provided byte slice is not a multiple of the size of a
+ /// field element, or if any of the values in the byte slice are invalid encodings of a field
+ /// element, because the encoded integer is larger than or equal to the field modulus.
+ ///
+ /// # Notes
+ ///
+ /// Ideally we would implement `From<&[u8]> for Vec<F: FieldElement>` or the corresponding
+ /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
+ /// impossible.
+ fn byte_slice_into_vec(bytes: &[u8]) -> Result<Vec<Self>, FieldError> {
+ if bytes.len() % Self::ENCODED_SIZE != 0 {
+ return Err(FieldError::ShortRead);
+ }
+ let mut vec = Vec::with_capacity(bytes.len() / Self::ENCODED_SIZE);
+ for chunk in bytes.chunks_exact(Self::ENCODED_SIZE) {
+ vec.push(Self::get_decoded(chunk)?);
+ }
+ Ok(vec)
+ }
+}
+
+/// Extension trait for field elements that can be converted back and forth to an integer type.
+///
+/// The `Integer` associated type is an integer (primitive or otherwise) that supports various
+/// arithmetic operations. The order of the field is guaranteed to fit inside the range of the
+/// integer type. This trait also defines methods on field elements, `pow` and `modulus`, that make
+/// use of the associated integer type.
+pub trait FieldElementWithInteger: FieldElement + From<Self::Integer> {
+ /// The error returned if converting `usize` to an `Integer` fails.
+ type IntegerTryFromError: std::error::Error;
+
+ /// The error returned if converting an `Integer` to a `u64` fails.
+ type TryIntoU64Error: std::error::Error;
+
+ /// The integer representation of a field element.
+ type Integer: Copy
+ + Debug
+ + Eq
+ + Ord
+ + BitAnd<Output = Self::Integer>
+ + Div<Output = Self::Integer>
+ + Shl<usize, Output = Self::Integer>
+ + Shr<usize, Output = Self::Integer>
+ + Add<Output = Self::Integer>
+ + Sub<Output = Self::Integer>
+ + From<Self>
+ + TryFrom<usize, Error = Self::IntegerTryFromError>
+ + TryInto<u64, Error = Self::TryIntoU64Error>;
+
+ /// Modular exponentation, i.e., `self^exp (mod p)`.
+ fn pow(&self, exp: Self::Integer) -> Self;
+
+ /// Returns the prime modulus `p`.
+ fn modulus() -> Self::Integer;
+}
+
+/// Methods common to all `FieldElementWithInteger` implementations that are private to the crate.
+pub(crate) trait FieldElementWithIntegerExt: FieldElementWithInteger {
+ /// Encode `input` as bitvector of elements of `Self`. Output is written into the `output` slice.
+ /// If `output.len()` is smaller than the number of bits required to respresent `input`,
+ /// an error is returned.
+ ///
+ /// # Arguments
+ ///
+ /// * `input` - The field element to encode
+ /// * `output` - The slice to write the encoded bits into. Least signicant bit comes first
+ fn fill_with_bitvector_representation(
+ input: &Self::Integer,
+ output: &mut [Self],
+ ) -> Result<(), FieldError> {
+ // Create a mutable copy of `input`. In each iteration of the following loop we take the
+ // least significant bit, and shift input to the right by one bit.
+ let mut i = *input;
+
+ let one = Self::Integer::from(Self::one());
+ for bit in output.iter_mut() {
+ let w = Self::from(i & one);
+ *bit = w;
+ i = i >> 1;
+ }
+
+ // If `i` is still not zero, this means that it cannot be encoded by `bits` bits.
+ if i != Self::Integer::from(Self::zero()) {
+ return Err(FieldError::InputSizeMismatch);
+ }
+
+ Ok(())
+ }
+
+ /// Encode `input` as `bits`-bit vector of elements of `Self` if it's small enough
+ /// to be represented with that many bits.
+ ///
+ /// # Arguments
+ ///
+ /// * `input` - The field element to encode
+ /// * `bits` - The number of bits to use for the encoding
+ fn encode_into_bitvector_representation(
+ input: &Self::Integer,
+ bits: usize,
+ ) -> Result<Vec<Self>, FieldError> {
+ let mut result = vec![Self::zero(); bits];
+ Self::fill_with_bitvector_representation(input, &mut result)?;
+ Ok(result)
+ }
+
+ /// Decode the bitvector-represented value `input` into a simple representation as a single
+ /// field element.
+ ///
+ /// # Errors
+ ///
+ /// This function errors if `2^input.len() - 1` does not fit into the field `Self`.
+ fn decode_from_bitvector_representation(input: &[Self]) -> Result<Self, FieldError> {
+ let fi_one = Self::Integer::from(Self::one());
+
+ if !Self::valid_integer_bitlength(input.len()) {
+ return Err(FieldError::ModulusOverflow);
+ }
+
+ let mut decoded = Self::zero();
+ for (l, bit) in input.iter().enumerate() {
+ let w = fi_one << l;
+ decoded += Self::from(w) * *bit;
+ }
+ Ok(decoded)
+ }
+
+ /// Interpret `i` as [`Self::Integer`] if it's representable in that type and smaller than the
+ /// field modulus.
+ fn valid_integer_try_from<N>(i: N) -> Result<Self::Integer, FieldError>
+ where
+ Self::Integer: TryFrom<N>,
+ {
+ let i_int = Self::Integer::try_from(i).map_err(|_| FieldError::IntegerTryFrom)?;
+ if Self::modulus() <= i_int {
+ return Err(FieldError::ModulusOverflow);
+ }
+ Ok(i_int)
+ }
+
+ /// Check if the largest number representable with `bits` bits (i.e. 2^bits - 1) is
+ /// representable in this field.
+ fn valid_integer_bitlength(bits: usize) -> bool {
+ if bits >= 8 * Self::ENCODED_SIZE {
+ return false;
+ }
+ if Self::modulus() >> bits != Self::Integer::from(Self::zero()) {
+ return true;
+ }
+ false
+ }
+}
+
+impl<F: FieldElementWithInteger> FieldElementWithIntegerExt for F {}
+
+/// Methods common to all `FieldElement` implementations that are private to the crate.
+pub(crate) trait FieldElementExt: FieldElement {
+ /// Try to interpret a slice of [`Self::ENCODED_SIZE`] random bytes as an element in the field. If
+ /// the input represents an integer greater than or equal to the field modulus, then
+ /// [`ControlFlow::Continue`] is returned instead, to indicate that an enclosing rejection sampling
+ /// loop should try again with different random bytes.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `bytes` is not of length [`Self::ENCODED_SIZE`].
+ fn from_random_rejection(bytes: &[u8]) -> ControlFlow<Self, ()> {
+ match Self::try_from_random(bytes) {
+ Ok(x) => ControlFlow::Break(x),
+ Err(FieldError::ModulusOverflow) => ControlFlow::Continue(()),
+ Err(err) => panic!("unexpected error: {err}"),
+ }
+ }
+}
+
+impl<F: FieldElement> FieldElementExt for F {}
+
+/// serde Visitor implementation used to generically deserialize `FieldElement`
+/// values from byte arrays.
+pub(crate) struct FieldElementVisitor<F: FieldElement> {
+ pub(crate) phantom: PhantomData<F>,
+}
+
+impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> {
+ type Value = F;
+
+ fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
+ formatter.write_fmt(format_args!("an array of {} bytes", F::ENCODED_SIZE))
+ }
+
+ fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
+ where
+ E: serde::de::Error,
+ {
+ Self::Value::try_from(v).map_err(E::custom)
+ }
+
+ fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::SeqAccess<'de>,
+ {
+ let mut bytes = vec![];
+ while let Some(byte) = seq.next_element()? {
+ bytes.push(byte);
+ }
+
+ self.visit_bytes(&bytes)
+ }
+}
+
+/// Objects with this trait represent an element of `GF(p)`, where `p` is some prime and the
+/// field's multiplicative group has a subgroup with an order that is a power of 2, and at least
+/// `2^20`.
+pub trait FftFriendlyFieldElement: FieldElementWithInteger {
+ /// Returns the size of the multiplicative subgroup generated by
+ /// [`FftFriendlyFieldElement::generator`].
+ fn generator_order() -> Self::Integer;
+
+ /// Returns the generator of the multiplicative subgroup of size
+ /// [`FftFriendlyFieldElement::generator_order`].
+ fn generator() -> Self;
+
+ /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
+ /// prinicpal root of unity is `1` by definition.
+ fn root(l: usize) -> Option<Self>;
+}
+
+macro_rules! make_field {
+ (
+ $(#[$meta:meta])*
+ $elem:ident, $int:ident, $fp:ident, $encoding_size:literal,
+ ) => {
+ $(#[$meta])*
+ ///
+ /// This structure represents a field element in a prime order field. The concrete
+ /// representation of the element is via the Montgomery domain. For an element `n` in
+ /// `GF(p)`, we store `n * R^-1 mod p` (where `R` is a given power of two). This
+ /// representation enables using a more efficient (and branchless) multiplication algorithm,
+ /// at the expense of having to convert elements between their Montgomery domain
+ /// representation and natural representation. For calculations with many multiplications or
+ /// exponentiations, this is worthwhile.
+ ///
+ /// As an invariant, this integer representing the field element in the Montgomery domain
+ /// must be less than the field modulus, `p`.
+ #[derive(Clone, Copy, PartialOrd, Ord, Default)]
+ pub struct $elem(u128);
+
+ impl $elem {
+ /// Attempts to instantiate an `$elem` from the first `Self::ENCODED_SIZE` bytes in the
+ /// provided slice. The decoded value will be bitwise-ANDed with `mask` before reducing
+ /// it using the field modulus.
+ ///
+ /// # Errors
+ ///
+ /// An error is returned if the provided slice is not long enough to encode a field
+ /// element or if the decoded value is greater than the field prime.
+ ///
+ /// # Notes
+ ///
+ /// We cannot use `u128::from_le_bytes` or `u128::from_be_bytes` because those functions
+ /// expect inputs to be exactly 16 bytes long. Our encoding of most field elements is
+ /// more compact.
+ fn try_from_bytes(bytes: &[u8], mask: u128) -> Result<Self, FieldError> {
+ if Self::ENCODED_SIZE > bytes.len() {
+ return Err(FieldError::ShortRead);
+ }
+
+ let mut int = 0;
+ for i in 0..Self::ENCODED_SIZE {
+ int |= (bytes[i] as u128) << (i << 3);
+ }
+
+ int &= mask;
+
+ if int >= $fp.p {
+ return Err(FieldError::ModulusOverflow);
+ }
+ // FieldParameters::montgomery() will return a value that has been fully reduced
+ // mod p, satisfying the invariant on Self.
+ Ok(Self($fp.montgomery(int)))
+ }
+ }
+
+ impl PartialEq for $elem {
+ fn eq(&self, rhs: &Self) -> bool {
+ // The fields included in this comparison MUST match the fields
+ // used in Hash::hash
+ // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
+
+ // Check the invariant that the integer representation is fully reduced.
+ debug_assert!(self.0 < $fp.p);
+ debug_assert!(rhs.0 < $fp.p);
+
+ self.0 == rhs.0
+ }
+ }
+
+ impl ConstantTimeEq for $elem {
+ fn ct_eq(&self, rhs: &Self) -> Choice {
+ self.0.ct_eq(&rhs.0)
+ }
+ }
+
+ impl ConditionallySelectable for $elem {
+ fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
+ Self(u128::conditional_select(&a.0, &b.0, choice))
+ }
+ }
+
+ impl Hash for $elem {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ // The fields included in this hash MUST match the fields used
+ // in PartialEq::eq
+ // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
+
+ // Check the invariant that the integer representation is fully reduced.
+ debug_assert!(self.0 < $fp.p);
+
+ self.0.hash(state);
+ }
+ }
+
+ impl Eq for $elem {}
+
+ impl Add for $elem {
+ type Output = $elem;
+ fn add(self, rhs: Self) -> Self {
+ // FieldParameters::add() returns a value that has been fully reduced
+ // mod p, satisfying the invariant on Self.
+ Self($fp.add(self.0, rhs.0))
+ }
+ }
+
+ impl Add for &$elem {
+ type Output = $elem;
+ fn add(self, rhs: Self) -> $elem {
+ *self + *rhs
+ }
+ }
+
+ impl AddAssign for $elem {
+ fn add_assign(&mut self, rhs: Self) {
+ *self = *self + rhs;
+ }
+ }
+
+ impl Sub for $elem {
+ type Output = $elem;
+ fn sub(self, rhs: Self) -> Self {
+ // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub()
+ // returns a value less than p, satisfying the invariant on Self.
+ Self($fp.sub(self.0, rhs.0))
+ }
+ }
+
+ impl Sub for &$elem {
+ type Output = $elem;
+ fn sub(self, rhs: Self) -> $elem {
+ *self - *rhs
+ }
+ }
+
+ impl SubAssign for $elem {
+ fn sub_assign(&mut self, rhs: Self) {
+ *self = *self - rhs;
+ }
+ }
+
+ impl Mul for $elem {
+ type Output = $elem;
+ fn mul(self, rhs: Self) -> Self {
+ // FieldParameters::mul() always returns a value less than p, so the invariant on
+ // Self is satisfied.
+ Self($fp.mul(self.0, rhs.0))
+ }
+ }
+
+ impl Mul for &$elem {
+ type Output = $elem;
+ fn mul(self, rhs: Self) -> $elem {
+ *self * *rhs
+ }
+ }
+
+ impl MulAssign for $elem {
+ fn mul_assign(&mut self, rhs: Self) {
+ *self = *self * rhs;
+ }
+ }
+
+ impl Div for $elem {
+ type Output = $elem;
+ #[allow(clippy::suspicious_arithmetic_impl)]
+ fn div(self, rhs: Self) -> Self {
+ self * rhs.inv()
+ }
+ }
+
+ impl Div for &$elem {
+ type Output = $elem;
+ fn div(self, rhs: Self) -> $elem {
+ *self / *rhs
+ }
+ }
+
+ impl DivAssign for $elem {
+ fn div_assign(&mut self, rhs: Self) {
+ *self = *self / rhs;
+ }
+ }
+
+ impl Neg for $elem {
+ type Output = $elem;
+ fn neg(self) -> Self {
+ // FieldParameters::neg() will return a value less than p because self.0 is less
+ // than p, and neg() dispatches to sub().
+ Self($fp.neg(self.0))
+ }
+ }
+
+ impl Neg for &$elem {
+ type Output = $elem;
+ fn neg(self) -> $elem {
+ -(*self)
+ }
+ }
+
+ impl From<$int> for $elem {
+ fn from(x: $int) -> Self {
+ // FieldParameters::montgomery() will return a value that has been fully reduced
+ // mod p, satisfying the invariant on Self.
+ Self($fp.montgomery(u128::try_from(x).unwrap()))
+ }
+ }
+
+ impl From<$elem> for $int {
+ fn from(x: $elem) -> Self {
+ $int::try_from($fp.residue(x.0)).unwrap()
+ }
+ }
+
+ impl PartialEq<$int> for $elem {
+ fn eq(&self, rhs: &$int) -> bool {
+ $fp.residue(self.0) == u128::try_from(*rhs).unwrap()
+ }
+ }
+
+ impl<'a> TryFrom<&'a [u8]> for $elem {
+ type Error = FieldError;
+
+ fn try_from(bytes: &[u8]) -> Result<Self, FieldError> {
+ Self::try_from_bytes(bytes, u128::MAX)
+ }
+ }
+
+ impl From<$elem> for [u8; $elem::ENCODED_SIZE] {
+ fn from(elem: $elem) -> Self {
+ let int = $fp.residue(elem.0);
+ let mut slice = [0; $elem::ENCODED_SIZE];
+ for i in 0..$elem::ENCODED_SIZE {
+ slice[i] = ((int >> (i << 3)) & 0xff) as u8;
+ }
+ slice
+ }
+ }
+
+ impl From<$elem> for Vec<u8> {
+ fn from(elem: $elem) -> Self {
+ <[u8; $elem::ENCODED_SIZE]>::from(elem).to_vec()
+ }
+ }
+
+ impl Display for $elem {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ write!(f, "{}", $fp.residue(self.0))
+ }
+ }
+
+ impl Debug for $elem {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", $fp.residue(self.0))
+ }
+ }
+
+ // We provide custom [`serde::Serialize`] and [`serde::Deserialize`] implementations because
+ // the derived implementations would represent `FieldElement` values as the backing `u128`,
+ // which is not what we want because (1) we can be more efficient in all cases and (2) in
+ // some circumstances, [some serializers don't support `u128`](https://github.com/serde-rs/json/issues/625).
+ impl Serialize for $elem {
+ fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+ let bytes: [u8; $elem::ENCODED_SIZE] = (*self).into();
+ serializer.serialize_bytes(&bytes)
+ }
+ }
+
+ impl<'de> Deserialize<'de> for $elem {
+ fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<$elem, D::Error> {
+ deserializer.deserialize_bytes(FieldElementVisitor { phantom: PhantomData })
+ }
+ }
+
+ impl Encode for $elem {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ let slice = <[u8; $elem::ENCODED_SIZE]>::from(*self);
+ bytes.extend_from_slice(&slice);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(Self::ENCODED_SIZE)
+ }
+ }
+
+ impl Decode for $elem {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut value = [0u8; $elem::ENCODED_SIZE];
+ bytes.read_exact(&mut value)?;
+ $elem::try_from_bytes(&value, u128::MAX).map_err(|e| {
+ CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>)
+ })
+ }
+ }
+
+ impl FieldElement for $elem {
+ const ENCODED_SIZE: usize = $encoding_size;
+ fn inv(&self) -> Self {
+ // FieldParameters::inv() ultimately relies on mul(), and will always return a
+ // value less than p.
+ Self($fp.inv(self.0))
+ }
+
+ fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> {
+ $elem::try_from_bytes(bytes, $fp.bit_mask)
+ }
+
+ fn zero() -> Self {
+ Self(0)
+ }
+
+ fn one() -> Self {
+ Self($fp.roots[0])
+ }
+ }
+
+ impl FieldElementWithInteger for $elem {
+ type Integer = $int;
+ type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
+ type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;
+
+ fn pow(&self, exp: Self::Integer) -> Self {
+ // FieldParameters::pow() relies on mul(), and will always return a value less
+ // than p.
+ Self($fp.pow(self.0, u128::try_from(exp).unwrap()))
+ }
+
+ fn modulus() -> Self::Integer {
+ $fp.p as $int
+ }
+ }
+
+ impl FftFriendlyFieldElement for $elem {
+ fn generator() -> Self {
+ Self($fp.g)
+ }
+
+ fn generator_order() -> Self::Integer {
+ 1 << (Self::Integer::try_from($fp.num_roots).unwrap())
+ }
+
+ fn root(l: usize) -> Option<Self> {
+ if l < min($fp.roots.len(), $fp.num_roots+1) {
+ Some(Self($fp.roots[l]))
+ } else {
+ None
+ }
+ }
+ }
+ };
+}
+
+make_field!(
+ /// Same as Field32, but encoded in little endian for compatibility with Prio v2.
+ FieldPrio2,
+ u32,
+ FP32,
+ 4,
+);
+
+make_field!(
+ /// `GF(18446744069414584321)`, a 64-bit field.
+ Field64,
+ u64,
+ FP64,
+ 8,
+);
+
+make_field!(
+ /// `GF(340282366920938462946865773367900766209)`, a 128-bit field.
+ Field128,
+ u128,
+ FP128,
+ 16,
+);
+
+/// Merge two vectors of fields by summing other_vector into accumulator.
+///
+/// # Errors
+///
+/// Fails if the two vectors do not have the same length.
+pub(crate) fn merge_vector<F: FieldElement>(
+ accumulator: &mut [F],
+ other_vector: &[F],
+) -> Result<(), FieldError> {
+ if accumulator.len() != other_vector.len() {
+ return Err(FieldError::InputSizeMismatch);
+ }
+ for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) {
+ *a += *o;
+ }
+
+ Ok(())
+}
+
+/// Outputs an additive secret sharing of the input.
+#[cfg(all(feature = "crypto-dependencies", test))]
+pub(crate) fn split_vector<F: FieldElement>(
+ inp: &[F],
+ num_shares: usize,
+) -> Result<Vec<Vec<F>>, PrngError> {
+ if num_shares == 0 {
+ return Ok(vec![]);
+ }
+
+ let mut outp = Vec::with_capacity(num_shares);
+ outp.push(inp.to_vec());
+
+ for _ in 1..num_shares {
+ let share: Vec<F> = random_vector(inp.len())?;
+ for (x, y) in outp[0].iter_mut().zip(&share) {
+ *x -= *y;
+ }
+ outp.push(share);
+ }
+
+ Ok(outp)
+}
+
+/// Generate a vector of uniformly distributed random field elements.
+#[cfg(feature = "crypto-dependencies")]
+#[cfg_attr(docsrs, doc(cfg(feature = "crypto-dependencies")))]
+pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> {
+ Ok(Prng::new()?.take(len).collect())
+}
+
+/// `encode_fieldvec` serializes a type that is equivalent to a vector of field elements.
+#[inline(always)]
+pub(crate) fn encode_fieldvec<F: FieldElement, T: AsRef<[F]>>(val: T, bytes: &mut Vec<u8>) {
+ for elem in val.as_ref() {
+ elem.encode(bytes);
+ }
+}
+
+/// `decode_fieldvec` deserializes some number of field elements from a cursor, and advances the
+/// cursor's position.
+pub(crate) fn decode_fieldvec<F: FieldElement>(
+ count: usize,
+ input: &mut Cursor<&[u8]>,
+) -> Result<Vec<F>, CodecError> {
+ let mut vec = Vec::with_capacity(count);
+ let mut buffer = [0u8; 64];
+ assert!(
+ buffer.len() >= F::ENCODED_SIZE,
+ "field is too big for buffer"
+ );
+ for _ in 0..count {
+ input.read_exact(&mut buffer[..F::ENCODED_SIZE])?;
+ vec.push(
+ F::try_from(&buffer[..F::ENCODED_SIZE]).map_err(|e| CodecError::Other(Box::new(e)))?,
+ );
+ }
+ Ok(vec)
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::{FieldElement, FieldElementWithInteger};
+ use crate::{codec::CodecError, field::FieldError, prng::Prng};
+ use assert_matches::assert_matches;
+ use std::{
+ collections::hash_map::DefaultHasher,
+ convert::{TryFrom, TryInto},
+ fmt::Debug,
+ hash::{Hash, Hasher},
+ io::Cursor,
+ ops::{Add, BitAnd, Div, Shl, Shr, Sub},
+ };
+
+ /// A test-only copy of `FieldElementWithInteger`.
+ ///
+ /// This trait is only used in tests, and it is implemented on some fields that do not have
+ /// `FieldElementWithInteger` implementations. This separate trait is used in order to avoid
+ /// affecting trait resolution with conditional compilation. Additionally, this trait only
+ /// requires the `Integer` associated type satisfy `Clone`, not `Copy`, so that it may be used
+ /// with arbitrary precision integer implementations.
+ pub(crate) trait TestFieldElementWithInteger:
+ FieldElement + From<Self::Integer>
+ {
+ type IntegerTryFromError: std::error::Error;
+ type TryIntoU64Error: std::error::Error;
+ type Integer: Clone
+ + Debug
+ + Eq
+ + Ord
+ + BitAnd<Output = Self::Integer>
+ + Div<Output = Self::Integer>
+ + Shl<usize, Output = Self::Integer>
+ + Shr<usize, Output = Self::Integer>
+ + Add<Output = Self::Integer>
+ + Sub<Output = Self::Integer>
+ + From<Self>
+ + TryFrom<usize, Error = Self::IntegerTryFromError>
+ + TryInto<u64, Error = Self::TryIntoU64Error>;
+
+ fn pow(&self, exp: Self::Integer) -> Self;
+
+ fn modulus() -> Self::Integer;
+ }
+
+ impl<F> TestFieldElementWithInteger for F
+ where
+ F: FieldElementWithInteger,
+ {
+ type IntegerTryFromError = <F as FieldElementWithInteger>::IntegerTryFromError;
+ type TryIntoU64Error = <F as FieldElementWithInteger>::TryIntoU64Error;
+ type Integer = <F as FieldElementWithInteger>::Integer;
+
+ fn pow(&self, exp: Self::Integer) -> Self {
+ <F as FieldElementWithInteger>::pow(self, exp)
+ }
+
+ fn modulus() -> Self::Integer {
+ <F as FieldElementWithInteger>::modulus()
+ }
+ }
+
+ pub(crate) fn field_element_test_common<F: TestFieldElementWithInteger>() {
+ let mut prng: Prng<F, _> = Prng::new().unwrap();
+ let int_modulus = F::modulus();
+ let int_one = F::Integer::try_from(1).unwrap();
+ let zero = F::zero();
+ let one = F::one();
+ let two = F::from(F::Integer::try_from(2).unwrap());
+ let four = F::from(F::Integer::try_from(4).unwrap());
+
+ // add
+ assert_eq!(F::from(int_modulus.clone() - int_one.clone()) + one, zero);
+ assert_eq!(one + one, two);
+ assert_eq!(two + F::from(int_modulus.clone()), two);
+
+ // add w/ assignment
+ let mut a = prng.get();
+ let b = prng.get();
+ let c = a + b;
+ a += b;
+ assert_eq!(a, c);
+
+ // sub
+ assert_eq!(zero - one, F::from(int_modulus.clone() - int_one.clone()));
+ #[allow(clippy::eq_op)]
+ {
+ assert_eq!(one - one, zero);
+ }
+ assert_eq!(one + (-one), zero);
+ assert_eq!(two - F::from(int_modulus.clone()), two);
+ assert_eq!(one - F::from(int_modulus.clone() - int_one.clone()), two);
+
+ // sub w/ assignment
+ let mut a = prng.get();
+ let b = prng.get();
+ let c = a - b;
+ a -= b;
+ assert_eq!(a, c);
+
+ // add + sub
+ for _ in 0..100 {
+ let f = prng.get();
+ let g = prng.get();
+ assert_eq!(f + g - f - g, zero);
+ assert_eq!(f + g - g, f);
+ assert_eq!(f + g - f, g);
+ }
+
+ // mul
+ assert_eq!(two * two, four);
+ assert_eq!(two * one, two);
+ assert_eq!(two * zero, zero);
+ assert_eq!(one * F::from(int_modulus.clone()), zero);
+
+ // mul w/ assignment
+ let mut a = prng.get();
+ let b = prng.get();
+ let c = a * b;
+ a *= b;
+ assert_eq!(a, c);
+
+ // integer conversion
+ assert_eq!(F::Integer::from(zero), F::Integer::try_from(0).unwrap());
+ assert_eq!(F::Integer::from(one), F::Integer::try_from(1).unwrap());
+ assert_eq!(F::Integer::from(two), F::Integer::try_from(2).unwrap());
+ assert_eq!(F::Integer::from(four), F::Integer::try_from(4).unwrap());
+
+ // serialization
+ let test_inputs = vec![
+ zero,
+ one,
+ prng.get(),
+ F::from(int_modulus.clone() - int_one.clone()),
+ ];
+ for want in test_inputs.iter() {
+ let mut bytes = vec![];
+ want.encode(&mut bytes);
+
+ assert_eq!(bytes.len(), F::ENCODED_SIZE);
+ assert_eq!(want.encoded_len().unwrap(), F::ENCODED_SIZE);
+
+ let got = F::get_decoded(&bytes).unwrap();
+ assert_eq!(got, *want);
+ }
+
+ let serialized_vec = F::slice_into_byte_vec(&test_inputs);
+ let deserialized = F::byte_slice_into_vec(&serialized_vec).unwrap();
+ assert_eq!(deserialized, test_inputs);
+
+ let test_input = prng.get();
+ let json = serde_json::to_string(&test_input).unwrap();
+ let deserialized = serde_json::from_str::<F>(&json).unwrap();
+ assert_eq!(deserialized, test_input);
+
+ let value = serde_json::from_str::<serde_json::Value>(&json).unwrap();
+ let array = value.as_array().unwrap();
+ for element in array {
+ element.as_u64().unwrap();
+ }
+
+ let err = F::byte_slice_into_vec(&[0]).unwrap_err();
+ assert_matches!(err, FieldError::ShortRead);
+
+ let err = F::byte_slice_into_vec(&vec![0xffu8; F::ENCODED_SIZE]).unwrap_err();
+ assert_matches!(err, FieldError::Codec(CodecError::Other(err)) => {
+ assert_matches!(err.downcast_ref::<FieldError>(), Some(FieldError::ModulusOverflow));
+ });
+
+ let insufficient = vec![0u8; F::ENCODED_SIZE - 1];
+ let err = F::try_from(insufficient.as_ref()).unwrap_err();
+ assert_matches!(err, FieldError::ShortRead);
+ let err = F::decode(&mut Cursor::new(&insufficient)).unwrap_err();
+ assert_matches!(err, CodecError::Io(_));
+
+ let err = F::decode(&mut Cursor::new(&vec![0xffu8; F::ENCODED_SIZE])).unwrap_err();
+ assert_matches!(err, CodecError::Other(err) => {
+ assert_matches!(err.downcast_ref::<FieldError>(), Some(FieldError::ModulusOverflow));
+ });
+
+ // equality and hash: Generate many elements, confirm they are not equal, and confirm
+ // various products that should be equal have the same hash. Three is chosen as a generator
+ // here because it happens to generate fairly large subgroups of (Z/pZ)* for all four
+ // primes.
+ let three = F::from(F::Integer::try_from(3).unwrap());
+ let mut powers_of_three = Vec::with_capacity(500);
+ let mut power = one;
+ for _ in 0..500 {
+ powers_of_three.push(power);
+ power *= three;
+ }
+ // Check all these elements are mutually not equal.
+ for i in 0..powers_of_three.len() {
+ let first = &powers_of_three[i];
+ for second in &powers_of_three[0..i] {
+ assert_ne!(first, second);
+ }
+ }
+
+ // Construct an element from a number that needs to be reduced, and test comparisons on it,
+ // confirming that it is reduced correctly.
+ let p = F::from(int_modulus.clone());
+ assert_eq!(p, zero);
+ let p_plus_one = F::from(int_modulus + int_one);
+ assert_eq!(p_plus_one, one);
+ }
+
+ pub(super) fn hash_helper<H: Hash>(input: H) -> u64 {
+ let mut hasher = DefaultHasher::new();
+ input.hash(&mut hasher);
+ hasher.finish()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::test_utils::{field_element_test_common, hash_helper};
+ use crate::fp::MAX_ROOTS;
+ use crate::prng::Prng;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn test_accumulate() {
+ let mut lhs = vec![FieldPrio2(1); 10];
+ let rhs = vec![FieldPrio2(2); 10];
+
+ merge_vector(&mut lhs, &rhs).unwrap();
+
+ lhs.iter().for_each(|f| assert_eq!(*f, FieldPrio2(3)));
+ rhs.iter().for_each(|f| assert_eq!(*f, FieldPrio2(2)));
+
+ let wrong_len = vec![FieldPrio2::zero(); 9];
+ let result = merge_vector(&mut lhs, &wrong_len);
+ assert_matches!(result, Err(FieldError::InputSizeMismatch));
+ }
+
+ fn field_element_test<F: FftFriendlyFieldElement + Hash>() {
+ field_element_test_common::<F>();
+
+ let mut prng: Prng<F, _> = Prng::new().unwrap();
+ let int_modulus = F::modulus();
+ let int_one = F::Integer::try_from(1).unwrap();
+ let zero = F::zero();
+ let one = F::one();
+ let two = F::from(F::Integer::try_from(2).unwrap());
+ let four = F::from(F::Integer::try_from(4).unwrap());
+
+ // div
+ assert_eq!(four / two, two);
+ #[allow(clippy::eq_op)]
+ {
+ assert_eq!(two / two, one);
+ }
+ assert_eq!(zero / two, zero);
+ assert_eq!(two / zero, zero); // Undefined behavior
+ assert_eq!(zero.inv(), zero); // Undefined behavior
+
+ // div w/ assignment
+ let mut a = prng.get();
+ let b = prng.get();
+ let c = a / b;
+ a /= b;
+ assert_eq!(a, c);
+ assert_eq!(hash_helper(a), hash_helper(c));
+
+ // mul + div
+ for _ in 0..100 {
+ let f = prng.get();
+ if f == zero {
+ continue;
+ }
+ assert_eq!(f * f.inv(), one);
+ assert_eq!(f.inv() * f, one);
+ }
+
+ // pow
+ assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one);
+ assert_eq!(two.pow(int_one), two);
+ assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four);
+ assert_eq!(two.pow(int_modulus - int_one), one);
+ assert_eq!(two.pow(int_modulus), two);
+
+ // roots
+ let mut int_order = F::generator_order();
+ for l in 0..MAX_ROOTS + 1 {
+ assert_eq!(
+ F::generator().pow(int_order),
+ F::root(l).unwrap(),
+ "failure for F::root({l})"
+ );
+ int_order = int_order >> 1;
+ }
+
+ // formatting
+ assert_eq!(format!("{zero}"), "0");
+ assert_eq!(format!("{one}"), "1");
+ assert_eq!(format!("{zero:?}"), "0");
+ assert_eq!(format!("{one:?}"), "1");
+
+ let three = F::from(F::Integer::try_from(3).unwrap());
+ let mut powers_of_three = Vec::with_capacity(500);
+ let mut power = one;
+ for _ in 0..500 {
+ powers_of_three.push(power);
+ power *= three;
+ }
+
+ // Check that 3^i is the same whether it's calculated with pow() or repeated
+ // multiplication, with both equality and hash equality.
+ for (i, power) in powers_of_three.iter().enumerate() {
+ let result = three.pow(F::Integer::try_from(i).unwrap());
+ assert_eq!(result, *power);
+ let hash1 = hash_helper(power);
+ let hash2 = hash_helper(result);
+ assert_eq!(hash1, hash2);
+ }
+
+ // Check that 3^n = (3^i)*(3^(n-i)), via both equality and hash equality.
+ let expected_product = powers_of_three[powers_of_three.len() - 1];
+ let expected_hash = hash_helper(expected_product);
+ for i in 0..powers_of_three.len() {
+ let a = powers_of_three[i];
+ let b = powers_of_three[powers_of_three.len() - 1 - i];
+ let product = a * b;
+ assert_eq!(product, expected_product);
+ assert_eq!(hash_helper(product), expected_hash);
+ }
+ }
+
+ #[test]
+ fn test_field_prio2() {
+ field_element_test::<FieldPrio2>();
+ }
+
+ #[test]
+ fn test_field64() {
+ field_element_test::<Field64>();
+ }
+
+ #[test]
+ fn test_field128() {
+ field_element_test::<Field128>();
+ }
+
+ #[test]
+ fn test_encode_into_bitvector() {
+ let zero = Field128::zero();
+ let one = Field128::one();
+ let zero_enc = Field128::encode_into_bitvector_representation(&0, 4).unwrap();
+ let one_enc = Field128::encode_into_bitvector_representation(&1, 4).unwrap();
+ let fifteen_enc = Field128::encode_into_bitvector_representation(&15, 4).unwrap();
+ assert_eq!(zero_enc, [zero; 4]);
+ assert_eq!(one_enc, [one, zero, zero, zero]);
+ assert_eq!(fifteen_enc, [one; 4]);
+ Field128::encode_into_bitvector_representation(&16, 4).unwrap_err();
+ }
+
+ #[test]
+ fn test_fill_bitvector() {
+ let zero = Field128::zero();
+ let one = Field128::one();
+ let mut output: Vec<Field128> = vec![zero; 6];
+ Field128::fill_with_bitvector_representation(&9, &mut output[1..5]).unwrap();
+ assert_eq!(output, [zero, one, zero, zero, one, zero]);
+ Field128::fill_with_bitvector_representation(&16, &mut output[1..5]).unwrap_err();
+ }
+}
diff --git a/third_party/rust/prio/src/field/field255.rs b/third_party/rust/prio/src/field/field255.rs
new file mode 100644
index 0000000000..fd06a6334a
--- /dev/null
+++ b/third_party/rust/prio/src/field/field255.rs
@@ -0,0 +1,543 @@
+// Copyright (c) 2023 ISRG
+// SPDX-License-Identifier: MPL-2.0
+
+//! Finite field arithmetic for `GF(2^255 - 19)`.
+
+use crate::{
+ codec::{CodecError, Decode, Encode},
+ field::{FieldElement, FieldElementVisitor, FieldError},
+};
+use fiat_crypto::curve25519_64::{
+ fiat_25519_add, fiat_25519_carry, fiat_25519_carry_mul, fiat_25519_from_bytes,
+ fiat_25519_loose_field_element, fiat_25519_opp, fiat_25519_relax, fiat_25519_selectznz,
+ fiat_25519_sub, fiat_25519_tight_field_element, fiat_25519_to_bytes,
+};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+use std::{
+ convert::TryFrom,
+ fmt::{self, Debug, Display, Formatter},
+ io::{Cursor, Read},
+ marker::PhantomData,
+ mem::size_of,
+ ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
+};
+use subtle::{
+ Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
+};
+
+// `python3 -c "print(', '.join(hex(x) for x in (2**255-19).to_bytes(32, 'little')))"`
+const MODULUS_LITTLE_ENDIAN: [u8; 32] = [
+ 0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f,
+];
+
+/// `GF(2^255 - 19)`, a 255-bit field.
+#[derive(Clone, Copy)]
+#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
+pub struct Field255(fiat_25519_tight_field_element);
+
+impl Field255 {
+ /// Attempts to instantiate a `Field255` from the first `Self::ENCODED_SIZE` bytes in the
+ /// provided slice.
+ ///
+ /// # Errors
+ ///
+ /// An error is returned if the provided slice is not long enough to encode a field element or
+ /// if the decoded value is greater than the field prime.
+ fn try_from_bytes(bytes: &[u8], mask_top_bit: bool) -> Result<Self, FieldError> {
+ if Self::ENCODED_SIZE > bytes.len() {
+ return Err(FieldError::ShortRead);
+ }
+
+ let mut value = [0u8; Self::ENCODED_SIZE];
+ value.copy_from_slice(&bytes[..Self::ENCODED_SIZE]);
+
+ if mask_top_bit {
+ value[31] &= 0b0111_1111;
+ }
+
+ // Walk through the bytes of the provided value from most significant to least,
+ // and identify whether the first byte that differs from the field's modulus is less than
+ // the corresponding byte or greater than the corresponding byte, in constant time. (Or
+ // whether the provided value is equal to the field modulus.)
+ let mut less_than_modulus = Choice::from(0u8);
+ let mut greater_than_modulus = Choice::from(0u8);
+ for (value_byte, modulus_byte) in value.iter().rev().zip(MODULUS_LITTLE_ENDIAN.iter().rev())
+ {
+ less_than_modulus |= value_byte.ct_lt(modulus_byte) & !greater_than_modulus;
+ greater_than_modulus |= value_byte.ct_gt(modulus_byte) & !less_than_modulus;
+ }
+
+ if bool::from(!less_than_modulus) {
+ return Err(FieldError::ModulusOverflow);
+ }
+
+ let mut output = fiat_25519_tight_field_element([0; 5]);
+ fiat_25519_from_bytes(&mut output, &value);
+
+ Ok(Field255(output))
+ }
+}
+
+impl ConstantTimeEq for Field255 {
+ fn ct_eq(&self, rhs: &Self) -> Choice {
+ // The internal representation used by fiat-crypto is not 1-1 with the field, so it is
+ // necessary to compare field elements via their canonical encodings.
+
+ let mut self_encoded = [0; 32];
+ fiat_25519_to_bytes(&mut self_encoded, &self.0);
+ let mut rhs_encoded = [0; 32];
+ fiat_25519_to_bytes(&mut rhs_encoded, &rhs.0);
+
+ self_encoded.ct_eq(&rhs_encoded)
+ }
+}
+
+impl ConditionallySelectable for Field255 {
+ fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
+ let mut output = [0; 5];
+ fiat_25519_selectznz(&mut output, choice.unwrap_u8(), &(a.0).0, &(b.0).0);
+ Field255(fiat_25519_tight_field_element(output))
+ }
+}
+
+impl PartialEq for Field255 {
+ fn eq(&self, rhs: &Self) -> bool {
+ self.ct_eq(rhs).into()
+ }
+}
+
+impl Eq for Field255 {}
+
+impl Add for Field255 {
+ type Output = Field255;
+
+ fn add(self, rhs: Self) -> Field255 {
+ let mut loose_output = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_add(&mut loose_output, &self.0, &rhs.0);
+ let mut output = fiat_25519_tight_field_element([0; 5]);
+ fiat_25519_carry(&mut output, &loose_output);
+ Field255(output)
+ }
+}
+
+impl AddAssign for Field255 {
+ fn add_assign(&mut self, rhs: Self) {
+ let mut loose_output = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_add(&mut loose_output, &self.0, &rhs.0);
+ fiat_25519_carry(&mut self.0, &loose_output);
+ }
+}
+
+impl Sub for Field255 {
+ type Output = Field255;
+
+ fn sub(self, rhs: Self) -> Field255 {
+ let mut loose_output = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_sub(&mut loose_output, &self.0, &rhs.0);
+ let mut output = fiat_25519_tight_field_element([0; 5]);
+ fiat_25519_carry(&mut output, &loose_output);
+ Field255(output)
+ }
+}
+
+impl SubAssign for Field255 {
+ fn sub_assign(&mut self, rhs: Self) {
+ let mut loose_output = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_sub(&mut loose_output, &self.0, &rhs.0);
+ fiat_25519_carry(&mut self.0, &loose_output);
+ }
+}
+
+impl Mul for Field255 {
+ type Output = Field255;
+
+ fn mul(self, rhs: Self) -> Field255 {
+ let mut self_relaxed = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_relax(&mut self_relaxed, &self.0);
+ let mut rhs_relaxed = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_relax(&mut rhs_relaxed, &rhs.0);
+ let mut output = fiat_25519_tight_field_element([0; 5]);
+ fiat_25519_carry_mul(&mut output, &self_relaxed, &rhs_relaxed);
+ Field255(output)
+ }
+}
+
+impl MulAssign for Field255 {
+ fn mul_assign(&mut self, rhs: Self) {
+ let mut self_relaxed = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_relax(&mut self_relaxed, &self.0);
+ let mut rhs_relaxed = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_relax(&mut rhs_relaxed, &rhs.0);
+ fiat_25519_carry_mul(&mut self.0, &self_relaxed, &rhs_relaxed);
+ }
+}
+
+impl Div for Field255 {
+ type Output = Field255;
+
+ fn div(self, _rhs: Self) -> Self::Output {
+ unimplemented!("Div is not implemented for Field255 because it's not needed yet")
+ }
+}
+
+impl DivAssign for Field255 {
+ fn div_assign(&mut self, _rhs: Self) {
+ unimplemented!("DivAssign is not implemented for Field255 because it's not needed yet")
+ }
+}
+
+impl Neg for Field255 {
+ type Output = Field255;
+
+ fn neg(self) -> Field255 {
+ -&self
+ }
+}
+
+impl<'a> Neg for &'a Field255 {
+ type Output = Field255;
+
+ fn neg(self) -> Field255 {
+ let mut loose_output = fiat_25519_loose_field_element([0; 5]);
+ fiat_25519_opp(&mut loose_output, &self.0);
+ let mut output = fiat_25519_tight_field_element([0; 5]);
+ fiat_25519_carry(&mut output, &loose_output);
+ Field255(output)
+ }
+}
+
+impl From<u64> for Field255 {
+ fn from(value: u64) -> Self {
+ let input_bytes = value.to_le_bytes();
+ let mut field_bytes = [0u8; Self::ENCODED_SIZE];
+ field_bytes[..input_bytes.len()].copy_from_slice(&input_bytes);
+ Self::try_from_bytes(&field_bytes, false).unwrap()
+ }
+}
+
+impl<'a> TryFrom<&'a [u8]> for Field255 {
+ type Error = FieldError;
+
+ fn try_from(bytes: &[u8]) -> Result<Self, FieldError> {
+ Self::try_from_bytes(bytes, false)
+ }
+}
+
+impl From<Field255> for [u8; Field255::ENCODED_SIZE] {
+ fn from(element: Field255) -> Self {
+ let mut array = [0; Field255::ENCODED_SIZE];
+ fiat_25519_to_bytes(&mut array, &element.0);
+ array
+ }
+}
+
+impl From<Field255> for Vec<u8> {
+ fn from(elem: Field255) -> Vec<u8> {
+ <[u8; Field255::ENCODED_SIZE]>::from(elem).to_vec()
+ }
+}
+
+impl Display for Field255 {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ let encoded: [u8; Self::ENCODED_SIZE] = (*self).into();
+ write!(f, "0x")?;
+ for byte in encoded {
+ write!(f, "{byte:02x}")?;
+ }
+ Ok(())
+ }
+}
+
+impl Debug for Field255 {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ <Self as Display>::fmt(self, f)
+ }
+}
+
+impl Serialize for Field255 {
+ fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+ let bytes: [u8; Self::ENCODED_SIZE] = (*self).into();
+ serializer.serialize_bytes(&bytes)
+ }
+}
+
+impl<'de> Deserialize<'de> for Field255 {
+ fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Field255, D::Error> {
+ deserializer.deserialize_bytes(FieldElementVisitor {
+ phantom: PhantomData,
+ })
+ }
+}
+
+impl Encode for Field255 {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&<[u8; Self::ENCODED_SIZE]>::from(*self));
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(Self::ENCODED_SIZE)
+ }
+}
+
+impl Decode for Field255 {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut value = [0u8; Self::ENCODED_SIZE];
+ bytes.read_exact(&mut value)?;
+ Field255::try_from_bytes(&value, false).map_err(|e| {
+ CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>)
+ })
+ }
+}
+
+impl FieldElement for Field255 {
+ const ENCODED_SIZE: usize = 32;
+
+ fn inv(&self) -> Self {
+ unimplemented!("Field255::inv() is not implemented because it's not needed yet")
+ }
+
+ fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> {
+ Field255::try_from_bytes(bytes, true)
+ }
+
+ fn zero() -> Self {
+ Field255(fiat_25519_tight_field_element([0, 0, 0, 0, 0]))
+ }
+
+ fn one() -> Self {
+ Field255(fiat_25519_tight_field_element([1, 0, 0, 0, 0]))
+ }
+}
+
+impl Default for Field255 {
+ fn default() -> Self {
+ Field255::zero()
+ }
+}
+
+impl TryFrom<Field255> for u64 {
+ type Error = FieldError;
+
+ fn try_from(elem: Field255) -> Result<u64, FieldError> {
+ const PREFIX_LEN: usize = size_of::<u64>();
+ let mut le_bytes = [0; 32];
+
+ fiat_25519_to_bytes(&mut le_bytes, &elem.0);
+ if !bool::from(le_bytes[PREFIX_LEN..].ct_eq(&[0_u8; 32 - PREFIX_LEN])) {
+ return Err(FieldError::IntegerTryFrom);
+ }
+
+ Ok(u64::from_le_bytes(
+ le_bytes[..PREFIX_LEN].try_into().unwrap(),
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{Field255, MODULUS_LITTLE_ENDIAN};
+ use crate::{
+ codec::Encode,
+ field::{
+ test_utils::{field_element_test_common, TestFieldElementWithInteger},
+ FieldElement, FieldError,
+ },
+ };
+ use assert_matches::assert_matches;
+ use fiat_crypto::curve25519_64::{
+ fiat_25519_from_bytes, fiat_25519_tight_field_element, fiat_25519_to_bytes,
+ };
+ use num_bigint::BigUint;
+ use once_cell::sync::Lazy;
+ use std::convert::{TryFrom, TryInto};
+
+ static MODULUS: Lazy<BigUint> = Lazy::new(|| BigUint::from_bytes_le(&MODULUS_LITTLE_ENDIAN));
+
+ impl From<BigUint> for Field255 {
+ fn from(value: BigUint) -> Self {
+ let le_bytes_vec = (value % &*MODULUS).to_bytes_le();
+
+ let mut le_bytes_array = [0u8; 32];
+ le_bytes_array[..le_bytes_vec.len()].copy_from_slice(&le_bytes_vec);
+
+ let mut output = fiat_25519_tight_field_element([0; 5]);
+ fiat_25519_from_bytes(&mut output, &le_bytes_array);
+ Field255(output)
+ }
+ }
+
+ impl From<Field255> for BigUint {
+ fn from(value: Field255) -> Self {
+ let mut le_bytes = [0u8; 32];
+ fiat_25519_to_bytes(&mut le_bytes, &value.0);
+ BigUint::from_bytes_le(&le_bytes)
+ }
+ }
+
+ impl TestFieldElementWithInteger for Field255 {
+ type Integer = BigUint;
+ type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
+ type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;
+
+ fn pow(&self, _exp: Self::Integer) -> Self {
+ unimplemented!("Field255::pow() is not implemented because it's not needed yet")
+ }
+
+ fn modulus() -> Self::Integer {
+ MODULUS.clone()
+ }
+ }
+
+ #[test]
+ fn check_modulus() {
+ let modulus = Field255::modulus();
+ let element = Field255::from(modulus);
+ // Note that these two objects represent the same field element, they encode to the same
+ // canonical value (32 zero bytes), but they do not have the same internal representation.
+ assert_eq!(element, Field255::zero());
+ }
+
+ #[test]
+ fn check_identities() {
+ let zero_bytes: [u8; 32] = Field255::zero().into();
+ assert_eq!(zero_bytes, [0; 32]);
+ let one_bytes: [u8; 32] = Field255::one().into();
+ assert_eq!(
+ one_bytes,
+ [
+ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0
+ ]
+ );
+ }
+
+ #[test]
+ fn encode_endianness() {
+ let mut one_encoded = Vec::new();
+ Field255::one().encode(&mut one_encoded);
+ assert_eq!(
+ one_encoded,
+ [
+ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0
+ ]
+ );
+ }
+
+ #[test]
+ fn test_field255() {
+ field_element_test_common::<Field255>();
+ }
+
+ #[test]
+ fn try_from_bytes() {
+ assert_matches!(
+ Field255::try_from_bytes(
+ &[
+ 0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f,
+ ],
+ false,
+ ),
+ Err(FieldError::ModulusOverflow)
+ );
+ assert_matches!(
+ Field255::try_from_bytes(
+ &[
+ 0xee, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f,
+ ],
+ false,
+ ),
+ Ok(_)
+ );
+ assert_matches!(
+ Field255::try_from_bytes(
+ &[
+ 0xec, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f,
+ ],
+ true,
+ ),
+ Ok(element) => assert_eq!(element + Field255::one(), Field255::zero())
+ );
+ assert_matches!(
+ Field255::try_from_bytes(
+ &[
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
+ ],
+ false
+ ),
+ Err(FieldError::ModulusOverflow)
+ );
+ assert_matches!(
+ Field255::try_from_bytes(
+ &[
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
+ ],
+ true
+ ),
+ Ok(element) => assert_eq!(element, Field255::zero())
+ );
+ }
+
+ #[test]
+ fn u64_conversion() {
+ assert_eq!(Field255::from(0u64), Field255::zero());
+ assert_eq!(Field255::from(1u64), Field255::one());
+
+ let max_bytes: [u8; 32] = Field255::from(u64::MAX).into();
+ assert_eq!(
+ max_bytes,
+ [
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00
+ ]
+ );
+
+ let want: u64 = 0xffffffffffffffff;
+ assert_eq!(u64::try_from(Field255::from(want)).unwrap(), want);
+
+ let want: u64 = 0x7000000000000001;
+ assert_eq!(u64::try_from(Field255::from(want)).unwrap(), want);
+
+ let want: u64 = 0x1234123412341234;
+ assert_eq!(u64::try_from(Field255::from(want)).unwrap(), want);
+
+ assert!(u64::try_from(Field255::try_from_bytes(&[1; 32], false).unwrap()).is_err());
+ assert!(u64::try_from(Field255::try_from_bytes(&[2; 32], false).unwrap()).is_err());
+ }
+
+ #[test]
+ fn formatting() {
+ assert_eq!(
+ format!("{}", Field255::zero()),
+ "0x0000000000000000000000000000000000000000000000000000000000000000"
+ );
+ assert_eq!(
+ format!("{}", Field255::one()),
+ "0x0100000000000000000000000000000000000000000000000000000000000000"
+ );
+ assert_eq!(
+ format!("{}", -Field255::one()),
+ "0xecffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f"
+ );
+ assert_eq!(
+ format!("{:?}", Field255::zero()),
+ "0x0000000000000000000000000000000000000000000000000000000000000000"
+ );
+ assert_eq!(
+ format!("{:?}", Field255::one()),
+ "0x0100000000000000000000000000000000000000000000000000000000000000"
+ );
+ }
+}
diff --git a/third_party/rust/prio/src/flp.rs b/third_party/rust/prio/src/flp.rs
new file mode 100644
index 0000000000..1912ebab14
--- /dev/null
+++ b/third_party/rust/prio/src/flp.rs
@@ -0,0 +1,1059 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of the generic Fully Linear Proof (FLP) system specified in
+//! [[draft-irtf-cfrg-vdaf-07]]. This is the main building block of [`Prio3`](crate::vdaf::prio3).
+//!
+//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation
+//! specifies a validity circuit that defines the set of valid measurements, as well as the finite
+//! field in which the validity circuit is evaluated. It also determines how raw measurements are
+//! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of
+//! measurements.
+//!
+//! # Overview
+//!
+//! The proof system is comprised of three algorithms. The first, `prove`, is run by the prover in
+//! order to generate a proof of a statement's validity. The second and third, `query` and
+//! `decide`, are run by the verifier in order to check the proof. The proof asserts that the input
+//! is an element of a language recognized by the arithmetic circuit. If an input is _not_ valid,
+//! then the verification step will fail with high probability:
+//!
+//! ```
+//! use prio::flp::types::Count;
+//! use prio::flp::Type;
+//! use prio::field::{random_vector, FieldElement, Field64};
+//!
+//! // The prover chooses a measurement.
+//! let count = Count::new();
+//! let input: Vec<Field64> = count.encode_measurement(&0).unwrap();
+//!
+//! // The prover and verifier agree on "joint randomness" used to generate and
+//! // check the proof. The application needs to ensure that the prover
+//! // "commits" to the input before this point. In Prio3, the joint
+//! // randomness is derived from additive shares of the input.
+//! let joint_rand = random_vector(count.joint_rand_len()).unwrap();
+//!
+//! // The prover generates the proof.
+//! let prove_rand = random_vector(count.prove_rand_len()).unwrap();
+//! let proof = count.prove(&input, &prove_rand, &joint_rand).unwrap();
+//!
+//! // The verifier checks the proof. In the first step, the verifier "queries"
+//! // the input and proof, getting the "verifier message" in response. It then
+//! // inspects the verifier to decide if the input is valid.
+//! let query_rand = random_vector(count.query_rand_len()).unwrap();
+//! let verifier = count.query(&input, &proof, &query_rand, &joint_rand, 1).unwrap();
+//! assert!(count.decide(&verifier).unwrap());
+//! ```
+//!
+//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+
+#[cfg(feature = "experimental")]
+use crate::dp::DifferentialPrivacyStrategy;
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError};
+use crate::field::{FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldError};
+use crate::fp::log2;
+use crate::polynomial::poly_eval;
+use std::any::Any;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+
+pub mod gadgets;
+pub mod types;
+
+/// Errors propagated by methods in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum FlpError {
+ /// Calling [`Type::prove`] returned an error.
+ #[error("prove error: {0}")]
+ Prove(String),
+
+ /// Calling [`Type::query`] returned an error.
+ #[error("query error: {0}")]
+ Query(String),
+
+ /// Calling [`Type::decide`] returned an error.
+ #[error("decide error: {0}")]
+ Decide(String),
+
+ /// Calling a gadget returned an error.
+ #[error("gadget error: {0}")]
+ Gadget(String),
+
+ /// Calling the validity circuit returned an error.
+ #[error("validity circuit error: {0}")]
+ Valid(String),
+
+ /// Calling [`Type::encode_measurement`] returned an error.
+ #[error("value error: {0}")]
+ Encode(String),
+
+ /// Calling [`Type::decode_result`] returned an error.
+ #[error("value error: {0}")]
+ Decode(String),
+
+ /// Calling [`Type::truncate`] returned an error.
+ #[error("truncate error: {0}")]
+ Truncate(String),
+
+ /// Generic invalid parameter. This may be returned when an FLP type cannot be constructed.
+ #[error("invalid paramter: {0}")]
+ InvalidParameter(String),
+
+ /// Returned if an FFT operation propagates an error.
+ #[error("FFT error: {0}")]
+ Fft(#[from] FftError),
+
+ /// Returned if a field operation encountered an error.
+ #[error("Field error: {0}")]
+ Field(#[from] FieldError),
+
+ #[cfg(feature = "experimental")]
+ /// An error happened during noising.
+ #[error("differential privacy error: {0}")]
+ DifferentialPrivacy(#[from] crate::dp::DpError),
+
+ /// Unit test error.
+ #[cfg(test)]
+ #[error("test failed: {0}")]
+ Test(String),
+}
+
+/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
+/// as a vector of field elements and how validity of the encoded measurement is determined.
+/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
+pub trait Type: Sized + Eq + Clone + Debug {
+ /// The Prio3 VDAF identifier corresponding to this type.
+ const ID: u32;
+
+ /// The type of raw measurement to be encoded.
+ type Measurement: Clone + Debug;
+
+ /// The type of aggregate result for this type.
+ type AggregateResult: Clone + Debug;
+
+ /// The finite field used for this type.
+ type Field: FftFriendlyFieldElement;
+
+ /// Encodes a measurement as a vector of [`Self::input_len`] field elements.
+ fn encode_measurement(
+ &self,
+ measurement: &Self::Measurement,
+ ) -> Result<Vec<Self::Field>, FlpError>;
+
+ /// Decode an aggregate result.
+ fn decode_result(
+ &self,
+ data: &[Self::Field],
+ num_measurements: usize,
+ ) -> Result<Self::AggregateResult, FlpError>;
+
+ /// Returns the sequence of gadgets associated with the validity circuit.
+ ///
+ /// # Notes
+ ///
+ /// The construction of [[BBCG+19], Theorem 4.3] uses a single gadget rather than many. The
+ /// idea to generalize the proof system to allow multiple gadgets is discussed briefly in
+ /// [[BBCG+19], Remark 4.5], but no construction is given. The construction implemented here
+ /// requires security analysis.
+ ///
+ /// [BBCG+19]: https://ia.cr/2019/188
+ fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>;
+
+ /// Evaluates the validity circuit on an input and returns the output.
+ ///
+ /// # Parameters
+ ///
+ /// * `gadgets` is the sequence of gadgets, presumably output by [`Self::gadget`].
+ /// * `input` is the input to be validated.
+ /// * `joint_rand` is the joint randomness shared by the prover and verifier.
+ /// * `num_shares` is the number of input shares.
+ ///
+ /// # Example usage
+ ///
+ /// Applications typically do not call this method directly. It is used internally by
+ /// [`Self::prove`] and [`Self::query`] to generate and verify the proof respectively.
+ ///
+ /// ```
+ /// use prio::flp::types::Count;
+ /// use prio::flp::Type;
+ /// use prio::field::{random_vector, FieldElement, Field64};
+ ///
+ /// let count = Count::new();
+ /// let input: Vec<Field64> = count.encode_measurement(&1).unwrap();
+ /// let joint_rand = random_vector(count.joint_rand_len()).unwrap();
+ /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap();
+ /// assert_eq!(v, Field64::zero());
+ /// ```
+ fn valid(
+ &self,
+ gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>,
+ input: &[Self::Field],
+ joint_rand: &[Self::Field],
+ num_shares: usize,
+ ) -> Result<Self::Field, FlpError>;
+
+ /// Constructs an aggregatable output from an encoded input. Calling this method is only safe
+ /// once `input` has been validated.
+ fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>;
+
+ /// The length in field elements of the encoded input returned by [`Self::encode_measurement`].
+ fn input_len(&self) -> usize;
+
+ /// The length in field elements of the proof generated for this type.
+ fn proof_len(&self) -> usize;
+
+ /// The length in field elements of the verifier message constructed by [`Self::query`].
+ fn verifier_len(&self) -> usize;
+
+ /// The length of the truncated output (i.e., the output of [`Type::truncate`]).
+ fn output_len(&self) -> usize;
+
+ /// The length of the joint random input.
+ fn joint_rand_len(&self) -> usize;
+
+ /// The length in field elements of the random input consumed by the prover to generate a
+ /// proof. This is the same as the sum of the arity of each gadget in the validity circuit.
+ fn prove_rand_len(&self) -> usize;
+
+ /// The length in field elements of the random input consumed by the verifier to make queries
+ /// against inputs and proofs. This is the same as the number of gadgets in the validity
+ /// circuit.
+ fn query_rand_len(&self) -> usize;
+
+ /// Generate a proof of an input's validity. The return value is a sequence of
+ /// [`Self::proof_len`] field elements.
+ ///
+ /// # Parameters
+ ///
+ /// * `input` is the input.
+ /// * `prove_rand` is the prover' randomness.
+ /// * `joint_rand` is the randomness shared by the prover and verifier.
+ fn prove(
+ &self,
+ input: &[Self::Field],
+ prove_rand: &[Self::Field],
+ joint_rand: &[Self::Field],
+ ) -> Result<Vec<Self::Field>, FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Prove(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len()
+ )));
+ }
+
+ if prove_rand.len() != self.prove_rand_len() {
+ return Err(FlpError::Prove(format!(
+ "unexpected prove randomness length: got {}; want {}",
+ prove_rand.len(),
+ self.prove_rand_len()
+ )));
+ }
+
+ if joint_rand.len() != self.joint_rand_len() {
+ return Err(FlpError::Prove(format!(
+ "unexpected joint randomness length: got {}; want {}",
+ joint_rand.len(),
+ self.joint_rand_len()
+ )));
+ }
+
+ let mut prove_rand_len = 0;
+ let mut shims = self
+ .gadget()
+ .into_iter()
+ .map(|inner| {
+ let inner_arity = inner.arity();
+ if prove_rand_len + inner_arity > prove_rand.len() {
+ return Err(FlpError::Prove(format!(
+ "short prove randomness: got {}; want at least {}",
+ prove_rand.len(),
+ prove_rand_len + inner_arity
+ )));
+ }
+
+ let gadget = Box::new(ProveShimGadget::new(
+ inner,
+ &prove_rand[prove_rand_len..prove_rand_len + inner_arity],
+ )?) as Box<dyn Gadget<Self::Field>>;
+ prove_rand_len += inner_arity;
+
+ Ok(gadget)
+ })
+ .collect::<Result<Vec<_>, FlpError>>()?;
+ assert_eq!(prove_rand_len, self.prove_rand_len());
+
+ // Create a buffer for storing the proof. The buffer is longer than the proof itself; the extra
+ // length is to accommodate the computation of each gadget polynomial.
+ let data_len = shims
+ .iter()
+ .map(|shim| {
+ let gadget_poly_len = gadget_poly_len(shim.degree(), wire_poly_len(shim.calls()));
+
+ // Computing the gadget polynomial using FFT requires an amount of memory that is a
+ // power of 2. Thus we choose the smallest power of 2 that is at least as large as
+ // the gadget polynomial. The wire seeds are encoded in the proof, too, so we
+ // include the arity of the gadget to ensure there is always enough room at the end
+ // of the buffer to compute the next gadget polynomial. It's likely that the
+ // memory footprint here can be reduced, with a bit of care.
+ shim.arity() + gadget_poly_len.next_power_of_two()
+ })
+ .sum();
+ let mut proof = vec![Self::Field::zero(); data_len];
+
+ // Run the validity circuit with a sequence of "shim" gadgets that record the value of each
+ // input wire of each gadget evaluation. These values are used to construct the wire
+ // polynomials for each gadget in the next step.
+ let _ = self.valid(&mut shims, input, joint_rand, 1)?;
+
+ // Construct the proof.
+ let mut proof_len = 0;
+ for shim in shims.iter_mut() {
+ let gadget = shim
+ .as_any()
+ .downcast_mut::<ProveShimGadget<Self::Field>>()
+ .unwrap();
+
+ // Interpolate the wire polynomials `f[0], ..., f[g_arity-1]` from the input wires of each
+ // evaluation of the gadget.
+ let m = wire_poly_len(gadget.calls());
+ let m_inv = Self::Field::from(
+ <Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(),
+ )
+ .inv();
+ let mut f = vec![vec![Self::Field::zero(); m]; gadget.arity()];
+ for ((coefficients, values), proof_val) in f[..gadget.arity()]
+ .iter_mut()
+ .zip(gadget.f_vals[..gadget.arity()].iter())
+ .zip(proof[proof_len..proof_len + gadget.arity()].iter_mut())
+ {
+ discrete_fourier_transform(coefficients, values, m)?;
+ discrete_fourier_transform_inv_finish(coefficients, m, m_inv);
+
+ // The first point on each wire polynomial is a random value chosen by the prover. This
+ // point is stored in the proof so that the verifier can reconstruct the wire
+ // polynomials.
+ *proof_val = values[0];
+ }
+
+ // Construct the gadget polynomial `G(f[0], ..., f[g_arity-1])` and append it to `proof`.
+ let gadget_poly_len = gadget_poly_len(gadget.degree(), m);
+ let start = proof_len + gadget.arity();
+ let end = start + gadget_poly_len.next_power_of_two();
+ gadget.call_poly(&mut proof[start..end], &f)?;
+ proof_len += gadget.arity() + gadget_poly_len;
+ }
+
+ // Truncate the buffer to the size of the proof.
+ assert_eq!(proof_len, self.proof_len());
+ proof.truncate(proof_len);
+ Ok(proof)
+ }
+
+ /// Query an input and proof and return the verifier message. The return value has length
+ /// [`Self::verifier_len`].
+ ///
+ /// # Parameters
+ ///
+ /// * `input` is the input or input share.
+ /// * `proof` is the proof or proof share.
+ /// * `query_rand` is the verifier's randomness.
+ /// * `joint_rand` is the randomness shared by the prover and verifier.
+ /// * `num_shares` is the total number of input shares.
+ fn query(
+ &self,
+ input: &[Self::Field],
+ proof: &[Self::Field],
+ query_rand: &[Self::Field],
+ joint_rand: &[Self::Field],
+ num_shares: usize,
+ ) -> Result<Vec<Self::Field>, FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len()
+ )));
+ }
+
+ if proof.len() != self.proof_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected proof length: got {}; want {}",
+ proof.len(),
+ self.proof_len()
+ )));
+ }
+
+ if query_rand.len() != self.query_rand_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected query randomness length: got {}; want {}",
+ query_rand.len(),
+ self.query_rand_len()
+ )));
+ }
+
+ if joint_rand.len() != self.joint_rand_len() {
+ return Err(FlpError::Query(format!(
+ "unexpected joint randomness length: got {}; want {}",
+ joint_rand.len(),
+ self.joint_rand_len()
+ )));
+ }
+
+ let mut proof_len = 0;
+ let mut shims = self
+ .gadget()
+ .into_iter()
+ .enumerate()
+ .map(|(idx, gadget)| {
+ let gadget_degree = gadget.degree();
+ let gadget_arity = gadget.arity();
+ let m = (1 + gadget.calls()).next_power_of_two();
+ let r = query_rand[idx];
+
+ // Make sure the query randomness isn't a root of unity. Evaluating the gadget
+ // polynomial at any of these points would be a privacy violation, since these points
+ // were used by the prover to construct the wire polynomials.
+ if r.pow(<Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap())
+ == Self::Field::one()
+ {
+ return Err(FlpError::Query(format!(
+ "invalid query randomness: encountered 2^{m}-th root of unity"
+ )));
+ }
+
+ // Compute the length of the sub-proof corresponding to the `idx`-th gadget.
+ let next_len = gadget_arity + gadget_degree * (m - 1) + 1;
+ let proof_data = &proof[proof_len..proof_len + next_len];
+ proof_len += next_len;
+
+ Ok(Box::new(QueryShimGadget::new(gadget, r, proof_data)?)
+ as Box<dyn Gadget<Self::Field>>)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ // Create a buffer for the verifier data. This includes the output of the validity circuit and,
+ // for each gadget `shim[idx].inner`, the wire polynomials evaluated at the query randomness
+ // `query_rand[idx]` and the gadget polynomial evaluated at `query_rand[idx]`.
+ let data_len = 1 + shims.iter().map(|shim| shim.arity() + 1).sum::<usize>();
+ let mut verifier = Vec::with_capacity(data_len);
+
+ // Run the validity circuit with a sequence of "shim" gadgets that record the inputs to each
+ // wire for each gadget call. Record the output of the circuit and append it to the verifier
+ // message.
+ //
+ // NOTE The proof of [BBC+19, Theorem 4.3] assumes that the output of the validity circuit is
+ // equal to the output of the last gadget evaluation. Here we relax this assumption. This
+ // should be OK, since it's possible to transform any circuit into one for which this is true.
+ // (Needs security analysis.)
+ let validity = self.valid(&mut shims, input, joint_rand, num_shares)?;
+ verifier.push(validity);
+
+ // Fill the buffer with the verifier message.
+ for (query_rand_val, shim) in query_rand[..shims.len()].iter().zip(shims.iter_mut()) {
+ let gadget = shim
+ .as_any()
+ .downcast_ref::<QueryShimGadget<Self::Field>>()
+ .unwrap();
+
+ // Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire
+ // polynomial at query randomness value.
+ let m = (1 + gadget.calls()).next_power_of_two();
+ let m_inv = Self::Field::from(
+ <Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(),
+ )
+ .inv();
+ let mut f = vec![Self::Field::zero(); m];
+ for wire in 0..gadget.arity() {
+ discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?;
+ discrete_fourier_transform_inv_finish(&mut f, m, m_inv);
+ verifier.push(poly_eval(&f, *query_rand_val));
+ }
+
+ // Add the value of the gadget polynomial evaluated at the query randomness value.
+ verifier.push(gadget.p_at_r);
+ }
+
+ assert_eq!(verifier.len(), self.verifier_len());
+ Ok(verifier)
+ }
+
+ /// Returns true if the verifier message indicates that the input from which it was generated is valid.
+ fn decide(&self, verifier: &[Self::Field]) -> Result<bool, FlpError> {
+ if verifier.len() != self.verifier_len() {
+ return Err(FlpError::Decide(format!(
+ "unexpected verifier length: got {}; want {}",
+ verifier.len(),
+ self.verifier_len()
+ )));
+ }
+
+ // Check if the output of the circuit is 0.
+ if verifier[0] != Self::Field::zero() {
+ return Ok(false);
+ }
+
+ // Check that each of the proof polynomials are well-formed.
+ let mut gadgets = self.gadget();
+ let mut verifier_len = 1;
+ for gadget in gadgets.iter_mut() {
+ let next_len = 1 + gadget.arity();
+
+ let e = gadget.call(&verifier[verifier_len..verifier_len + next_len - 1])?;
+ if e != verifier[verifier_len + next_len - 1] {
+ return Ok(false);
+ }
+
+ verifier_len += next_len;
+ }
+
+ Ok(true)
+ }
+
+ /// Check whether `input` and `joint_rand` have the length expected by `self`,
+ /// return [`FlpError::Valid`] otherwise.
+ fn valid_call_check(
+ &self,
+ input: &[Self::Field],
+ joint_rand: &[Self::Field],
+ ) -> Result<(), FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Valid(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len(),
+ )));
+ }
+
+ if joint_rand.len() != self.joint_rand_len() {
+ return Err(FlpError::Valid(format!(
+ "unexpected joint randomness length: got {}; want {}",
+ joint_rand.len(),
+ self.joint_rand_len()
+ )));
+ }
+
+ Ok(())
+ }
+
+ /// Check if the length of `input` matches `self`'s `input_len()`,
+ /// return [`FlpError::Truncate`] otherwise.
+ fn truncate_call_check(&self, input: &[Self::Field]) -> Result<(), FlpError> {
+ if input.len() != self.input_len() {
+ return Err(FlpError::Truncate(format!(
+ "Unexpected input length: got {}; want {}",
+ input.len(),
+ self.input_len()
+ )));
+ }
+
+ Ok(())
+ }
+}
+
+/// A type which supports adding noise to aggregate shares for Server Differential Privacy.
+#[cfg(feature = "experimental")]
+pub trait TypeWithNoise<S>: Type
+where
+ S: DifferentialPrivacyStrategy,
+{
+ /// Add noise to the aggregate share to obtain differential privacy.
+ fn add_noise_to_result(
+ &self,
+ dp_strategy: &S,
+ agg_result: &mut [Self::Field],
+ num_measurements: usize,
+ ) -> Result<(), FlpError>;
+}
+
+/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit.
+pub trait Gadget<F: FftFriendlyFieldElement>: Debug {
+ /// Evaluates the gadget on input `inp` and returns the output.
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError>;
+
+ /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`.
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>;
+
+ /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or
+ /// `call_poly`.
+ fn arity(&self) -> usize;
+
+ /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp`
+ /// buffer passed to `call_poly`.
+ fn degree(&self) -> usize;
+
+ /// Returns the number of times the gadget is expected to be called.
+ fn calls(&self) -> usize;
+
+ /// This call is used to downcast a `Box<dyn Gadget<F>>` to a concrete type.
+ fn as_any(&mut self) -> &mut dyn Any;
+}
+
+// A "shim" gadget used during proof generation to record the input wires each time a gadget is
+// evaluated.
+#[derive(Debug)]
+struct ProveShimGadget<F: FftFriendlyFieldElement> {
+ inner: Box<dyn Gadget<F>>,
+
+ /// Points at which the wire polynomials are interpolated.
+ f_vals: Vec<Vec<F>>,
+
+ /// The number of times the gadget has been called so far.
+ ct: usize,
+}
+
+impl<F: FftFriendlyFieldElement> ProveShimGadget<F> {
+ fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, FlpError> {
+ let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()];
+
+ for (prove_rand_val, wire_poly_vals) in
+ prove_rand[..f_vals.len()].iter().zip(f_vals.iter_mut())
+ {
+ // Choose a random field element as the first point on the wire polynomial.
+ wire_poly_vals[0] = *prove_rand_val;
+ }
+
+ Ok(Self {
+ inner,
+ f_vals,
+ ct: 1,
+ })
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
+ wire_poly_vals[self.ct] = *inp_val;
+ }
+ self.ct += 1;
+ self.inner.call(inp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ self.inner.call_poly(outp, inp)
+ }
+
+ fn arity(&self) -> usize {
+ self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+// A "shim" gadget used during proof verification to record the points at which the intermediate
+// proof polynomials are evaluated.
+#[derive(Debug)]
+struct QueryShimGadget<F: FftFriendlyFieldElement> {
+ inner: Box<dyn Gadget<F>>,
+
+ /// Points at which intermediate proof polynomials are interpolated.
+ f_vals: Vec<Vec<F>>,
+
+ /// Points at which the gadget polynomial is interpolated.
+ p_vals: Vec<F>,
+
+ /// The gadget polynomial evaluated on a random input `r`.
+ p_at_r: F,
+
+ /// Used to compute an index into `p_val`.
+ step: usize,
+
+ /// The number of times the gadget has been called so far.
+ ct: usize,
+}
+
+impl<F: FftFriendlyFieldElement> QueryShimGadget<F> {
+ fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, FlpError> {
+ let gadget_degree = inner.degree();
+ let gadget_arity = inner.arity();
+ let m = (1 + inner.calls()).next_power_of_two();
+ let p = m * gadget_degree;
+
+ // Each call to this gadget records the values at which intermediate proof polynomials were
+ // interpolated. The first point was a random value chosen by the prover and transmitted in
+ // the proof.
+ let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; gadget_arity];
+ for wire in 0..gadget_arity {
+ f_vals[wire][0] = proof_data[wire];
+ }
+
+ // Evaluate the gadget polynomial at roots of unity.
+ let size = p.next_power_of_two();
+ let mut p_vals = vec![F::zero(); size];
+ discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?;
+
+ // The step is used to compute the element of `p_val` that will be returned by a call to
+ // the gadget.
+ let step = (1 << (log2(p as u128) - log2(m as u128))) as usize;
+
+ // Evaluate the gadget polynomial `p` at query randomness `r`.
+ let p_at_r = poly_eval(&proof_data[gadget_arity..], r);
+
+ Ok(Self {
+ inner,
+ f_vals,
+ p_vals,
+ p_at_r,
+ step,
+ ct: 1,
+ })
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
+ wire_poly_vals[self.ct] = *inp_val;
+ }
+ let outp = self.p_vals[self.ct * self.step];
+ self.ct += 1;
+ Ok(outp)
+ }
+
+ fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> {
+ panic!("no-op");
+ }
+
+ fn arity(&self) -> usize {
+ self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// Compute the length of the wire polynomial constructed from the given number of gadget calls.
+#[inline]
+pub(crate) fn wire_poly_len(num_calls: usize) -> usize {
+ (1 + num_calls).next_power_of_two()
+}
+
+/// Compute the length of the gadget polynomial for a gadget with the given degree and from wire
+/// polynomials of the given length.
+#[inline]
+pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usize {
+ gadget_degree * (wire_poly_len - 1) + 1
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{random_vector, split_vector, Field128};
+ use crate::flp::gadgets::{Mul, PolyEval};
+ use crate::polynomial::poly_range_check;
+
+ use std::marker::PhantomData;
+
+ // Simple integration test for the core FLP logic. You'll find more extensive unit tests for
+ // each implemented data type in src/types.rs.
+ #[test]
+ fn test_flp() {
+ const NUM_SHARES: usize = 2;
+
+ let typ: TestType<Field128> = TestType::new();
+ let input = typ.encode_measurement(&3).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+
+ let input_shares: Vec<Vec<Field128>> = split_vector(input.as_slice(), NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ assert_eq!(proof.len(), typ.proof_len());
+
+ let proof_shares: Vec<Vec<Field128>> = split_vector(&proof, NUM_SHARES)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let verifier: Vec<Field128> = (0..NUM_SHARES)
+ .map(|i| {
+ typ.query(
+ &input_shares[i],
+ &proof_shares[i],
+ &query_rand,
+ &joint_rand,
+ NUM_SHARES,
+ )
+ .unwrap()
+ })
+ .reduce(|mut left, right| {
+ for (x, y) in left.iter_mut().zip(right.iter()) {
+ *x += *y;
+ }
+ left
+ })
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ /// A toy type used for testing multiple gadgets. Valid inputs of this type consist of a pair
+ /// of field elements `(x, y)` where `2 <= x < 5` and `x^3 == y`.
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct TestType<F>(PhantomData<F>);
+
+ impl<F> TestType<F> {
+ fn new() -> Self {
+ Self(PhantomData)
+ }
+ }
+
+ impl<F: FftFriendlyFieldElement> Type for TestType<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ let r = joint_rand[0];
+ let mut res = F::zero();
+
+ // Check that `data[0]^3 == data[1]`.
+ let mut inp = [input[0], input[0]];
+ inp[0] = g[0].call(&inp)?;
+ inp[0] = g[0].call(&inp)?;
+ let x3_diff = inp[0] - input[1];
+ res += r * x3_diff;
+
+ // Check that `data[0]` is in the correct range.
+ let x_checked = g[1].call(&[input[0]])?;
+ res += (r * r) * x_checked;
+
+ Ok(res)
+ }
+
+ fn input_len(&self) -> usize {
+ 2
+ }
+
+ fn proof_len(&self) -> usize {
+ // First chunk
+ let mul = 2 /* gadget arity */ + 2 /* gadget degree */ * (
+ (1 + 2_usize /* gadget calls */).next_power_of_two() - 1) + 1;
+
+ // Second chunk
+ let poly = 1 /* gadget arity */ + 3 /* gadget degree */ * (
+ (1 + 1_usize /* gadget calls */).next_power_of_two() - 1) + 1;
+
+ mul + poly
+ }
+
+ fn verifier_len(&self) -> usize {
+ // First chunk
+ let mul = 1 + 2 /* gadget arity */;
+
+ // Second chunk
+ let poly = 1 + 1 /* gadget arity */;
+
+ 1 + mul + poly
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 3
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![
+ Box::new(Mul::new(2)),
+ Box::new(PolyEval::new(poly_range_check(2, 5), 1)),
+ ]
+ }
+
+ fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
+ Ok(vec![
+ F::from(*measurement),
+ F::from(*measurement).pow(F::Integer::try_from(3).unwrap()),
+ ])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ Ok(input)
+ }
+
+ fn decode_result(
+ &self,
+ _data: &[F],
+ _num_measurements: usize,
+ ) -> Result<F::Integer, FlpError> {
+ panic!("not implemented");
+ }
+ }
+
+ // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that
+ // gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than
+ // needed for computing the gadget polynomial.
+ #[test]
+ fn issue254() {
+ let typ: Issue254Type<Field128> = Issue254Type::new();
+ let input = typ.encode_measurement(&0).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct Issue254Type<F> {
+ num_gadget_calls: [usize; 2],
+ phantom: PhantomData<F>,
+ }
+
+ impl<F> Issue254Type<F> {
+ fn new() -> Self {
+ Self {
+ // The bug is triggered when there are two gadgets, but it doesn't matter how many
+ // times the second gadget is called.
+ num_gadget_calls: [100, 0],
+ phantom: PhantomData,
+ }
+ }
+ }
+
+ impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ _joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ // This is a useless circuit, as it only accepts "0". Its purpose is to exercise the
+ // use of multiple gadgets, each of which is called an arbitrary number of times.
+ let mut res = F::zero();
+ for _ in 0..self.num_gadget_calls[0] {
+ res += g[0].call(&[input[0]])?;
+ }
+ for _ in 0..self.num_gadget_calls[1] {
+ res += g[1].call(&[input[0]])?;
+ }
+ Ok(res)
+ }
+
+ fn input_len(&self) -> usize {
+ 1
+ }
+
+ fn proof_len(&self) -> usize {
+ // First chunk
+ let first = 1 /* gadget arity */ + 2 /* gadget degree */ * (
+ (1 + self.num_gadget_calls[0]).next_power_of_two() - 1) + 1;
+
+ // Second chunk
+ let second = 1 /* gadget arity */ + 2 /* gadget degree */ * (
+ (1 + self.num_gadget_calls[1]).next_power_of_two() - 1) + 1;
+
+ first + second
+ }
+
+ fn verifier_len(&self) -> usize {
+ // First chunk
+ let first = 1 + 1 /* gadget arity */;
+
+ // Second chunk
+ let second = 1 + 1 /* gadget arity */;
+
+ 1 + first + second
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 0
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ // First chunk
+ let first = 1; // gadget arity
+
+ // Second chunk
+ let second = 1; // gadget arity
+
+ first + second
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 2 // number of gadgets
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ let poly = poly_range_check(0, 2); // A polynomial with degree 2
+ vec![
+ Box::new(PolyEval::new(poly.clone(), self.num_gadget_calls[0])),
+ Box::new(PolyEval::new(poly, self.num_gadget_calls[1])),
+ ]
+ }
+
+ fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
+ Ok(vec![F::from(*measurement)])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ Ok(input)
+ }
+
+ fn decode_result(
+ &self,
+ _data: &[F],
+ _num_measurements: usize,
+ ) -> Result<F::Integer, FlpError> {
+ panic!("not implemented");
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs
new file mode 100644
index 0000000000..c2696665f4
--- /dev/null
+++ b/third_party/rust/prio/src/flp/gadgets.rs
@@ -0,0 +1,591 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of gadgets.
+
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
+use crate::field::FftFriendlyFieldElement;
+use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
+use crate::polynomial::{poly_deg, poly_eval, poly_mul};
+
+#[cfg(feature = "multithreaded")]
+use rayon::prelude::*;
+
+use std::any::Any;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
+/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
+const FFT_THRESHOLD: usize = 60;
+
+/// An arity-2 gadget that multiples its inputs.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Mul<F: FftFriendlyFieldElement> {
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FftFriendlyFieldElement> Mul<F> {
+ /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
+ /// called by the validity circuit.
+ pub fn new(num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(2, num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+
+ // Multiply input polynomials directly.
+ pub(crate) fn call_poly_direct(
+ &mut self,
+ outp: &mut [F],
+ inp: &[Vec<F>],
+ ) -> Result<(), FlpError> {
+ let v = poly_mul(&inp[0], &inp[1]);
+ outp[..v.len()].clone_from_slice(&v);
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let mut buf = vec![F::zero(); n];
+
+ discrete_fourier_transform(&mut buf, &inp[0], n)?;
+ discrete_fourier_transform(outp, &inp[1], n)?;
+
+ for i in 0..n {
+ buf[i] *= outp[i];
+ }
+
+ discrete_fourier_transform(outp, &buf, n)?;
+ discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
+ Ok(())
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Gadget<F> for Mul<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(inp[0] * inp[1])
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 2
+ }
+
+ fn degree(&self) -> usize {
+ 2
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// An arity-1 gadget that evaluates its input on some polynomial.
+//
+// TODO Make `poly` an array of length determined by a const generic.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct PolyEval<F: FftFriendlyFieldElement> {
+ poly: Vec<F>,
+ /// Size of buffer for FFT operations.
+ n: usize,
+ /// Inverse of `n` in `F`.
+ n_inv: F,
+ /// The number of times this gadget will be called.
+ num_calls: usize,
+}
+
+impl<F: FftFriendlyFieldElement> PolyEval<F> {
+ /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
+ /// this gadget is called by the validity circuit.
+ pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
+ let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
+ let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
+ Self {
+ poly,
+ n,
+ n_inv,
+ num_calls,
+ }
+ }
+}
+
+impl<F: FftFriendlyFieldElement> PolyEval<F> {
+ // Multiply input polynomials directly.
+ fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ outp[0] = self.poly[0];
+ let mut x = inp[0].to_vec();
+ for i in 1..self.poly.len() {
+ for j in 0..x.len() {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ x = poly_mul(&x, &inp[0]);
+ }
+ }
+ Ok(())
+ }
+
+ // Multiply input polynomials using FFT.
+ fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ let n = self.n;
+ let inp = &inp[0];
+
+ let mut inp_vals = vec![F::zero(); n];
+ discrete_fourier_transform(&mut inp_vals, inp, n)?;
+
+ let mut x_vals = inp_vals.clone();
+ let mut x = vec![F::zero(); n];
+ x[..inp.len()].clone_from_slice(inp);
+
+ outp[0] = self.poly[0];
+ for i in 1..self.poly.len() {
+ for j in 0..n {
+ outp[j] += self.poly[i] * x[j];
+ }
+
+ if i < self.poly.len() - 1 {
+ for j in 0..n {
+ x_vals[j] *= inp_vals[j];
+ }
+
+ discrete_fourier_transform(&mut x, &x_vals, n)?;
+ discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Gadget<F> for PolyEval<F> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ Ok(poly_eval(&self.poly, inp[0]))
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for item in outp.iter_mut() {
+ *item = F::zero();
+ }
+
+ if inp[0].len() >= FFT_THRESHOLD {
+ self.call_poly_fft(outp, inp)
+ } else {
+ self.call_poly_direct(outp, inp)
+ }
+ }
+
+ fn arity(&self) -> usize {
+ 1
+ }
+
+ fn degree(&self) -> usize {
+ poly_deg(&self.poly)
+ }
+
+ fn calls(&self) -> usize {
+ self.num_calls
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// Trait for abstracting over [`ParallelSum`].
+pub trait ParallelSumGadget<F: FftFriendlyFieldElement, G>: Gadget<F> + Debug {
+ /// Wraps `inner` into a sum gadget that calls it `chunks` many times, and adds the reuslts.
+ fn new(inner: G, chunks: usize) -> Self;
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of times it is
+/// called.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSum<F: FftFriendlyFieldElement, G: Gadget<F>> {
+ inner: G,
+ chunks: usize,
+ phantom: PhantomData<F>,
+}
+
+impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G>
+ for ParallelSum<F, G>
+{
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ inner,
+ chunks,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ gadget_call_check(self, inp.len())?;
+ let mut outp = F::zero();
+ for chunk in inp.chunks(self.inner.arity()) {
+ outp += self.inner.call(chunk)?;
+ }
+ Ok(outp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ for x in outp.iter_mut() {
+ *x = F::zero();
+ }
+
+ let mut partial_outp = vec![F::zero(); outp.len()];
+
+ for chunk in inp.chunks(self.inner.arity()) {
+ self.inner.call_poly(&mut partial_outp, chunk)?;
+ for i in 0..outp.len() {
+ outp[i] += partial_outp[i]
+ }
+ }
+
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.chunks * self.inner.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.inner.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.inner.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
+/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
+/// evaluation is multithreaded.
+#[cfg(feature = "multithreaded")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct ParallelSumMultithreaded<F: FftFriendlyFieldElement, G: Gadget<F>> {
+ serial_sum: ParallelSum<F, G>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
+where
+ F: FftFriendlyFieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn new(inner: G, chunks: usize) -> Self {
+ Self {
+ serial_sum: ParallelSum::new(inner, chunks),
+ }
+ }
+}
+
+/// Data structures passed between fold operations in [`ParallelSumMultithreaded`].
+#[cfg(feature = "multithreaded")]
+struct ParallelSumFoldState<F, G> {
+ /// Inner gadget.
+ inner: G,
+ /// Output buffer for `call_poly()`.
+ partial_output: Vec<F>,
+ /// Sum accumulator.
+ partial_sum: Vec<F>,
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> ParallelSumFoldState<F, G> {
+ fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G>
+ where
+ G: Clone,
+ F: FftFriendlyFieldElement,
+ {
+ ParallelSumFoldState {
+ inner: gadget.clone(),
+ partial_output: vec![F::zero(); length],
+ partial_sum: vec![F::zero(); length],
+ }
+ }
+}
+
+#[cfg(feature = "multithreaded")]
+impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
+where
+ F: FftFriendlyFieldElement + Sync + Send,
+ G: 'static + Gadget<F> + Clone + Sync + Send,
+{
+ fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
+ self.serial_sum.call(inp)
+ }
+
+ fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
+ gadget_call_poly_check(self, outp, inp)?;
+
+ // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
+ // gadget on each input polynomial, using the first temporary buffer as an output buffer.
+ // Then accumulate that result into the second temporary buffer, which acts as a running
+ // sum. Then, discard everything but the partial sums, add them, and finally copy the sum
+ // to the output parameter. This is equivalent to the single threaded calculation in
+ // ParallelSum, since we only rearrange additions, and field addition is associative.
+ let res = inp
+ .par_chunks(self.serial_sum.inner.arity())
+ .fold(
+ || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()),
+ |mut state, chunk| {
+ state
+ .inner
+ .call_poly(&mut state.partial_output, chunk)
+ .unwrap();
+ for (sum_elem, output_elem) in state
+ .partial_sum
+ .iter_mut()
+ .zip(state.partial_output.iter())
+ {
+ *sum_elem += *output_elem;
+ }
+ state
+ },
+ )
+ .map(|state| state.partial_sum)
+ .reduce(
+ || vec![F::zero(); outp.len()],
+ |mut x, y| {
+ for (xi, yi) in x.iter_mut().zip(y.iter()) {
+ *xi += *yi;
+ }
+ x
+ },
+ );
+
+ outp.copy_from_slice(&res[..]);
+ Ok(())
+ }
+
+ fn arity(&self) -> usize {
+ self.serial_sum.arity()
+ }
+
+ fn degree(&self) -> usize {
+ self.serial_sum.degree()
+ }
+
+ fn calls(&self) -> usize {
+ self.serial_sum.calls()
+ }
+
+ fn as_any(&mut self) -> &mut dyn Any {
+ self
+ }
+}
+
+// Check that the input parameters of g.call() are well-formed.
+fn gadget_call_check<F: FftFriendlyFieldElement, G: Gadget<F>>(
+ gadget: &G,
+ in_len: usize,
+) -> Result<(), FlpError> {
+ if in_len != gadget.arity() {
+ return Err(FlpError::Gadget(format!(
+ "unexpected number of inputs: got {}; want {}",
+ in_len,
+ gadget.arity()
+ )));
+ }
+
+ if in_len == 0 {
+ return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
+ }
+
+ Ok(())
+}
+
+// Check that the input parameters of g.call_poly() are well-formed.
+fn gadget_call_poly_check<F: FftFriendlyFieldElement, G: Gadget<F>>(
+ gadget: &G,
+ outp: &[F],
+ inp: &[Vec<F>],
+) -> Result<(), FlpError>
+where
+ G: Gadget<F>,
+{
+ gadget_call_check(gadget, inp.len())?;
+
+ for i in 1..inp.len() {
+ if inp[i].len() != inp[0].len() {
+ return Err(FlpError::Gadget(
+ "gadget called on wire polynomials with different lengths".to_string(),
+ ));
+ }
+ }
+
+ let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
+ if outp.len() != expected {
+ return Err(FlpError::Gadget(format!(
+ "incorrect output length: got {}; want {}",
+ outp.len(),
+ expected
+ )));
+ }
+
+ Ok(())
+}
+
+#[inline]
+fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
+ gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[cfg(feature = "multithreaded")]
+ use crate::field::FieldElement;
+ use crate::field::{random_vector, Field64 as TestField};
+ use crate::prng::Prng;
+
+ #[test]
+ fn test_mul() {
+ // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
+ // naive multiplication code path.
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+
+ // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
+ // FFT-based polynomial multiplication.
+ let num_calls = FFT_THRESHOLD;
+ let mut g: Mul<TestField> = Mul::new(num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_poly_eval() {
+ let poly: Vec<TestField> = random_vector(10).unwrap();
+
+ let num_calls = FFT_THRESHOLD / 2;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
+ gadget_test(&mut g, num_calls);
+
+ let num_calls = FFT_THRESHOLD;
+ let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ fn test_parallel_sum() {
+ let num_calls = 10;
+ let chunks = 23;
+
+ let mut g = ParallelSum::new(Mul::<TestField>::new(num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_parallel_sum_multithreaded() {
+ use std::iter;
+
+ for num_calls in [1, 10, 100] {
+ let chunks = 23;
+
+ let mut g = ParallelSumMultithreaded::new(Mul::new(num_calls), chunks);
+ gadget_test(&mut g, num_calls);
+
+ // Test that the multithreaded version has the same output as the normal version.
+ let mut g_serial = ParallelSum::new(Mul::new(num_calls), chunks);
+ assert_eq!(g.arity(), g_serial.arity());
+ assert_eq!(g.degree(), g_serial.degree());
+ assert_eq!(g.calls(), g_serial.calls());
+
+ let arity = g.arity();
+ let degree = g.degree();
+
+ // Test that both gadgets evaluate to the same value when run on scalar inputs.
+ let inp: Vec<TestField> = random_vector(arity).unwrap();
+ let result = g.call(&inp).unwrap();
+ let result_serial = g_serial.call(&inp).unwrap();
+ assert_eq!(result, result_serial);
+
+ // Test that both gadgets evaluate to the same value when run on polynomial inputs.
+ let mut poly_outp =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut poly_outp_serial =
+ vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
+ let mut prng: Prng<TestField, _> = Prng::new().unwrap();
+ let poly_inp: Vec<_> = iter::repeat_with(|| {
+ iter::repeat_with(|| prng.get())
+ .take(1 + num_calls)
+ .collect::<Vec<_>>()
+ })
+ .take(arity)
+ .collect();
+
+ g.call_poly(&mut poly_outp, &poly_inp).unwrap();
+ g_serial
+ .call_poly(&mut poly_outp_serial, &poly_inp)
+ .unwrap();
+ assert_eq!(poly_outp, poly_outp_serial);
+ }
+ }
+
+ // Test that calling g.call_poly() and evaluating the output at a given point is equivalent
+ // to evaluating each of the inputs at the same point and applying g.call() on the results.
+ fn gadget_test<F: FftFriendlyFieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
+ let wire_poly_len = (1 + num_calls).next_power_of_two();
+ let mut prng = Prng::new().unwrap();
+ let mut inp = vec![F::zero(); g.arity()];
+ let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
+ let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];
+
+ let r = prng.get();
+ for i in 0..g.arity() {
+ for j in 0..wire_poly_len {
+ wire_polys[i][j] = prng.get();
+ }
+ inp[i] = poly_eval(&wire_polys[i], r);
+ }
+
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ let want = g.call(&inp).unwrap();
+ assert_eq!(got, want);
+
+ // Repeat the call to make sure that the gadget's memory is reset properly between calls.
+ g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
+ let got = poly_eval(&gadget_poly, r);
+ assert_eq!(got, want);
+ }
+}
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs
new file mode 100644
index 0000000000..18c290355c
--- /dev/null
+++ b/third_party/rust/prio/src/flp/types.rs
@@ -0,0 +1,1415 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A collection of [`Type`] implementations.
+
+use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt};
+use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
+use crate::flp::{FlpError, Gadget, Type};
+use crate::polynomial::poly_range_check;
+use std::convert::TryInto;
+use std::fmt::{self, Debug};
+use std::marker::PhantomData;
+/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of the measurements (i.e., the total number of `1s`).
+#[derive(Clone, PartialEq, Eq)]
+pub struct Count<F> {
+ range_checker: Vec<F>,
+}
+
+impl<F> Debug for Count<F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Count").finish()
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Count<F> {
+ /// Return a new [`Count`] type instance.
+ pub fn new() -> Self {
+ Self {
+ range_checker: poly_range_check(0, 2),
+ }
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Default for Count<F> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Type for Count<F> {
+ const ID: u32 = 0x00000000;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let max = F::valid_integer_try_from(1)?;
+ if *value > max {
+ return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
+ }
+
+ Ok(vec![F::from(*value)])
+ }
+
+ fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
+ decode_result(data)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(Mul::new(1))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ Ok(g[0].call(&[input[0], input[0]])? - input[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ 1
+ }
+
+ fn proof_len(&self) -> usize {
+ 5
+ }
+
+ fn verifier_len(&self) -> usize {
+ 4
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 0
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of
+/// the measurements.
+///
+/// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3].
+///
+/// [BBCG+19]: https://ia.cr/2019/188
+#[derive(Clone, PartialEq, Eq)]
+pub struct Sum<F: FftFriendlyFieldElement> {
+ bits: usize,
+ range_checker: Vec<F>,
+}
+
+impl<F: FftFriendlyFieldElement> Debug for Sum<F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Sum").field("bits", &self.bits).finish()
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Sum<F> {
+ /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0,
+ /// 2^bits)`.
+ pub fn new(bits: usize) -> Result<Self, FlpError> {
+ if !F::valid_integer_bitlength(bits) {
+ return Err(FlpError::Encode(
+ "invalid bits: number of bits exceeds maximum number of bits in this field"
+ .to_string(),
+ ));
+ }
+ Ok(Self {
+ bits,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Type for Sum<F> {
+ const ID: u32 = 0x00000001;
+ type Measurement = F::Integer;
+ type AggregateResult = F::Integer;
+ type Field = F;
+
+ fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ Ok(v)
+ }
+
+ fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
+ decode_result(data)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.clone(),
+ self.bits,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let res = F::decode_from_bitvector_representation(&input)?;
+ Ok(vec![res])
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.bits).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ 1
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
+/// aggregate is the arithmetic average.
+#[derive(Clone, PartialEq, Eq)]
+pub struct Average<F: FftFriendlyFieldElement> {
+ bits: usize,
+ range_checker: Vec<F>,
+}
+
+impl<F: FftFriendlyFieldElement> Debug for Average<F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Average").field("bits", &self.bits).finish()
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Average<F> {
+ /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0,
+ /// 2^bits)`.
+ pub fn new(bits: usize) -> Result<Self, FlpError> {
+ if !F::valid_integer_bitlength(bits) {
+ return Err(FlpError::Encode(
+ "invalid bits: number of bits exceeds maximum number of bits in this field"
+ .to_string(),
+ ));
+ }
+ Ok(Self {
+ bits,
+ range_checker: poly_range_check(0, 2),
+ })
+ }
+}
+
+impl<F: FftFriendlyFieldElement> Type for Average<F> {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = F::Integer;
+ type AggregateResult = f64;
+ type Field = F;
+
+ fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
+ let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ Ok(v)
+ }
+
+ fn decode_result(&self, data: &[F], num_measurements: usize) -> Result<f64, FlpError> {
+ // Compute the average from the aggregated sum.
+ let data = decode_result(data)?;
+ let data: u64 = data.try_into().map_err(|err| {
+ FlpError::Decode(format!("failed to convert {data:?} to u64: {err}",))
+ })?;
+ let result = (data as f64) / (num_measurements as f64);
+ Ok(result)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(PolyEval::new(
+ self.range_checker.clone(),
+ self.bits,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ _num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+ call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0])
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let res = F::decode_from_bitvector_representation(&input)?;
+ Ok(vec![res])
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits
+ }
+
+ fn proof_len(&self) -> usize {
+ 2 * ((1 + self.bits).next_power_of_two() - 1) + 2
+ }
+
+ fn verifier_len(&self) -> usize {
+ 3
+ }
+
+ fn output_len(&self) -> usize {
+ 1
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// The histogram type. Each measurement is an integer in `[0, length)` and the aggregate is a
+/// histogram counting the number of occurrences of each measurement.
+#[derive(PartialEq, Eq)]
+pub struct Histogram<F, S> {
+ length: usize,
+ chunk_length: usize,
+ gadget_calls: usize,
+ phantom: PhantomData<(F, S)>,
+}
+
+impl<F: FftFriendlyFieldElement, S> Debug for Histogram<F, S> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Histogram")
+ .field("length", &self.length)
+ .field("chunk_length", &self.chunk_length)
+ .finish()
+ }
+}
+
+impl<F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>>> Histogram<F, S> {
+ /// Return a new [`Histogram`] type with the given number of buckets.
+ pub fn new(length: usize, chunk_length: usize) -> Result<Self, FlpError> {
+ if length >= u32::MAX as usize {
+ return Err(FlpError::Encode(
+ "invalid length: number of buckets exceeds maximum permitted".to_string(),
+ ));
+ }
+ if length == 0 {
+ return Err(FlpError::InvalidParameter(
+ "length cannot be zero".to_string(),
+ ));
+ }
+ if chunk_length == 0 {
+ return Err(FlpError::InvalidParameter(
+ "chunk_length cannot be zero".to_string(),
+ ));
+ }
+
+ let mut gadget_calls = length / chunk_length;
+ if length % chunk_length != 0 {
+ gadget_calls += 1;
+ }
+
+ Ok(Self {
+ length,
+ chunk_length,
+ gadget_calls,
+ phantom: PhantomData,
+ })
+ }
+}
+
+impl<F, S> Clone for Histogram<F, S> {
+ fn clone(&self) -> Self {
+ Self {
+ length: self.length,
+ chunk_length: self.chunk_length,
+ gadget_calls: self.gadget_calls,
+ phantom: self.phantom,
+ }
+ }
+}
+
+impl<F, S> Type for Histogram<F, S>
+where
+ F: FftFriendlyFieldElement,
+ S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
+{
+ const ID: u32 = 0x00000003;
+ type Measurement = usize;
+ type AggregateResult = Vec<F::Integer>;
+ type Field = F;
+
+ fn encode_measurement(&self, measurement: &usize) -> Result<Vec<F>, FlpError> {
+ let mut data = vec![F::zero(); self.length];
+
+ data[*measurement] = F::one();
+ Ok(data)
+ }
+
+ fn decode_result(
+ &self,
+ data: &[F],
+ _num_measurements: usize,
+ ) -> Result<Vec<F::Integer>, FlpError> {
+ decode_result_vec(data, self.length)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(S::new(
+ Mul::new(self.gadget_calls),
+ self.chunk_length,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ // Check that each element of `input` is a 0 or 1.
+ let range_check = parallel_sum_range_checks(
+ &mut g[0],
+ input,
+ joint_rand[0],
+ self.chunk_length,
+ num_shares,
+ )?;
+
+ // Check that the elements of `input` sum to 1.
+ let mut sum_check = -(F::one() / F::from(F::valid_integer_try_from(num_shares)?));
+ for val in input.iter() {
+ sum_check += *val;
+ }
+
+ // Take a random linear combination of both checks.
+ let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check;
+ Ok(out)
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ Ok(input)
+ }
+
+ fn input_len(&self) -> usize {
+ self.length
+ }
+
+ fn proof_len(&self) -> usize {
+ (self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1
+ }
+
+ fn verifier_len(&self) -> usize {
+ 2 + self.chunk_length * 2
+ }
+
+ fn output_len(&self) -> usize {
+ self.input_len()
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ self.chunk_length * 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19],
+/// Corollary 4.9] to reduce the proof size to roughly the square root of the input size.
+///
+/// [BBCG+19]: https://eprint.iacr.org/2019/188
+#[derive(PartialEq, Eq)]
+pub struct SumVec<F: FftFriendlyFieldElement, S> {
+ len: usize,
+ bits: usize,
+ flattened_len: usize,
+ max: F::Integer,
+ chunk_length: usize,
+ gadget_calls: usize,
+ phantom: PhantomData<S>,
+}
+
+impl<F: FftFriendlyFieldElement, S> Debug for SumVec<F, S> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SumVec")
+ .field("len", &self.len)
+ .field("bits", &self.bits)
+ .field("chunk_length", &self.chunk_length)
+ .finish()
+ }
+}
+
+impl<F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>>> SumVec<F, S> {
+ /// Returns a new [`SumVec`] with the desired bit width and vector length.
+ ///
+ /// # Errors
+ ///
+ /// * The length of the encoded measurement, i.e., `bits * len`, overflows addressable memory.
+ /// * The bit width cannot be encoded, i.e., `bits` is larger than or equal to the number of
+ /// bits required to encode field elements.
+ /// * Any of `bits`, `len`, or `chunk_length` are zero.
+ pub fn new(bits: usize, len: usize, chunk_length: usize) -> Result<Self, FlpError> {
+ let flattened_len = bits.checked_mul(len).ok_or_else(|| {
+ FlpError::InvalidParameter("`bits*len` overflows addressable memory".into())
+ })?;
+
+ // Check if the bit width is too large. This limit is defined to be one bit less than the
+ // number of bits required to encode `F::Integer`. (One less so that we can compute `1 <<
+ // bits` without overflowing.)
+ let limit = std::mem::size_of::<F::Integer>() * 8 - 1;
+ if bits > limit {
+ return Err(FlpError::InvalidParameter(format!(
+ "bit wdith exceeds limit of {limit}"
+ )));
+ }
+
+ // Check for degenerate parameters.
+ if bits == 0 {
+ return Err(FlpError::InvalidParameter(
+ "bits cannot be zero".to_string(),
+ ));
+ }
+ if len == 0 {
+ return Err(FlpError::InvalidParameter("len cannot be zero".to_string()));
+ }
+ if chunk_length == 0 {
+ return Err(FlpError::InvalidParameter(
+ "chunk_length cannot be zero".to_string(),
+ ));
+ }
+
+ // Compute the largest encodable measurement.
+ let one = F::Integer::from(F::one());
+ let max = (one << bits) - one;
+
+ let mut gadget_calls = flattened_len / chunk_length;
+ if flattened_len % chunk_length != 0 {
+ gadget_calls += 1;
+ }
+
+ Ok(Self {
+ len,
+ bits,
+ flattened_len,
+ max,
+ chunk_length,
+ gadget_calls,
+ phantom: PhantomData,
+ })
+ }
+}
+
+impl<F: FftFriendlyFieldElement, S> Clone for SumVec<F, S> {
+ fn clone(&self) -> Self {
+ Self {
+ len: self.len,
+ bits: self.bits,
+ flattened_len: self.flattened_len,
+ max: self.max,
+ chunk_length: self.chunk_length,
+ gadget_calls: self.gadget_calls,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl<F, S> Type for SumVec<F, S>
+where
+ F: FftFriendlyFieldElement,
+ S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
+{
+ const ID: u32 = 0x00000002;
+ type Measurement = Vec<F::Integer>;
+ type AggregateResult = Vec<F::Integer>;
+ type Field = F;
+
+ fn encode_measurement(&self, measurement: &Vec<F::Integer>) -> Result<Vec<F>, FlpError> {
+ if measurement.len() != self.len {
+ return Err(FlpError::Encode(format!(
+ "unexpected measurement length: got {}; want {}",
+ measurement.len(),
+ self.len
+ )));
+ }
+
+ let mut flattened = vec![F::zero(); self.flattened_len];
+ for (summand, chunk) in measurement
+ .iter()
+ .zip(flattened.chunks_exact_mut(self.bits))
+ {
+ if summand > &self.max {
+ return Err(FlpError::Encode(format!(
+ "summand exceeds maximum of 2^{}-1",
+ self.bits
+ )));
+ }
+ F::fill_with_bitvector_representation(summand, chunk)?;
+ }
+
+ Ok(flattened)
+ }
+
+ fn decode_result(
+ &self,
+ data: &[F],
+ _num_measurements: usize,
+ ) -> Result<Vec<F::Integer>, FlpError> {
+ decode_result_vec(data, self.len)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
+ vec![Box::new(S::new(
+ Mul::new(self.gadget_calls),
+ self.chunk_length,
+ ))]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<F>>>,
+ input: &[F],
+ joint_rand: &[F],
+ num_shares: usize,
+ ) -> Result<F, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ parallel_sum_range_checks(
+ &mut g[0],
+ input,
+ joint_rand[0],
+ self.chunk_length,
+ num_shares,
+ )
+ }
+
+ fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
+ self.truncate_call_check(&input)?;
+ let mut unflattened = Vec::with_capacity(self.len);
+ for chunk in input.chunks(self.bits) {
+ unflattened.push(F::decode_from_bitvector_representation(chunk)?);
+ }
+ Ok(unflattened)
+ }
+
+ fn input_len(&self) -> usize {
+ self.flattened_len
+ }
+
+ fn proof_len(&self) -> usize {
+ (self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1
+ }
+
+ fn verifier_len(&self) -> usize {
+ 2 + self.chunk_length * 2
+ }
+
+ fn output_len(&self) -> usize {
+ self.len
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 1
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ self.chunk_length * 2
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 1
+ }
+}
+
+/// Compute a random linear combination of the result of calls of `g` on each element of `input`.
+///
+/// # Arguments
+///
+/// * `g` - The gadget to be applied elementwise
+/// * `input` - The vector on whose elements to apply `g`
+/// * `rnd` - The randomness used for the linear combination
+pub(crate) fn call_gadget_on_vec_entries<F: FftFriendlyFieldElement>(
+ g: &mut Box<dyn Gadget<F>>,
+ input: &[F],
+ rnd: F,
+) -> Result<F, FlpError> {
+ let mut range_check = F::zero();
+ let mut r = rnd;
+ for chunk in input.chunks(1) {
+ range_check += r * g.call(chunk)?;
+ r *= rnd;
+ }
+ Ok(range_check)
+}
+
+/// Given a vector `data` of field elements which should contain exactly one entry, return the
+/// integer representation of that entry.
+pub(crate) fn decode_result<F: FftFriendlyFieldElement>(
+ data: &[F],
+) -> Result<F::Integer, FlpError> {
+ if data.len() != 1 {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ Ok(F::Integer::from(data[0]))
+}
+
+/// Given a vector `data` of field elements, return a vector containing the corresponding integer
+/// representations, if the number of entries matches `expected_len`.
+pub(crate) fn decode_result_vec<F: FftFriendlyFieldElement>(
+ data: &[F],
+ expected_len: usize,
+) -> Result<Vec<F::Integer>, FlpError> {
+ if data.len() != expected_len {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ Ok(data.iter().map(|elem| F::Integer::from(*elem)).collect())
+}
+
+/// This evaluates range checks on a slice of field elements, using a ParallelSum gadget evaluating
+/// many multiplication gates.
+///
+/// # Arguments
+///
+/// * `gadget`: A `ParallelSumGadget<F, Mul<F>>` gadget, or a shim wrapping the same.
+/// * `input`: A slice of inputs. This calculation will check that all inputs were zero or one
+/// before secret sharing.
+/// * `joint_randomness`: A joint randomness value, used to compute a random linear combination of
+/// individual range checks.
+/// * `chunk_length`: How many multiplication gates per ParallelSum gadget. This must match what the
+/// gadget was constructed with.
+/// * `num_shares`: The number of shares that the inputs were secret shared into. This is needed to
+/// correct constant terms in the circuit.
+///
+/// # Returns
+///
+/// This returns (additive shares of) zero if all inputs were zero or one, and otherwise returns a
+/// non-zero value with high probability.
+pub(crate) fn parallel_sum_range_checks<F: FftFriendlyFieldElement>(
+ gadget: &mut Box<dyn Gadget<F>>,
+ input: &[F],
+ joint_randomness: F,
+ chunk_length: usize,
+ num_shares: usize,
+) -> Result<F, FlpError> {
+ let f_num_shares = F::from(F::valid_integer_try_from::<usize>(num_shares)?);
+ let num_shares_inverse = f_num_shares.inv();
+
+ let mut output = F::zero();
+ let mut r_power = joint_randomness;
+ let mut padded_chunk = vec![F::zero(); 2 * chunk_length];
+
+ for chunk in input.chunks(chunk_length) {
+ // Construct arguments for the Mul subcircuits.
+ for (input, args) in chunk.iter().zip(padded_chunk.chunks_exact_mut(2)) {
+ args[0] = r_power * *input;
+ args[1] = *input - num_shares_inverse;
+ r_power *= joint_randomness;
+ }
+ // If the chunk of the input is smaller than chunk_length, use zeros instead of measurement
+ // inputs for the remaining calls.
+ for args in padded_chunk[chunk.len() * 2..].chunks_exact_mut(2) {
+ args[0] = F::zero();
+ args[1] = -num_shares_inverse;
+ // Skip updating r_power. This inner loop is only used during the last iteration of the
+ // outer loop, if the last input chunk is a partial chunk. Thus, r_power won't be
+ // accessed again before returning.
+ }
+
+ output += gadget.call(&padded_chunk)?;
+ }
+
+ Ok(output)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::field::{random_vector, Field64 as TestField, FieldElement};
+ use crate::flp::gadgets::ParallelSum;
+ #[cfg(feature = "multithreaded")]
+ use crate::flp::gadgets::ParallelSumMultithreaded;
+ use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase};
+ use std::cmp;
+
+ #[test]
+ fn test_count() {
+ let count: Count<TestField> = Count::new();
+ let zero = TestField::zero();
+ let one = TestField::one();
+
+ // Round trip
+ assert_eq!(
+ count
+ .decode_result(
+ &count
+ .truncate(count.encode_measurement(&1).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ 1,
+ );
+
+ // Test FLP on valid input.
+ flp_validity_test(
+ &count,
+ &count.encode_measurement(&1).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &count,
+ &count.encode_measurement(&0).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // Test FLP on invalid input.
+ flp_validity_test(
+ &count,
+ &[TestField::from(1337)],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // Try running the validity circuit on an input that's too short.
+ count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err();
+ count
+ .valid(&mut count.gadget(), &[1.into(), 2.into()], &[], 1)
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_sum() {
+ let sum = Sum::new(11).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ // Round trip
+ assert_eq!(
+ sum.decode_result(
+ &sum.truncate(sum.encode_measurement(&27).unwrap()).unwrap(),
+ 1
+ )
+ .unwrap(),
+ 27,
+ );
+
+ // Test FLP on valid input.
+ flp_validity_test(
+ &sum,
+ &sum.encode_measurement(&1337).unwrap(),
+ &ValidityTestCase {
+ expect_valid: true,
+ expected_output: Some(vec![TestField::from(1337)]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(0).unwrap(),
+ &[],
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(2).unwrap(),
+ &[one, zero],
+ &ValidityTestCase {
+ expect_valid: true,
+ expected_output: Some(vec![one]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(9).unwrap(),
+ &[one, zero, one, one, zero, one, one, one, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![TestField::from(237)]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // Test FLP on invalid input.
+ flp_validity_test(
+ &Sum::new(3).unwrap(),
+ &[one, nine, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &Sum::new(5).unwrap(),
+ &[zero, zero, zero, zero, nine],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_average() {
+ let average = Average::new(11).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let ten = TestField::from(10);
+
+ // Testing that average correctly quotients the sum of the measurements
+ // by the number of measurements.
+ assert_eq!(average.decode_result(&[zero], 1).unwrap(), 0.0);
+ assert_eq!(average.decode_result(&[one], 1).unwrap(), 1.0);
+ assert_eq!(average.decode_result(&[one], 2).unwrap(), 0.5);
+ assert_eq!(average.decode_result(&[one], 4).unwrap(), 0.25);
+ assert_eq!(average.decode_result(&[ten], 8).unwrap(), 1.25);
+
+ // round trip of 12 with `num_measurements`=1
+ assert_eq!(
+ average
+ .decode_result(
+ &average
+ .truncate(average.encode_measurement(&12).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ 12.0
+ );
+
+ // round trip of 12 with `num_measurements`=24
+ assert_eq!(
+ average
+ .decode_result(
+ &average
+ .truncate(average.encode_measurement(&12).unwrap())
+ .unwrap(),
+ 24
+ )
+ .unwrap(),
+ 0.5
+ );
+ }
+
+ fn test_histogram<F, S>(f: F)
+ where
+ F: Fn(usize, usize) -> Result<Histogram<TestField, S>, FlpError>,
+ S: ParallelSumGadget<TestField, Mul<TestField>> + Eq + 'static,
+ {
+ let hist = f(3, 2).unwrap();
+ let zero = TestField::zero();
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ assert_eq!(&hist.encode_measurement(&0).unwrap(), &[one, zero, zero]);
+ assert_eq!(&hist.encode_measurement(&1).unwrap(), &[zero, one, zero]);
+ assert_eq!(&hist.encode_measurement(&2).unwrap(), &[zero, zero, one]);
+
+ // Round trip
+ assert_eq!(
+ hist.decode_result(
+ &hist.truncate(hist.encode_measurement(&2).unwrap()).unwrap(),
+ 1
+ )
+ .unwrap(),
+ [0, 0, 1]
+ );
+
+ // Test valid inputs.
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&0).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one, zero, zero]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&1).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero, one, zero]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &hist.encode_measurement(&2).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![zero, zero, one]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // Test invalid inputs.
+ flp_validity_test(
+ &hist,
+ &[zero, zero, nine],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[zero, one, one],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[one, one, one],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ flp_validity_test(
+ &hist,
+ &[zero, zero, zero],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_histogram_serial() {
+ test_histogram(Histogram::<TestField, ParallelSum<TestField, Mul<TestField>>>::new);
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_histogram_parallel() {
+ test_histogram(
+ Histogram::<TestField, ParallelSumMultithreaded<TestField, Mul<TestField>>>::new,
+ );
+ }
+
+ fn test_sum_vec<F, S>(f: F)
+ where
+ F: Fn(usize, usize, usize) -> Result<SumVec<TestField, S>, FlpError>,
+ S: 'static + ParallelSumGadget<TestField, Mul<TestField>> + Eq,
+ {
+ let one = TestField::one();
+ let nine = TestField::from(9);
+
+ // Test on valid inputs.
+ for len in 1..10 {
+ let chunk_length = cmp::max((len as f64).sqrt() as usize, 1);
+ let sum_vec = f(1, len, chunk_length).unwrap();
+ flp_validity_test(
+ &sum_vec,
+ &sum_vec.encode_measurement(&vec![1; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one; len]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+ }
+
+ let len = 100;
+ let sum_vec = f(1, len, 10).unwrap();
+ flp_validity_test(
+ &sum_vec,
+ &sum_vec.encode_measurement(&vec![1; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![one; len]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ let len = 23;
+ let sum_vec = f(4, len, 4).unwrap();
+ flp_validity_test(
+ &sum_vec,
+ &sum_vec.encode_measurement(&vec![9; len]).unwrap(),
+ &ValidityTestCase::<TestField> {
+ expect_valid: true,
+ expected_output: Some(vec![nine; len]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // Test on invalid inputs.
+ for len in 1..10 {
+ let chunk_length = cmp::max((len as f64).sqrt() as usize, 1);
+ let sum_vec = f(1, len, chunk_length).unwrap();
+ flp_validity_test(
+ &sum_vec,
+ &vec![nine; len],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+ }
+
+ let len = 23;
+ let sum_vec = f(2, len, 4).unwrap();
+ flp_validity_test(
+ &sum_vec,
+ &vec![nine; 2 * len],
+ &ValidityTestCase::<TestField> {
+ expect_valid: false,
+ expected_output: None,
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // Round trip
+ let want = vec![1; len];
+ assert_eq!(
+ sum_vec
+ .decode_result(
+ &sum_vec
+ .truncate(sum_vec.encode_measurement(&want).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ want
+ );
+ }
+
+ #[test]
+ fn test_sum_vec_serial() {
+ test_sum_vec(SumVec::<TestField, ParallelSum<TestField, Mul<TestField>>>::new)
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_sum_vec_parallel() {
+ test_sum_vec(SumVec::<TestField, ParallelSumMultithreaded<TestField, Mul<TestField>>>::new)
+ }
+
+ #[test]
+ fn sum_vec_serial_long() {
+ let typ: SumVec<TestField, ParallelSum<TestField, _>> = SumVec::new(1, 1000, 31).unwrap();
+ let input = typ.encode_measurement(&vec![0; 1000]).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn sum_vec_parallel_long() {
+ let typ: SumVec<TestField, ParallelSumMultithreaded<TestField, _>> =
+ SumVec::new(1, 1000, 31).unwrap();
+ let input = typ.encode_measurement(&vec![0; 1000]).unwrap();
+ assert_eq!(input.len(), typ.input_len());
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+ let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
+ let verifier = typ
+ .query(&input, &proof, &query_rand, &joint_rand, 1)
+ .unwrap();
+ assert_eq!(verifier.len(), typ.verifier_len());
+ assert!(typ.decide(&verifier).unwrap());
+ }
+}
+
+#[cfg(test)]
+mod test_utils {
+ use super::*;
+ use crate::field::{random_vector, split_vector, FieldElement};
+
+ pub(crate) struct ValidityTestCase<F> {
+ pub(crate) expect_valid: bool,
+ pub(crate) expected_output: Option<Vec<F>>,
+ // Number of shares to split input and proofs into in `flp_test`.
+ pub(crate) num_shares: usize,
+ }
+
+ pub(crate) fn flp_validity_test<T: Type>(
+ typ: &T,
+ input: &[T::Field],
+ t: &ValidityTestCase<T::Field>,
+ ) -> Result<(), FlpError> {
+ let mut gadgets = typ.gadget();
+
+ if input.len() != typ.input_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected input length: got {}; want {}",
+ input.len(),
+ typ.input_len()
+ )));
+ }
+
+ if typ.query_rand_len() != gadgets.len() {
+ return Err(FlpError::Test(format!(
+ "query rand length: got {}; want {}",
+ typ.query_rand_len(),
+ gadgets.len()
+ )));
+ }
+
+ let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
+ let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
+ let query_rand = random_vector(typ.query_rand_len()).unwrap();
+
+ // Run the validity circuit.
+ let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?;
+ if v != T::Field::zero() && t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "expected valid input: valid() returned {v}"
+ )));
+ }
+ if v == T::Field::zero() && !t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "expected invalid input: valid() returned {v}"
+ )));
+ }
+
+ // Generate the proof.
+ let proof = typ.prove(input, &prove_rand, &joint_rand)?;
+ if proof.len() != typ.proof_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected proof length: got {}; want {}",
+ proof.len(),
+ typ.proof_len()
+ )));
+ }
+
+ // Query the proof.
+ let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?;
+ if verifier.len() != typ.verifier_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected verifier length: got {}; want {}",
+ verifier.len(),
+ typ.verifier_len()
+ )));
+ }
+
+ // Decide if the input is valid.
+ let res = typ.decide(&verifier)?;
+ if res != t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "decision is {}; want {}",
+ res, t.expect_valid,
+ )));
+ }
+
+ // Run distributed FLP.
+ let input_shares: Vec<Vec<T::Field>> = split_vector(input, t.num_shares)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, t.num_shares)
+ .unwrap()
+ .into_iter()
+ .collect();
+
+ let verifier: Vec<T::Field> = (0..t.num_shares)
+ .map(|i| {
+ typ.query(
+ &input_shares[i],
+ &proof_shares[i],
+ &query_rand,
+ &joint_rand,
+ t.num_shares,
+ )
+ .unwrap()
+ })
+ .reduce(|mut left, right| {
+ for (x, y) in left.iter_mut().zip(right.iter()) {
+ *x += *y;
+ }
+ left
+ })
+ .unwrap();
+
+ let res = typ.decide(&verifier)?;
+ if res != t.expect_valid {
+ return Err(FlpError::Test(format!(
+ "distributed decision is {}; want {}",
+ res, t.expect_valid,
+ )));
+ }
+
+ // Try verifying various proof mutants.
+ for i in 0..proof.len() {
+ let mut mutated_proof = proof.clone();
+ mutated_proof[i] += T::Field::one();
+ let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?;
+ if typ.decide(&verifier)? {
+ return Err(FlpError::Test(format!(
+ "decision for proof mutant {} is {}; want {}",
+ i, true, false,
+ )));
+ }
+ }
+
+ // Try verifying a proof that is too short.
+ let mut mutated_proof = proof.clone();
+ mutated_proof.truncate(gadgets[0].arity() - 1);
+ if typ
+ .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .is_ok()
+ {
+ return Err(FlpError::Test(
+ "query on short proof succeeded; want failure".to_string(),
+ ));
+ }
+
+ // Try verifying a proof that is too long.
+ let mut mutated_proof = proof;
+ mutated_proof.extend_from_slice(&[T::Field::one(); 17]);
+ if typ
+ .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
+ .is_ok()
+ {
+ return Err(FlpError::Test(
+ "query on long proof succeeded; want failure".to_string(),
+ ));
+ }
+
+ if let Some(ref want) = t.expected_output {
+ let got = typ.truncate(input.to_vec())?;
+
+ if got.len() != typ.output_len() {
+ return Err(FlpError::Test(format!(
+ "unexpected output length: got {}; want {}",
+ got.len(),
+ typ.output_len()
+ )));
+ }
+
+ if &got != want {
+ return Err(FlpError::Test(format!(
+ "unexpected output: got {got:?}; want {want:?}"
+ )));
+ }
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(feature = "experimental")]
+#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
+pub mod fixedpoint_l2;
diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
new file mode 100644
index 0000000000..b5aa2fd116
--- /dev/null
+++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
@@ -0,0 +1,899 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! A [`Type`] for summing vectors of fixed point numbers where the
+//! [L2 norm](https://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm)
+//! of each vector is bounded by `1` and adding [discrete Gaussian
+//! noise](https://arxiv.org/abs/2004.00010) in order to achieve server
+//! differential privacy.
+//!
+//! In the following a high level overview over the inner workings of this type
+//! is given and implementation details are discussed. It is not necessary for
+//! using the type, but it should be helpful when trying to understand the
+//! implementation.
+//!
+//! ### Overview
+//!
+//! Clients submit a vector of numbers whose values semantically lie in `[-1,1)`,
+//! together with a norm in the range `[0,1)`. The validation circuit checks that
+//! the norm of the vector is equal to the submitted norm, while the encoding
+//! guarantees that the submitted norm lies in the correct range.
+//!
+//! The bound on the L2 norm allows calibration of discrete Gaussian noise added
+//! after aggregation, making the procedure differentially private.
+//!
+//! ### Submission layout
+//!
+//! The client submissions contain a share of their vector and the norm
+//! they claim it has.
+//! The submission is a vector of field elements laid out as follows:
+//! ```text
+//! |---- bits_per_entry * entries ----|---- bits_for_norm ----|
+//! ^ ^
+//! \- the input vector entries |
+//! \- the encoded norm
+//! ```
+//!
+//! ### Different number encodings
+//!
+//! Let `n` denote the number of bits of the chosen fixed-point type.
+//! Numbers occur in 5 different representations:
+//! 1. Clients have a vector whose entries are fixed point numbers. Only those
+//! fixed point types are supported where the numbers lie in the range
+//! `[-1,1)`.
+//! 2. Because norm computation happens in the validation circuit, it is done
+//! on entries encoded as field elements. That is, the same vector entries
+//! are now represented by integers in the range `[0,2^n)`, where `-1` is
+//! represented by `0` and `+1` by `2^n`.
+//! 3. Because the field is not necessarily exactly of size `2^n`, but might be
+//! larger, it is not enough to encode a vector entry as in (2.) and submit
+//! it to the aggregator. Instead, in order to make sure that all submitted
+//! values are in the correct range, they are bit-encoded. (This is the same
+//! as what happens in the [`Sum`](crate::flp::types::Sum) type.)
+//! This means that instead of sending a field element in the range `[0,2^n)`,
+//! we send `n` field elements representing the bit encoding. The validation
+//! circuit can verify that all submitted "bits" are indeed either `0` or `1`.
+//! 4. The computed and submitted norms are treated similar to the vector
+//! entries, but they have a different number of bits, namely `2n-2`.
+//! 5. As the aggregation result is a pointwise sum of the client vectors,
+//! the numbers no longer (semantically) lie in the range `[-1,1)`, and cannot
+//! be represented by the same fixed point type as the input. Instead the
+//! decoding happens directly into a vector of floats.
+//!
+//! ### Fixed point encoding
+//!
+//! Submissions consist of encoded fixed-point numbers in `[-1,1)` represented as
+//! field elements in `[0,2^n)`, where n is the number of bits the fixed-point
+//! representation has. Encoding and decoding is handled by the associated functions
+//! of the [`CompatibleFloat`] trait. Semantically, the following function describes
+//! how a fixed-point value `x` in range `[-1,1)` is converted to a field integer:
+//! ```text
+//! enc : [-1,1) -> [0,2^n)
+//! enc(x) = 2^(n-1) * x + 2^(n-1)
+//! ```
+//! The inverse is:
+//! ```text
+//! dec : [0,2^n) -> [-1,1)
+//! dec(y) = (y - 2^(n-1)) * 2^(1-n)
+//! ```
+//! Note that these functions only make sense when interpreting all occuring
+//! numbers as real numbers. Since our signed fixed-point numbers are encoded as
+//! two's complement integers, the computation that happens in
+//! [`CompatibleFloat::to_field_integer`] is actually simpler.
+//!
+//! ### Value `1`
+//!
+//! We actually do not allow the submitted norm or vector entries to be
+//! exactly `1`, but rather require them to be strictly less. Supporting `1` would
+//! entail a more fiddly encoding and is not necessary for our usecase.
+//! The largest representable vector entry can be computed by `dec(2^n-1)`.
+//! For example, it is `0.999969482421875` for `n = 16`.
+//!
+//! ### Norm computation
+//!
+//! The L2 norm of a vector xs of numbers in `[-1,1)` is given by:
+//! ```text
+//! norm(xs) = sqrt(sum_{x in xs} x^2)
+//! ```
+//! Instead of computing the norm, we make two simplifications:
+//! 1. We ignore the square root, which means that we are actually computing
+//! the square of the norm.
+//! 2. We want our norm computation result to be integral and in the range `[0, 2^(2n-2))`,
+//! so we can represent it in our field integers. We achieve this by multiplying with `2^(2n-2)`.
+//! This means that what is actually computed in this type is the following:
+//! ```text
+//! our_norm(xs) = 2^(2n-2) * norm(xs)^2
+//! ```
+//!
+//! Explained more visually, `our_norm()` is a composition of three functions:
+//!
+//! ```text
+//! map of dec() norm() "mult with 2^(2n-2)"
+//! vector of [0,2^n) -> vector of [-1,1) -> [0,1) -> [0,2^(2n-2))
+//! ^ ^
+//! | |
+//! fractions with denom of 2^(n-1) fractions with denom of 2^(2n-2)
+//! ```
+//! (Note that the ranges on the LHS and RHS of `"mult with 2^(2n-2)"` are stated
+//! here for vectors with a norm less than `1`.)
+//!
+//! Given a vector `ys` of numbers in the field integer encoding (in `[0,2^n)`),
+//! this gives the following equation:
+//! ```text
+//! our_norm_on_encoded(ys) = our_norm([dec(y) for y in ys])
+//! = 2^(2n-2) * sum_{y in ys} ((y - 2^(n-1)) * 2^(1-n))^2
+//! = 2^(2n-2)
+//! * sum_{y in ys} y^2 - 2*y*2^(n-1) + (2^(n-1))^2
+//! * 2^(1-n)^2
+//! = sum_{y in ys} y^2 - (2^n)*y + 2^(2n-2)
+//! ```
+//!
+//! Let `d` denote the number of the vector entries. The maximal value the result
+//! of `our_norm_on_encoded()` can take occurs in the case where all entries are
+//! `2^n-1`, in which case `d * 2^(2n-2)` is an upper bound to the result. The
+//! finite field used for encoding must be at least as large as this.
+//! For validating that the norm of the submitted vector lies in the correct
+//! range, consider the following:
+//! - The result of `norm(xs)` should be in `[0,1)`.
+//! - Thus, the result of `our_norm(xs)` should be in `[0,2^(2n-2))`.
+//! - The result of `our_norm_on_encoded(ys)` should be in `[0,2^(2n-2))`.
+//! This means that the valid norms are exactly those representable with `2n-2`
+//! bits.
+//!
+//! ### Noise and Differential Privacy
+//!
+//! Bounding the submission norm bounds the impact that changing a single
+//! client's submission can have on the aggregate. That is, the so-called
+//! L2-sensitivity of the procedure is equal to two times the norm bound, namely
+//! `2^n`. Therefore, adding discrete Gaussian noise with standard deviation
+//! `sigma = `(2^n)/epsilon` for some `epsilon` will make the procedure [`(epsilon^2)/2`
+//! zero-concentrated differentially private](https://arxiv.org/abs/2004.00010).
+//! `epsilon` is given as a parameter to the `add_noise_to_result` function, as part of the
+//! `dp_strategy` argument of type [`ZCdpDiscreteGaussian`].
+//!
+//! ### Differences in the computation because of distribution
+//!
+//! In `decode_result()`, what is decoded are not the submitted entries of a
+//! single client, but the sum of the the entries of all clients. We have to
+//! take this into account, and cannot directly use the `dec()` function from
+//! above. Instead we use:
+//! ```text
+//! dec'(y) = y * 2^(1-n) - c
+//! ```
+//! Here, `c` is the number of clients.
+//!
+//! ### Naming in the implementation
+//!
+//! The following names are used:
+//! - `self.bits_per_entry` is `n`
+//! - `self.entries` is `d`
+//! - `self.bits_for_norm` is `2n-2`
+//!
+
+pub mod compatible_float;
+
+use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError};
+use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt};
+use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
+use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat;
+use crate::flp::types::parallel_sum_range_checks;
+use crate::flp::{FlpError, Gadget, Type, TypeWithNoise};
+use crate::vdaf::xof::SeedStreamSha3;
+use fixed::traits::Fixed;
+use num_bigint::{BigInt, BigUint, TryFromBigIntError};
+use num_integer::Integer;
+use num_rational::Ratio;
+use rand::{distributions::Distribution, Rng};
+use rand_core::SeedableRng;
+use std::{convert::TryFrom, convert::TryInto, fmt::Debug, marker::PhantomData};
+
+/// The fixed point vector sum data type. Each measurement is a vector of fixed point numbers of
+/// type `T`, and the aggregate result is the float vector of the sum of the measurements.
+///
+/// The validity circuit verifies that the L2 norm of each measurement is bounded by 1.
+///
+/// The [*fixed* crate](https://crates.io/crates/fixed) is used for fixed point numbers, in
+/// particular, exactly the following types are supported:
+/// `FixedI16<U15>`, `FixedI32<U31>` and `FixedI64<U63>`.
+///
+/// The type implements the [`TypeWithNoise`] trait. The `add_noise_to_result` function adds
+/// discrete Gaussian noise to an aggregate share, calibrated to the passed privacy budget.
+/// This will result in the aggregate satisfying zero-concentrated differential privacy.
+///
+/// Depending on the size of the vector that needs to be transmitted, a corresponding field type has
+/// to be chosen for `F`. For a `n`-bit fixed point type and a `d`-dimensional vector, the field
+/// modulus needs to be larger than `d * 2^(2n-2)` so there are no overflows during norm validity
+/// computation.
+#[derive(Clone, PartialEq, Eq)]
+pub struct FixedPointBoundedL2VecSum<
+ T: Fixed,
+ SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Clone,
+ SMul: ParallelSumGadget<Field128, Mul<Field128>> + Clone,
+> {
+ bits_per_entry: usize,
+ entries: usize,
+ bits_for_norm: usize,
+ norm_summand_poly: Vec<Field128>,
+ phantom: PhantomData<(T, SPoly, SMul)>,
+
+ // range/position constants
+ range_norm_begin: usize,
+ range_norm_end: usize,
+
+ // configuration of parallel sum gadgets
+ gadget0_calls: usize,
+ gadget0_chunk_length: usize,
+ gadget1_calls: usize,
+ gadget1_chunk_length: usize,
+}
+
+impl<T, SPoly, SMul> Debug for FixedPointBoundedL2VecSum<T, SPoly, SMul>
+where
+ T: Fixed,
+ SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Clone,
+ SMul: ParallelSumGadget<Field128, Mul<Field128>> + Clone,
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("FixedPointBoundedL2VecSum")
+ .field("bits_per_entry", &self.bits_per_entry)
+ .field("entries", &self.entries)
+ .finish()
+ }
+}
+
+impl<T, SPoly, SMul> FixedPointBoundedL2VecSum<T, SPoly, SMul>
+where
+ T: Fixed,
+ SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Clone,
+ SMul: ParallelSumGadget<Field128, Mul<Field128>> + Clone,
+{
+ /// Return a new [`FixedPointBoundedL2VecSum`] type parameter. Each value of this type is a
+ /// fixed point vector with `entries` entries.
+ pub fn new(entries: usize) -> Result<Self, FlpError> {
+ // (0) initialize constants
+ let fi_one = u128::from(Field128::one());
+
+ // (I) Check that the fixed type is compatible.
+ //
+ // We only support fixed types that encode values in [-1,1].
+ // These have a single integer bit.
+ if <T as Fixed>::INT_NBITS != 1 {
+ return Err(FlpError::Encode(format!(
+ "Expected fixed point type with one integer bit, but got {}.",
+ <T as Fixed>::INT_NBITS,
+ )));
+ }
+
+ // Compute number of bits of an entry, and check that an entry fits
+ // into the field.
+ let bits_per_entry: usize = (<T as Fixed>::INT_NBITS + <T as Fixed>::FRAC_NBITS)
+ .try_into()
+ .map_err(|_| FlpError::Encode("Could not convert u32 into usize.".to_string()))?;
+ if !Field128::valid_integer_bitlength(bits_per_entry) {
+ return Err(FlpError::Encode(format!(
+ "fixed point type bit length ({bits_per_entry}) too large for field modulus",
+ )));
+ }
+
+ // (II) Check that the field is large enough for the norm.
+ //
+ // Valid norms encoded as field integers lie in [0,2^(2*bits - 2)).
+ let bits_for_norm = 2 * bits_per_entry - 2;
+ if !Field128::valid_integer_bitlength(bits_for_norm) {
+ return Err(FlpError::Encode(format!(
+ "maximal norm bit length ({bits_for_norm}) too large for field modulus",
+ )));
+ }
+
+ // In order to compare the actual norm of the vector with the claimed
+ // norm, the field needs to be able to represent all numbers that can
+ // occur during the computation of the norm of any submitted vector,
+ // even if its norm is not bounded by 1. Because of our encoding, an
+ // upper bound to that value is `entries * 2^(2*bits - 2)` (see docs of
+ // compute_norm_of_entries for details). It has to fit into the field.
+ let err = Err(FlpError::Encode(format!(
+ "number of entries ({entries}) not compatible with field size",
+ )));
+
+ if let Some(val) = (entries as u128).checked_mul(1 << bits_for_norm) {
+ if val >= Field128::modulus() {
+ return err;
+ }
+ } else {
+ return err;
+ }
+
+ // Construct the polynomial that computes a part of the norm for a
+ // single vector entry.
+ //
+ // the linear part is 2^n,
+ // the constant part is 2^(2n-2),
+ // the polynomial is:
+ // p(y) = 2^(2n-2) + -(2^n) * y + 1 * y^2
+ let linear_part = fi_one << bits_per_entry;
+ let constant_part = fi_one << (bits_per_entry + bits_per_entry - 2);
+ let norm_summand_poly = vec![
+ Field128::from(constant_part),
+ -Field128::from(linear_part),
+ Field128::one(),
+ ];
+
+ // Compute chunk length and number of calls for parallel sum gadgets.
+ let len0 = bits_per_entry * entries + bits_for_norm;
+ let gadget0_chunk_length = std::cmp::max(1, (len0 as f64).sqrt() as usize);
+ let gadget0_calls = (len0 + gadget0_chunk_length - 1) / gadget0_chunk_length;
+
+ let len1 = entries;
+ let gadget1_chunk_length = std::cmp::max(1, (len1 as f64).sqrt() as usize);
+ let gadget1_calls = (len1 + gadget1_chunk_length - 1) / gadget1_chunk_length;
+
+ Ok(Self {
+ bits_per_entry,
+ entries,
+ bits_for_norm,
+ norm_summand_poly,
+ phantom: PhantomData,
+
+ // range constants
+ range_norm_begin: entries * bits_per_entry,
+ range_norm_end: entries * bits_per_entry + bits_for_norm,
+
+ // configuration of parallel sum gadgets
+ gadget0_calls,
+ gadget0_chunk_length,
+ gadget1_calls,
+ gadget1_chunk_length,
+ })
+ }
+
+ /// This noising function can be called on the aggregate share to make
+ /// the entire aggregation process differentially private. The noise is
+ /// calibrated to result in a guarantee of `1/2 * epsilon^2` zero-concentrated
+ /// differential privacy, where `epsilon` is given by `dp_strategy.budget`.
+ fn add_noise<R: Rng>(
+ &self,
+ dp_strategy: &ZCdpDiscreteGaussian,
+ agg_result: &mut [Field128],
+ rng: &mut R,
+ ) -> Result<(), FlpError> {
+ // generate and add discrete gaussian noise for each entry
+
+ // 0. Compute sensitivity of aggregation, namely 2^n.
+ let sensitivity = BigUint::from(2u128).pow(self.bits_per_entry as u32);
+ // Also create a BigInt containing the field modulus.
+ let modulus = BigInt::from(Field128::modulus());
+
+ // 1. initialize sampler
+ let sampler = dp_strategy.create_distribution(Ratio::from_integer(sensitivity))?;
+
+ // 2. Generate noise for each slice entry and apply it.
+ for entry in agg_result.iter_mut() {
+ // (a) Generate noise.
+ let noise: BigInt = sampler.sample(rng);
+
+ // (b) Put noise into field.
+ //
+ // The noise is generated as BigInt, but has to fit into the Field128,
+ // which has modulus `Field128::modulus()`. Thus we use `BigInt::mod_floor()`
+ // to calculate `noise mod modulus`. This value fits into `u128`, and
+ // can be then put into the field.
+ //
+ // Note: we cannot use the operator `%` here, since it is not the mathematical
+ // modulus operation: for negative inputs and positive modulus it gives a
+ // negative result!
+ let noise: BigInt = noise.mod_floor(&modulus);
+ let noise: u128 = noise.try_into().map_err(|e: TryFromBigIntError<BigInt>| {
+ FlpError::DifferentialPrivacy(DpError::BigIntConversion(e))
+ })?;
+ let f_noise = Field128::from(Field128::valid_integer_try_from::<u128>(noise)?);
+
+ // (c) Apply noise to each entry of the aggregate share.
+ *entry += f_noise;
+ }
+
+ Ok(())
+ }
+}
+
+impl<T, SPoly, SMul> Type for FixedPointBoundedL2VecSum<T, SPoly, SMul>
+where
+ T: Fixed + CompatibleFloat,
+ SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static,
+ SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static,
+{
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = Vec<T>;
+ type AggregateResult = Vec<f64>;
+ type Field = Field128;
+
+ fn encode_measurement(&self, fp_entries: &Vec<T>) -> Result<Vec<Field128>, FlpError> {
+ if fp_entries.len() != self.entries {
+ return Err(FlpError::Encode("unexpected input length".into()));
+ }
+
+ // Convert the fixed-point encoded input values to field integers. We do
+ // this once here because we need them for encoding but also for
+ // computing the norm.
+ let integer_entries = fp_entries.iter().map(|x| x.to_field_integer());
+
+ // (I) Vector entries.
+ // Encode the integer entries bitwise, and write them into the `encoded`
+ // vector.
+ let mut encoded: Vec<Field128> =
+ vec![Field128::zero(); self.bits_per_entry * self.entries + self.bits_for_norm];
+ for (l, entry) in integer_entries.clone().enumerate() {
+ Field128::fill_with_bitvector_representation(
+ &entry,
+ &mut encoded[l * self.bits_per_entry..(l + 1) * self.bits_per_entry],
+ )?;
+ }
+
+ // (II) Vector norm.
+ // Compute the norm of the input vector.
+ let field_entries = integer_entries.map(Field128::from);
+ let norm = compute_norm_of_entries(field_entries, self.bits_per_entry)?;
+ let norm_int = u128::from(norm);
+
+ // Write the norm into the `entries` vector.
+ Field128::fill_with_bitvector_representation(
+ &norm_int,
+ &mut encoded[self.range_norm_begin..self.range_norm_end],
+ )?;
+
+ Ok(encoded)
+ }
+
+ fn decode_result(
+ &self,
+ data: &[Field128],
+ num_measurements: usize,
+ ) -> Result<Vec<f64>, FlpError> {
+ if data.len() != self.entries {
+ return Err(FlpError::Decode("unexpected input length".into()));
+ }
+ let num_measurements = match u128::try_from(num_measurements) {
+ Ok(m) => m,
+ Err(_) => {
+ return Err(FlpError::Decode(
+ "number of clients is too large to fit into u128".into(),
+ ))
+ }
+ };
+ let mut res = Vec::with_capacity(data.len());
+ for d in data {
+ let decoded = <T as CompatibleFloat>::to_float(*d, num_measurements);
+ res.push(decoded);
+ }
+ Ok(res)
+ }
+
+ fn gadget(&self) -> Vec<Box<dyn Gadget<Field128>>> {
+ // This gadget checks that a field element is zero or one.
+ // It is called for all the "bits" of the encoded entries
+ // and of the encoded norm.
+ let gadget0 = SMul::new(Mul::new(self.gadget0_calls), self.gadget0_chunk_length);
+
+ // This gadget computes the square of a fixed point number, operating on
+ // its encoding as a field element. It is called on each entry during
+ // norm computation.
+ let gadget1 = SPoly::new(
+ PolyEval::new(self.norm_summand_poly.clone(), self.gadget1_calls),
+ self.gadget1_chunk_length,
+ );
+
+ vec![Box::new(gadget0), Box::new(gadget1)]
+ }
+
+ fn valid(
+ &self,
+ g: &mut Vec<Box<dyn Gadget<Field128>>>,
+ input: &[Field128],
+ joint_rand: &[Field128],
+ num_shares: usize,
+ ) -> Result<Field128, FlpError> {
+ self.valid_call_check(input, joint_rand)?;
+
+ let f_num_shares = Field128::from(Field128::valid_integer_try_from::<usize>(num_shares)?);
+ let num_shares_inverse = Field128::one() / f_num_shares;
+
+ // Ensure that all submitted field elements are either 0 or 1.
+ // This is done for:
+ // (I) all vector entries (each of them encoded in `self.bits_per_entry`
+ // field elements)
+ // (II) the submitted norm (encoded in `self.bits_for_norm` field
+ // elements)
+ //
+ // Since all input vector entry (field-)bits, as well as the norm bits,
+ // are contiguous, we do the check directly for all bits from 0 to
+ // entries*bits_per_entry + bits_for_norm.
+ //
+ // In order to keep the proof size down, this is done using the
+ // `ParallelSum` gadget. For a similar application see the `SumVec`
+ // type.
+ let range_check = parallel_sum_range_checks(
+ &mut g[0],
+ &input[..self.range_norm_end],
+ joint_rand[0],
+ self.gadget0_chunk_length,
+ num_shares,
+ )?;
+
+ // Compute the norm of the entries and ensure that it is the same as the
+ // submitted norm. There are exactly enough bits such that a submitted
+ // norm is always a valid norm (semantically in the range [0,1]). By
+ // comparing submitted with actual, we make sure the actual norm is
+ // valid.
+ //
+ // The function to compute here (see explanatory comment at the top) is
+ // norm(ys) = sum_{y in ys} y^2 - (2^n)*y + 2^(2n-2)
+ //
+ // This is done by the `ParallelSum` gadget `g[1]`, which evaluates the
+ // inner polynomial on each (decoded) vector entry, and then sums the
+ // results. Note that the gadget is not called on the whole vector at
+ // once, but sequentially on chunks of size `self.gadget1_chunk_length` of
+ // it. The results of these calls are accumulated in the `outp` variable.
+ //
+ // decode the bit-encoded entries into elements in the range [0,2^n):
+ let decoded_entries: Result<Vec<_>, _> = input[0..self.entries * self.bits_per_entry]
+ .chunks(self.bits_per_entry)
+ .map(Field128::decode_from_bitvector_representation)
+ .collect();
+
+ // run parallel sum gadget on the decoded entries
+ let computed_norm = {
+ let mut outp = Field128::zero();
+
+ // Chunks which are too short need to be extended with a share of the
+ // encoded zero value, that is: 1/num_shares * (2^(n-1))
+ let fi_one = u128::from(Field128::one());
+ let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1));
+ let zero_enc_share = zero_enc * num_shares_inverse;
+
+ for chunk in decoded_entries?.chunks(self.gadget1_chunk_length) {
+ let d = chunk.len();
+ if d == self.gadget1_chunk_length {
+ outp += g[1].call(chunk)?;
+ } else {
+ // If the chunk is smaller than the chunk length, extend
+ // chunk with zeros.
+ let mut padded_chunk: Vec<_> = chunk.to_owned();
+ padded_chunk.resize(self.gadget1_chunk_length, zero_enc_share);
+ outp += g[1].call(&padded_chunk)?;
+ }
+ }
+
+ outp
+ };
+
+ // The submitted norm is also decoded from its bit-encoding, and
+ // compared with the computed norm.
+ let submitted_norm_enc = &input[self.range_norm_begin..self.range_norm_end];
+ let submitted_norm = Field128::decode_from_bitvector_representation(submitted_norm_enc)?;
+
+ let norm_check = computed_norm - submitted_norm;
+
+ // Finally, we require both checks to be successful by computing a
+ // random linear combination of them.
+ let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * norm_check;
+ Ok(out)
+ }
+
+ fn truncate(&self, input: Vec<Field128>) -> Result<Vec<Self::Field>, FlpError> {
+ self.truncate_call_check(&input)?;
+
+ let mut decoded_vector = vec![];
+
+ for i_entry in 0..self.entries {
+ let start = i_entry * self.bits_per_entry;
+ let end = (i_entry + 1) * self.bits_per_entry;
+
+ let decoded = Field128::decode_from_bitvector_representation(&input[start..end])?;
+ decoded_vector.push(decoded);
+ }
+ Ok(decoded_vector)
+ }
+
+ fn input_len(&self) -> usize {
+ self.bits_per_entry * self.entries + self.bits_for_norm
+ }
+
+ fn proof_len(&self) -> usize {
+ // computed via
+ // `gadget.arity() + gadget.degree()
+ // * ((1 + gadget.calls()).next_power_of_two() - 1) + 1;`
+ let proof_gadget_0 = (self.gadget0_chunk_length * 2)
+ + 2 * ((1 + self.gadget0_calls).next_power_of_two() - 1)
+ + 1;
+ let proof_gadget_1 = (self.gadget1_chunk_length)
+ + 2 * ((1 + self.gadget1_calls).next_power_of_two() - 1)
+ + 1;
+
+ proof_gadget_0 + proof_gadget_1
+ }
+
+ fn verifier_len(&self) -> usize {
+ self.gadget0_chunk_length * 2 + self.gadget1_chunk_length + 3
+ }
+
+ fn output_len(&self) -> usize {
+ self.entries
+ }
+
+ fn joint_rand_len(&self) -> usize {
+ 2
+ }
+
+ fn prove_rand_len(&self) -> usize {
+ self.gadget0_chunk_length * 2 + self.gadget1_chunk_length
+ }
+
+ fn query_rand_len(&self) -> usize {
+ 2
+ }
+}
+
+impl<T, SPoly, SMul> TypeWithNoise<ZCdpDiscreteGaussian>
+ for FixedPointBoundedL2VecSum<T, SPoly, SMul>
+where
+ T: Fixed + CompatibleFloat,
+ SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static,
+ SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static,
+{
+ fn add_noise_to_result(
+ &self,
+ dp_strategy: &ZCdpDiscreteGaussian,
+ agg_result: &mut [Self::Field],
+ _num_measurements: usize,
+ ) -> Result<(), FlpError> {
+ self.add_noise(dp_strategy, agg_result, &mut SeedStreamSha3::from_entropy())
+ }
+}
+
+/// Compute the square of the L2 norm of a vector of fixed-point numbers encoded as field elements.
+///
+/// * `entries` - Iterator over the vector entries.
+/// * `bits_per_entry` - Number of bits one entry has.
+fn compute_norm_of_entries<Fs>(entries: Fs, bits_per_entry: usize) -> Result<Field128, FlpError>
+where
+ Fs: IntoIterator<Item = Field128>,
+{
+ let fi_one = u128::from(Field128::one());
+
+ // The value that is computed here is:
+ // sum_{y in entries} 2^(2n-2) + -(2^n) * y + 1 * y^2
+ //
+ // Check out the norm computation bit in the explanatory comment block for
+ // more information.
+ //
+ // Initialize `norm_accumulator`.
+ let mut norm_accumulator = Field128::zero();
+
+ // constants
+ let linear_part = fi_one << bits_per_entry; // = 2^(2n-2)
+ let constant_part = fi_one << (bits_per_entry + bits_per_entry - 2); // = 2^n
+
+ // Add term for a given `entry` to `norm_accumulator`.
+ for entry in entries.into_iter() {
+ let summand =
+ entry * entry + Field128::from(constant_part) - Field128::from(linear_part) * (entry);
+ norm_accumulator += summand;
+ }
+ Ok(norm_accumulator)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::dp::{Rational, ZCdpBudget};
+ use crate::field::{random_vector, Field128, FieldElement};
+ use crate::flp::gadgets::ParallelSum;
+ use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase};
+ use crate::vdaf::xof::SeedStreamSha3;
+ use fixed::types::extra::{U127, U14, U63};
+ use fixed::{FixedI128, FixedI16, FixedI64};
+ use fixed_macro::fixed;
+ use rand::SeedableRng;
+
+ #[test]
+ fn test_bounded_fpvec_sum_parallel_fp16() {
+ let fp16_4_inv = fixed!(0.25: I1F15);
+ let fp16_8_inv = fixed!(0.125: I1F15);
+ let fp16_16_inv = fixed!(0.0625: I1F15);
+
+ let fp16_vec = vec![fp16_4_inv, fp16_8_inv, fp16_16_inv];
+
+ // the encoded vector has the following entries:
+ // enc(0.25) = 2^(n-1) * 0.25 + 2^(n-1) = 40960
+ // enc(0.125) = 2^(n-1) * 0.125 + 2^(n-1) = 36864
+ // enc(0.0625) = 2^(n-1) * 0.0625 + 2^(n-1) = 34816
+ test_fixed(fp16_vec, vec![40960, 36864, 34816]);
+ }
+
+ #[test]
+ fn test_bounded_fpvec_sum_parallel_fp32() {
+ let fp32_4_inv = fixed!(0.25: I1F31);
+ let fp32_8_inv = fixed!(0.125: I1F31);
+ let fp32_16_inv = fixed!(0.0625: I1F31);
+
+ let fp32_vec = vec![fp32_4_inv, fp32_8_inv, fp32_16_inv];
+ // computed as above but with n=32
+ test_fixed(fp32_vec, vec![2684354560, 2415919104, 2281701376]);
+ }
+
+ #[test]
+ fn test_bounded_fpvec_sum_parallel_fp64() {
+ let fp64_4_inv = fixed!(0.25: I1F63);
+ let fp64_8_inv = fixed!(0.125: I1F63);
+ let fp64_16_inv = fixed!(0.0625: I1F63);
+
+ let fp64_vec = vec![fp64_4_inv, fp64_8_inv, fp64_16_inv];
+ // computed as above but with n=64
+ test_fixed(
+ fp64_vec,
+ vec![
+ 11529215046068469760,
+ 10376293541461622784,
+ 9799832789158199296,
+ ],
+ );
+ }
+
+ fn test_fixed<F: Fixed>(fp_vec: Vec<F>, enc_vec: Vec<u128>)
+ where
+ F: CompatibleFloat,
+ {
+ let n: usize = (F::INT_NBITS + F::FRAC_NBITS).try_into().unwrap();
+
+ type Ps = ParallelSum<Field128, PolyEval<Field128>>;
+ type Psm = ParallelSum<Field128, Mul<Field128>>;
+
+ let vsum: FixedPointBoundedL2VecSum<F, Ps, Psm> =
+ FixedPointBoundedL2VecSum::new(3).unwrap();
+ let one = Field128::one();
+ // Round trip
+ assert_eq!(
+ vsum.decode_result(
+ &vsum
+ .truncate(vsum.encode_measurement(&fp_vec).unwrap())
+ .unwrap(),
+ 1
+ )
+ .unwrap(),
+ vec!(0.25, 0.125, 0.0625)
+ );
+
+ // Noise
+ let mut v = vsum
+ .truncate(vsum.encode_measurement(&fp_vec).unwrap())
+ .unwrap();
+ let strategy = ZCdpDiscreteGaussian::from_budget(ZCdpBudget::new(
+ Rational::from_unsigned(100u8, 3u8).unwrap(),
+ ));
+ vsum.add_noise(&strategy, &mut v, &mut SeedStreamSha3::from_seed([0u8; 16]))
+ .unwrap();
+ assert_eq!(
+ vsum.decode_result(&v, 1).unwrap(),
+ match n {
+ // sensitivity depends on encoding so the noise differs
+ 16 => vec![0.150604248046875, 0.139373779296875, -0.03759765625],
+ 32 => vec![0.3051439793780446, 0.1226568529382348, 0.08595499861985445],
+ 64 => vec![0.2896077990915178, 0.16115188007715098, 0.0788390114728425],
+ _ => panic!("unsupported bitsize"),
+ }
+ );
+
+ // encoded norm does not match computed norm
+ let mut input: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap();
+ assert_eq!(input[0], Field128::zero());
+ input[0] = one; // it was zero
+ flp_validity_test(
+ &vsum,
+ &input,
+ &ValidityTestCase::<Field128> {
+ expect_valid: false,
+ expected_output: Some(vec![
+ Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0
+ Field128::from(enc_vec[1]),
+ Field128::from(enc_vec[2]),
+ ]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // encoding contains entries that are not zero or one
+ let mut input2: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap();
+ input2[0] = one + one;
+ flp_validity_test(
+ &vsum,
+ &input2,
+ &ValidityTestCase::<Field128> {
+ expect_valid: false,
+ expected_output: Some(vec![
+ Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0
+ Field128::from(enc_vec[1]),
+ Field128::from(enc_vec[2]),
+ ]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // norm is too big
+ // 2^n - 1, the field element encoded by the all-1 vector
+ let one_enc = Field128::from(((2_u128) << (n - 1)) - 1);
+ flp_validity_test(
+ &vsum,
+ &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors
+ &ValidityTestCase::<Field128> {
+ expect_valid: false,
+ expected_output: Some(vec![one_enc; 3]),
+ num_shares: 3,
+ },
+ )
+ .unwrap();
+
+ // invalid submission length, should be 3n + (2*n - 2) for a
+ // 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm.
+ let joint_rand = random_vector(vsum.joint_rand_len()).unwrap();
+ vsum.valid(
+ &mut vsum.gadget(),
+ &vec![one; 3 * n + 2 * n - 1],
+ &joint_rand,
+ 1,
+ )
+ .unwrap_err();
+
+ // test that the zero vector has correct norm, where zero is encoded as:
+ // enc(0) = 2^(n-1) * 0 + 2^(n-1)
+ let zero_enc = Field128::from((2_u128) << (n - 2));
+ {
+ let entries = vec![zero_enc; 3];
+ let norm = compute_norm_of_entries(entries, vsum.bits_per_entry).unwrap();
+ let expected_norm = Field128::from(0);
+ assert_eq!(norm, expected_norm);
+ }
+
+ // ensure that no overflow occurs with largest possible norm
+ {
+ // the largest possible entries (2^n-1)
+ let entries = vec![one_enc; 3];
+ let norm = compute_norm_of_entries(entries, vsum.bits_per_entry).unwrap();
+ let expected_norm = Field128::from(3 * (1 + (1 << (2 * n - 2)) - (1 << n)));
+ // = 3 * ((2^n-1)^2 - (2^n-1)*2^16 + 2^(2*n-2))
+ assert_eq!(norm, expected_norm);
+
+ // the smallest possible entries (0)
+ let entries = vec![Field128::from(0), Field128::from(0), Field128::from(0)];
+ let norm = compute_norm_of_entries(entries, vsum.bits_per_entry).unwrap();
+ let expected_norm = Field128::from(3 * (1 << (2 * n - 2)));
+ // = 3 * (0^2 - 0*2^n + 2^(2*n-2))
+ assert_eq!(norm, expected_norm);
+ }
+ }
+
+ #[test]
+ fn test_bounded_fpvec_sum_parallel_invalid_args() {
+ // invalid initialization
+ // fixed point too large
+ <FixedPointBoundedL2VecSum<
+ FixedI128<U127>,
+ ParallelSum<Field128, PolyEval<Field128>>,
+ ParallelSum<Field128, Mul<Field128>>,
+ >>::new(3)
+ .unwrap_err();
+ // vector too large
+ <FixedPointBoundedL2VecSum<
+ FixedI64<U63>,
+ ParallelSum<Field128, PolyEval<Field128>>,
+ ParallelSum<Field128, Mul<Field128>>,
+ >>::new(3000000000)
+ .unwrap_err();
+ // fixed point type has more than one int bit
+ <FixedPointBoundedL2VecSum<
+ FixedI16<U14>,
+ ParallelSum<Field128, PolyEval<Field128>>,
+ ParallelSum<Field128, Mul<Field128>>,
+ >>::new(3)
+ .unwrap_err();
+ }
+}
diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs
new file mode 100644
index 0000000000..404bec125a
--- /dev/null
+++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs
@@ -0,0 +1,93 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementations of encoding fixed point types as field elements and field elements as floats
+//! for the [`FixedPointBoundedL2VecSum`](crate::flp::types::fixedpoint_l2::FixedPointBoundedL2VecSum) type.
+
+use crate::field::{Field128, FieldElementWithInteger};
+use fixed::types::extra::{U15, U31, U63};
+use fixed::{FixedI16, FixedI32, FixedI64};
+
+/// Assign a `Float` type to this type and describe how to represent this type as an integer of the
+/// given field, and how to represent a field element as the assigned `Float` type.
+pub trait CompatibleFloat {
+ /// Represent a field element as `Float`, given the number of clients `c`.
+ fn to_float(t: Field128, c: u128) -> f64;
+
+ /// Represent a value of this type as an integer in the given field.
+ fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer;
+}
+
+impl CompatibleFloat for FixedI16<U15> {
+ fn to_float(d: Field128, c: u128) -> f64 {
+ to_float_bits(d, c, 16)
+ }
+
+ fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer {
+ //signed two's complement integer representation
+ let i: i16 = self.to_bits();
+ // reinterpret as unsigned
+ let u = i as u16;
+ // invert the left-most bit to de-two-complement
+ u128::from(u ^ (1 << 15))
+ }
+}
+
+impl CompatibleFloat for FixedI32<U31> {
+ fn to_float(d: Field128, c: u128) -> f64 {
+ to_float_bits(d, c, 32)
+ }
+
+ fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer {
+ //signed two's complement integer representation
+ let i: i32 = self.to_bits();
+ // reinterpret as unsigned
+ let u = i as u32;
+ // invert the left-most bit to de-two-complement
+ u128::from(u ^ (1 << 31))
+ }
+}
+
+impl CompatibleFloat for FixedI64<U63> {
+ fn to_float(d: Field128, c: u128) -> f64 {
+ to_float_bits(d, c, 64)
+ }
+
+ fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer {
+ //signed two's complement integer representation
+ let i: i64 = self.to_bits();
+ // reinterpret as unsigned
+ let u = i as u64;
+ // invert the left-most bit to de-two-complement
+ u128::from(u ^ (1 << 63))
+ }
+}
+
+/// Return an `f64` representation of the field element `s`, assuming it is the computation result
+/// of a `c`-client fixed point vector summation with `n` fractional bits.
+fn to_float_bits(s: Field128, c: u128, n: i32) -> f64 {
+ // get integer representation of field element
+ let s_int: u128 = <Field128 as FieldElementWithInteger>::Integer::from(s);
+
+ // to decode a single integer, we'd use the function
+ // dec(y) = (y - 2^(n-1)) * 2^(1-n) = y * 2^(1-n) - 1
+ // as s is the sum of c encoded vector entries where c is the number of
+ // clients, we have to compute instead
+ // s * 2^(1-n) - c
+ //
+ // Furthermore, for better numerical stability, we reformulate this as
+ // = (s - c*2^(n-1)) * 2^(1-n)
+ // where the subtraction of `c` is done on integers and only afterwards
+ // the conversion to floats is done.
+ //
+ // Since the RHS of the substraction may be larger than the LHS
+ // (when the number we are decoding is going to be negative),
+ // yet we are dealing with unsigned 128-bit integers, we manually
+ // check for the resulting sign while ensuring that the subtraction
+ // does not underflow.
+ let (a, b, sign) = match (s_int, c << (n - 1)) {
+ (x, y) if x < y => (y, x, -1.0f64),
+ (x, y) => (x, y, 1.0f64),
+ };
+
+ ((a - b) as f64) * sign * f64::powi(2.0, 1 - n)
+}
diff --git a/third_party/rust/prio/src/fp.rs b/third_party/rust/prio/src/fp.rs
new file mode 100644
index 0000000000..d4c0dcdc2c
--- /dev/null
+++ b/third_party/rust/prio/src/fp.rs
@@ -0,0 +1,533 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Finite field arithmetic for any field GF(p) for which p < 2^128.
+
+#[cfg(test)]
+use rand::{prelude::*, Rng};
+
+/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots
+/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This
+/// is the largest input size we would ever need for the cryptographic applications in this crate.
+pub(crate) const MAX_ROOTS: usize = 20;
+
+/// This structure represents the parameters of a finite field GF(p) for which p < 2^128.
+#[derive(Debug, PartialEq, Eq)]
+pub(crate) struct FieldParameters {
+ /// The prime modulus `p`.
+ pub p: u128,
+ /// `mu = -p^(-1) mod 2^64`.
+ pub mu: u64,
+ /// `r2 = (2^128)^2 mod p`.
+ pub r2: u128,
+ /// The `2^num_roots`-th -principal root of unity. This element is used to generate the
+ /// elements of `roots`.
+ pub g: u128,
+ /// The number of principal roots of unity in `roots`.
+ pub num_roots: usize,
+ /// Equal to `2^b - 1`, where `b` is the length of `p` in bits.
+ pub bit_mask: u128,
+ /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the
+ /// multiplicative group. `roots[0]` is equal to one by definition.
+ pub roots: [u128; MAX_ROOTS + 1],
+}
+
+impl FieldParameters {
+ /// Addition. The result will be in [0, p), so long as both x and y are as well.
+ #[inline(always)]
+ pub fn add(&self, x: u128, y: u128) -> u128 {
+ // 0,x
+ // + 0,y
+ // =====
+ // c,z
+ let (z, carry) = x.overflowing_add(y);
+ // c, z
+ // - 0, p
+ // ========
+ // b1,s1,s0
+ let (s0, b0) = z.overflowing_sub(self.p);
+ let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128);
+ // if b1 == 1: return z
+ // else: return s0
+ let m = 0u128.wrapping_sub(b1 as u128);
+ (z & m) | (s0 & !m)
+ }
+
+ /// Subtraction. The result will be in [0, p), so long as both x and y are as well.
+ #[inline(always)]
+ pub fn sub(&self, x: u128, y: u128) -> u128 {
+ // 0, x
+ // - 0, y
+ // ========
+ // b1,z1,z0
+ let (z0, b0) = x.overflowing_sub(y);
+ let (_z1, b1) = 0u128.overflowing_sub(b0 as u128);
+ let m = 0u128.wrapping_sub(b1 as u128);
+ // z1,z0
+ // + 0, p
+ // ========
+ // s1,s0
+ z0.wrapping_add(m & self.p)
+ // if b1 == 1: return s0
+ // else: return z0
+ }
+
+ /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm
+ /// described
+ /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf).
+ /// The result will be in [0, p).
+ ///
+ /// # Example usage
+ /// ```text
+ /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46);
+ /// ```
+ #[inline(always)]
+ pub fn mul(&self, x: u128, y: u128) -> u128 {
+ let x = [lo64(x), hi64(x)];
+ let y = [lo64(y), hi64(y)];
+ let p = [lo64(self.p), hi64(self.p)];
+ let mut zz = [0; 4];
+
+ // Integer multiplication
+ // z = x * y
+
+ // x1,x0
+ // * y1,y0
+ // ===========
+ // z3,z2,z1,z0
+ let mut result = x[0] * y[0];
+ let mut carry = hi64(result);
+ zz[0] = lo64(result);
+ result = x[0] * y[1];
+ let mut hi = hi64(result);
+ let mut lo = lo64(result);
+ result = lo + carry;
+ zz[1] = lo64(result);
+ let mut cc = hi64(result);
+ result = hi + cc;
+ zz[2] = lo64(result);
+
+ result = x[1] * y[0];
+ hi = hi64(result);
+ lo = lo64(result);
+ result = zz[1] + lo;
+ zz[1] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ carry = lo64(result);
+
+ result = x[1] * y[1];
+ hi = hi64(result);
+ lo = lo64(result);
+ result = lo + carry;
+ lo = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ hi = lo64(result);
+ result = zz[2] + lo;
+ zz[2] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ zz[3] = lo64(result);
+
+ // Montgomery Reduction
+ // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64.
+
+ // z3,z2,z1,z0
+ // + p1,p0
+ // * w = mu*z0
+ // ===========
+ // z3,z2,z1, 0
+ let w = self.mu.wrapping_mul(zz[0] as u64);
+ result = p[0] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = zz[0] + lo;
+ zz[0] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ carry = lo64(result);
+
+ result = p[1] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = lo + carry;
+ lo = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ hi = lo64(result);
+ result = zz[1] + lo;
+ zz[1] = lo64(result);
+ cc = hi64(result);
+ result = zz[2] + hi + cc;
+ zz[2] = lo64(result);
+ cc = hi64(result);
+ result = zz[3] + cc;
+ zz[3] = lo64(result);
+
+ // z3,z2,z1
+ // + p1,p0
+ // * w = mu*z1
+ // ===========
+ // z3,z2, 0
+ let w = self.mu.wrapping_mul(zz[1] as u64);
+ result = p[0] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = zz[1] + lo;
+ zz[1] = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ carry = lo64(result);
+
+ result = p[1] * (w as u128);
+ hi = hi64(result);
+ lo = lo64(result);
+ result = lo + carry;
+ lo = lo64(result);
+ cc = hi64(result);
+ result = hi + cc;
+ hi = lo64(result);
+ result = zz[2] + lo;
+ zz[2] = lo64(result);
+ cc = hi64(result);
+ result = zz[3] + hi + cc;
+ zz[3] = lo64(result);
+ cc = hi64(result);
+
+ // z = (z3,z2)
+ let prod = zz[2] | (zz[3] << 64);
+
+ // Final subtraction
+ // If z >= p, then z = z - p
+
+ // 0, z
+ // - 0, p
+ // ========
+ // b1,s1,s0
+ let (s0, b0) = prod.overflowing_sub(self.p);
+ let (_s1, b1) = cc.overflowing_sub(b0 as u128);
+ // if b1 == 1: return z
+ // else: return s0
+ let mask = 0u128.wrapping_sub(b1 as u128);
+ (prod & mask) | (s0 & !mask)
+ }
+
+ /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the
+ /// runtime of this algorithm is linear in the bit length of `exp`.
+ pub fn pow(&self, x: u128, exp: u128) -> u128 {
+ let mut t = self.montgomery(1);
+ for i in (0..128 - exp.leading_zeros()).rev() {
+ t = self.mul(t, t);
+ if (exp >> i) & 1 != 0 {
+ t = self.mul(t, x);
+ }
+ }
+ t
+ }
+
+ /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of
+ /// this algorithm is linear in the bit length of `p`.
+ #[inline(always)]
+ pub fn inv(&self, x: u128) -> u128 {
+ self.pow(x, self.p - 2)
+ }
+
+ /// Negation, i.e., `-x (mod p)` where `p` is the modulus.
+ #[inline(always)]
+ pub fn neg(&self, x: u128) -> u128 {
+ self.sub(0, x)
+ }
+
+ /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery
+ /// domain in order to carry out field arithmetic. The result will be in [0, p).
+ ///
+ /// # Example usage
+ /// ```text
+ /// let integer = 1; // Standard integer representation
+ /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain
+ /// assert_eq!(elem, 2564090464);
+ /// ```
+ #[inline(always)]
+ pub fn montgomery(&self, x: u128) -> u128 {
+ modp(self.mul(x, self.r2), self.p)
+ }
+
+ /// Returns a random field element mapped.
+ #[cfg(test)]
+ pub fn rand_elem<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 {
+ let uniform = rand::distributions::Uniform::from(0..self.p);
+ self.montgomery(uniform.sample(rng))
+ }
+
+ /// Maps a field element to its representation as an integer. The result will be in [0, p).
+ ///
+ /// #Example usage
+ /// ```text
+ /// let elem = 2564090464; // Internal representation in the Montgomery domain
+ /// let integer = fp.residue(elem); // Standard integer representation
+ /// assert_eq!(integer, 1);
+ /// ```
+ #[inline(always)]
+ pub fn residue(&self, x: u128) -> u128 {
+ modp(self.mul(x, 1), self.p)
+ }
+
+ #[cfg(test)]
+ pub fn check(&self, p: u128, g: u128, order: u128) {
+ use modinverse::modinverse;
+ use num_bigint::{BigInt, ToBigInt};
+ use std::cmp::max;
+
+ assert_eq!(self.p, p, "p mismatch");
+
+ let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) {
+ Some(mu) => mu as u64,
+ None => panic!("inverse of -p (mod 2^64) is undefined"),
+ };
+ assert_eq!(self.mu, mu, "mu mismatch");
+
+ let big_p = &p.to_bigint().unwrap();
+ let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p);
+ let big_r2: &BigInt = &(&(big_r * big_r) % big_p);
+ let mut it = big_r2.iter_u64_digits();
+ let mut r2 = 0;
+ r2 |= it.next().unwrap() as u128;
+ if let Some(x) = it.next() {
+ r2 |= (x as u128) << 64;
+ }
+ assert_eq!(self.r2, r2, "r2 mismatch");
+
+ assert_eq!(self.g, self.montgomery(g), "g mismatch");
+ assert_eq!(
+ self.residue(self.pow(self.g, order)),
+ 1,
+ "g order incorrect"
+ );
+
+ let num_roots = log2(order) as usize;
+ assert_eq!(order, 1 << num_roots, "order not a power of 2");
+ assert_eq!(self.num_roots, num_roots, "num_roots mismatch");
+
+ let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1];
+ roots[num_roots] = self.montgomery(g);
+ for i in (0..num_roots).rev() {
+ roots[i] = self.mul(roots[i + 1], roots[i + 1]);
+ }
+ assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch");
+ assert_eq!(self.residue(self.roots[0]), 1, "first root is not one");
+
+ let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1);
+ assert_eq!(
+ self.bit_mask.to_bigint().unwrap(),
+ bit_mask,
+ "bit_mask mismatch"
+ );
+ }
+}
+
+#[inline(always)]
+fn lo64(x: u128) -> u128 {
+ x & ((1 << 64) - 1)
+}
+
+#[inline(always)]
+fn hi64(x: u128) -> u128 {
+ x >> 64
+}
+
+#[inline(always)]
+fn modp(x: u128, p: u128) -> u128 {
+ let (z, carry) = x.overflowing_sub(p);
+ let m = 0u128.wrapping_sub(carry as u128);
+ z.wrapping_add(m & p)
+}
+
+pub(crate) const FP32: FieldParameters = FieldParameters {
+ p: 4293918721, // 32-bit prime
+ mu: 17302828673139736575,
+ r2: 1676699750,
+ g: 1074114499,
+ num_roots: 20,
+ bit_mask: 4294967295,
+ roots: [
+ 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825,
+ 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415,
+ 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499,
+ ],
+};
+
+pub(crate) const FP64: FieldParameters = FieldParameters {
+ p: 18446744069414584321, // 64-bit prime
+ mu: 18446744069414584319,
+ r2: 4294967295,
+ g: 959634606461954525,
+ num_roots: 32,
+ bit_mask: 18446744073709551615,
+ roots: [
+ 18446744065119617025,
+ 4294967296,
+ 18446462594437939201,
+ 72057594037927936,
+ 1152921504338411520,
+ 16384,
+ 18446743519658770561,
+ 18446735273187346433,
+ 6519596376689022014,
+ 9996039020351967275,
+ 15452408553935940313,
+ 15855629130643256449,
+ 8619522106083987867,
+ 13036116919365988132,
+ 1033106119984023956,
+ 16593078884869787648,
+ 16980581328500004402,
+ 12245796497946355434,
+ 8709441440702798460,
+ 8611358103550827629,
+ 8120528636261052110,
+ ],
+};
+
+pub(crate) const FP128: FieldParameters = FieldParameters {
+ p: 340282366920938462946865773367900766209, // 128-bit prime
+ mu: 18446744073709551615,
+ r2: 403909908237944342183153,
+ g: 107630958476043550189608038630704257141,
+ num_roots: 66,
+ bit_mask: 340282366920938463463374607431768211455,
+ roots: [
+ 516508834063867445247,
+ 340282366920938462430356939304033320962,
+ 129526470195413442198896969089616959958,
+ 169031622068548287099117778531474117974,
+ 81612939378432101163303892927894236156,
+ 122401220764524715189382260548353967708,
+ 199453575871863981432000940507837456190,
+ 272368408887745135168960576051472383806,
+ 24863773656265022616993900367764287617,
+ 257882853788779266319541142124730662203,
+ 323732363244658673145040701829006542956,
+ 57532865270871759635014308631881743007,
+ 149571414409418047452773959687184934208,
+ 177018931070866797456844925926211239962,
+ 268896136799800963964749917185333891349,
+ 244556960591856046954834420512544511831,
+ 118945432085812380213390062516065622346,
+ 202007153998709986841225284843501908420,
+ 332677126194796691532164818746739771387,
+ 258279638927684931537542082169183965856,
+ 148221243758794364405224645520862378432,
+ ],
+};
+
+// Compute the ceiling of the base-2 logarithm of `x`.
+pub(crate) fn log2(x: u128) -> u128 {
+ let y = (127 - x.leading_zeros()) as u128;
+ y + ((x > 1 << y) as u128)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use num_bigint::ToBigInt;
+
+ #[test]
+ fn test_log2() {
+ assert_eq!(log2(1), 0);
+ assert_eq!(log2(2), 1);
+ assert_eq!(log2(3), 2);
+ assert_eq!(log2(4), 2);
+ assert_eq!(log2(15), 4);
+ assert_eq!(log2(16), 4);
+ assert_eq!(log2(30), 5);
+ assert_eq!(log2(32), 5);
+ assert_eq!(log2(1 << 127), 127);
+ assert_eq!(log2((1 << 127) + 13), 128);
+ }
+
+ struct TestFieldParametersData {
+ fp: FieldParameters, // The paramters being tested
+ expected_p: u128, // Expected fp.p
+ expected_g: u128, // Expected fp.residue(fp.g)
+ expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1
+ }
+
+ #[test]
+ fn test_fp() {
+ let test_fps = vec![
+ TestFieldParametersData {
+ fp: FP32,
+ expected_p: 4293918721,
+ expected_g: 3925978153,
+ expected_order: 1 << 20,
+ },
+ TestFieldParametersData {
+ fp: FP64,
+ expected_p: 18446744069414584321,
+ expected_g: 1753635133440165772,
+ expected_order: 1 << 32,
+ },
+ TestFieldParametersData {
+ fp: FP128,
+ expected_p: 340282366920938462946865773367900766209,
+ expected_g: 145091266659756586618791329697897684742,
+ expected_order: 1 << 66,
+ },
+ ];
+
+ for t in test_fps.into_iter() {
+ // Check that the field parameters have been constructed properly.
+ t.fp.check(t.expected_p, t.expected_g, t.expected_order);
+
+ // Check that the generator has the correct order.
+ assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1);
+ assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1);
+
+ // Test arithmetic using the field parameters.
+ arithmetic_test(&t.fp);
+ }
+ }
+
+ fn arithmetic_test(fp: &FieldParameters) {
+ let mut rng = rand::thread_rng();
+ let big_p = &fp.p.to_bigint().unwrap();
+
+ for _ in 0..100 {
+ let x = fp.rand_elem(&mut rng);
+ let y = fp.rand_elem(&mut rng);
+ let big_x = &fp.residue(x).to_bigint().unwrap();
+ let big_y = &fp.residue(y).to_bigint().unwrap();
+
+ // Test addition.
+ let got = fp.add(x, y);
+ let want = (big_x + big_y) % big_p;
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+
+ // Test subtraction.
+ let got = fp.sub(x, y);
+ let want = if big_x >= big_y {
+ big_x - big_y
+ } else {
+ big_p - big_y + big_x
+ };
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+
+ // Test multiplication.
+ let got = fp.mul(x, y);
+ let want = (big_x * big_y) % big_p;
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+
+ // Test inversion.
+ let got = fp.inv(x);
+ let want = big_x.modpow(&(big_p - 2u128), big_p);
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+ assert_eq!(fp.residue(fp.mul(got, x)), 1);
+
+ // Test negation.
+ let got = fp.neg(x);
+ let want = (big_p - big_x) % big_p;
+ assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
+ assert_eq!(fp.residue(fp.add(got, x)), 0);
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/idpf.rs b/third_party/rust/prio/src/idpf.rs
new file mode 100644
index 0000000000..2bb73f2159
--- /dev/null
+++ b/third_party/rust/prio/src/idpf.rs
@@ -0,0 +1,2200 @@
+//! This module implements the incremental distributed point function (IDPF) described in
+//! [[draft-irtf-cfrg-vdaf-07]].
+//!
+//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+
+use crate::{
+ codec::{CodecError, Decode, Encode, ParameterizedDecode},
+ field::{FieldElement, FieldElementExt},
+ vdaf::{
+ xof::{Seed, XofFixedKeyAes128Key},
+ VdafError, VERSION,
+ },
+};
+use bitvec::{
+ bitvec,
+ boxed::BitBox,
+ prelude::{Lsb0, Msb0},
+ slice::BitSlice,
+ vec::BitVec,
+ view::BitView,
+};
+use rand_core::RngCore;
+use std::{
+ collections::{HashMap, VecDeque},
+ fmt::Debug,
+ io::{Cursor, Read},
+ ops::{Add, AddAssign, ControlFlow, Index, Sub},
+};
+use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
+
+/// IDPF-related errors.
+#[derive(Debug, thiserror::Error)]
+pub enum IdpfError {
+ /// Error from incompatible shares at different levels.
+ #[error("tried to merge shares from incompatible levels")]
+ MismatchedLevel,
+
+ /// Invalid parameter, indicates an invalid input to either [`Idpf::gen`] or [`Idpf::eval`].
+ #[error("invalid parameter: {0}")]
+ InvalidParameter(String),
+}
+
+/// An index used as the input to an IDPF evaluation.
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct IdpfInput {
+ /// The index as a boxed bit slice.
+ index: BitBox,
+}
+
+impl IdpfInput {
+ /// Convert a slice of bytes into an IDPF input, where the bits of each byte are processed in
+ /// MSB-to-LSB order. (Subsequent bytes are processed in their natural order.)
+ pub fn from_bytes(bytes: &[u8]) -> IdpfInput {
+ let bit_slice_u8_storage = bytes.view_bits::<Msb0>();
+ let mut bit_vec_usize_storage = bitvec![0; bit_slice_u8_storage.len()];
+ bit_vec_usize_storage.clone_from_bitslice(bit_slice_u8_storage);
+ IdpfInput {
+ index: bit_vec_usize_storage.into_boxed_bitslice(),
+ }
+ }
+
+ /// Convert a slice of booleans into an IDPF input.
+ pub fn from_bools(bools: &[bool]) -> IdpfInput {
+ let bits = bools.iter().collect::<BitVec>();
+ IdpfInput {
+ index: bits.into_boxed_bitslice(),
+ }
+ }
+
+ /// Create a new IDPF input by appending to this input.
+ pub fn clone_with_suffix(&self, suffix: &[bool]) -> IdpfInput {
+ let mut vec = BitVec::with_capacity(self.index.len() + suffix.len());
+ vec.extend_from_bitslice(&self.index);
+ vec.extend(suffix);
+ IdpfInput {
+ index: vec.into_boxed_bitslice(),
+ }
+ }
+
+ /// Get the length of the input in bits.
+ pub fn len(&self) -> usize {
+ self.index.len()
+ }
+
+ /// Check if the input is empty, i.e. it does not contain any bits.
+ pub fn is_empty(&self) -> bool {
+ self.index.is_empty()
+ }
+
+ /// Get an iterator over the bits that make up this input.
+ pub fn iter(&self) -> impl DoubleEndedIterator<Item = bool> + '_ {
+ self.index.iter().by_vals()
+ }
+
+ /// Convert the IDPF into a byte slice. If the length of the underlying bit vector is not a
+ /// multiple of `8`, then the least significant bits of the last byte are `0`-padded.
+ pub fn to_bytes(&self) -> Vec<u8> {
+ let mut vec = BitVec::<u8, Msb0>::with_capacity(self.index.len());
+ vec.extend_from_bitslice(&self.index);
+ vec.set_uninitialized(false);
+ vec.into_vec()
+ }
+
+ /// Return the `level`-bit prefix of this IDPF input.
+ pub fn prefix(&self, level: usize) -> Self {
+ Self {
+ index: self.index[..=level].to_owned().into(),
+ }
+ }
+}
+
+impl From<BitVec<usize, Lsb0>> for IdpfInput {
+ fn from(bit_vec: BitVec<usize, Lsb0>) -> Self {
+ IdpfInput {
+ index: bit_vec.into_boxed_bitslice(),
+ }
+ }
+}
+
+impl From<BitBox<usize, Lsb0>> for IdpfInput {
+ fn from(bit_box: BitBox<usize, Lsb0>) -> Self {
+ IdpfInput { index: bit_box }
+ }
+}
+
+impl<I> Index<I> for IdpfInput
+where
+ BitSlice: Index<I>,
+{
+ type Output = <BitSlice as Index<I>>::Output;
+
+ fn index(&self, index: I) -> &Self::Output {
+ &self.index[index]
+ }
+}
+
+/// Trait for values to be programmed into an IDPF.
+///
+/// Values must form an Abelian group, so that they can be secret-shared, and the group operation
+/// must be represented by [`Add`]. Values must be encodable and decodable, without need for a
+/// decoding parameter. Values can be pseudorandomly generated, with a uniform probability
+/// distribution, from XOF output.
+pub trait IdpfValue:
+ Add<Output = Self>
+ + AddAssign
+ + Sub<Output = Self>
+ + ConditionallyNegatable
+ + Encode
+ + Decode
+ + Sized
+{
+ /// Any run-time parameters needed to produce a value.
+ type ValueParameter;
+
+ /// Generate a pseudorandom value from a seed stream.
+ fn generate<S>(seed_stream: &mut S, parameter: &Self::ValueParameter) -> Self
+ where
+ S: RngCore;
+
+ /// Returns the additive identity.
+ fn zero(parameter: &Self::ValueParameter) -> Self;
+
+ /// Conditionally select between two values. Implementations must perform this operation in
+ /// constant time.
+ ///
+ /// This is the same as in [`subtle::ConditionallySelectable`], but without the [`Copy`] bound.
+ fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self;
+}
+
+impl<F> IdpfValue for F
+where
+ F: FieldElement,
+{
+ type ValueParameter = ();
+
+ fn generate<S>(seed_stream: &mut S, _: &()) -> Self
+ where
+ S: RngCore,
+ {
+ // This is analogous to `Prng::get()`, but does not make use of a persistent buffer of
+ // output.
+ let mut buffer = [0u8; 64];
+ assert!(
+ buffer.len() >= F::ENCODED_SIZE,
+ "field is too big for buffer"
+ );
+ loop {
+ seed_stream.fill_bytes(&mut buffer[..F::ENCODED_SIZE]);
+ match F::from_random_rejection(&buffer[..F::ENCODED_SIZE]) {
+ ControlFlow::Break(x) => return x,
+ ControlFlow::Continue(()) => continue,
+ }
+ }
+ }
+
+ fn zero(_: &()) -> Self {
+ <Self as FieldElement>::zero()
+ }
+
+ fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
+ <F as ConditionallySelectable>::conditional_select(a, b, choice)
+ }
+}
+
+/// An output from evaluation of an IDPF at some level and index.
+#[derive(Debug, PartialEq, Eq)]
+pub enum IdpfOutputShare<VI, VL> {
+ /// An IDPF output share corresponding to an inner tree node.
+ Inner(VI),
+ /// An IDPF output share corresponding to a leaf tree node.
+ Leaf(VL),
+}
+
+impl<VI, VL> IdpfOutputShare<VI, VL>
+where
+ VI: IdpfValue,
+ VL: IdpfValue,
+{
+ /// Combine two output share values into one.
+ pub fn merge(self, other: Self) -> Result<IdpfOutputShare<VI, VL>, IdpfError> {
+ match (self, other) {
+ (IdpfOutputShare::Inner(mut self_value), IdpfOutputShare::Inner(other_value)) => {
+ self_value += other_value;
+ Ok(IdpfOutputShare::Inner(self_value))
+ }
+ (IdpfOutputShare::Leaf(mut self_value), IdpfOutputShare::Leaf(other_value)) => {
+ self_value += other_value;
+ Ok(IdpfOutputShare::Leaf(self_value))
+ }
+ (_, _) => Err(IdpfError::MismatchedLevel),
+ }
+ }
+}
+
+fn extend(seed: &[u8; 16], xof_fixed_key: &XofFixedKeyAes128Key) -> ([[u8; 16]; 2], [Choice; 2]) {
+ let mut seed_stream = xof_fixed_key.with_seed(seed);
+
+ let mut seeds = [[0u8; 16], [0u8; 16]];
+ seed_stream.fill_bytes(&mut seeds[0]);
+ seed_stream.fill_bytes(&mut seeds[1]);
+
+ let mut byte = [0u8];
+ seed_stream.fill_bytes(&mut byte);
+ let control_bits = [(byte[0] & 1).into(), ((byte[0] >> 1) & 1).into()];
+
+ (seeds, control_bits)
+}
+
+fn convert<V>(
+ seed: &[u8; 16],
+ xof_fixed_key: &XofFixedKeyAes128Key,
+ parameter: &V::ValueParameter,
+) -> ([u8; 16], V)
+where
+ V: IdpfValue,
+{
+ let mut seed_stream = xof_fixed_key.with_seed(seed);
+
+ let mut next_seed = [0u8; 16];
+ seed_stream.fill_bytes(&mut next_seed);
+
+ (next_seed, V::generate(&mut seed_stream, parameter))
+}
+
+/// Helper method to update seeds, update control bits, and output the correction word for one level
+/// of the IDPF key generation process.
+fn generate_correction_word<V>(
+ input_bit: Choice,
+ value: V,
+ parameter: &V::ValueParameter,
+ keys: &mut [[u8; 16]; 2],
+ control_bits: &mut [Choice; 2],
+ extend_xof_fixed_key: &XofFixedKeyAes128Key,
+ convert_xof_fixed_key: &XofFixedKeyAes128Key,
+) -> IdpfCorrectionWord<V>
+where
+ V: IdpfValue,
+{
+ // Expand both keys into two seeds and two control bits each.
+ let (seed_0, control_bits_0) = extend(&keys[0], extend_xof_fixed_key);
+ let (seed_1, control_bits_1) = extend(&keys[1], extend_xof_fixed_key);
+
+ let (keep, lose) = (input_bit, !input_bit);
+
+ let cw_seed = xor_seeds(
+ &conditional_select_seed(lose, &seed_0),
+ &conditional_select_seed(lose, &seed_1),
+ );
+ let cw_control_bits = [
+ control_bits_0[0] ^ control_bits_1[0] ^ input_bit ^ Choice::from(1),
+ control_bits_0[1] ^ control_bits_1[1] ^ input_bit,
+ ];
+ let cw_control_bits_keep =
+ Choice::conditional_select(&cw_control_bits[0], &cw_control_bits[1], keep);
+
+ let previous_control_bits = *control_bits;
+ let control_bits_0_keep =
+ Choice::conditional_select(&control_bits_0[0], &control_bits_0[1], keep);
+ let control_bits_1_keep =
+ Choice::conditional_select(&control_bits_1[0], &control_bits_1[1], keep);
+ control_bits[0] = control_bits_0_keep ^ (cw_control_bits_keep & previous_control_bits[0]);
+ control_bits[1] = control_bits_1_keep ^ (cw_control_bits_keep & previous_control_bits[1]);
+
+ let seed_0_keep = conditional_select_seed(keep, &seed_0);
+ let seed_1_keep = conditional_select_seed(keep, &seed_1);
+ let seeds_corrected = [
+ conditional_xor_seeds(&seed_0_keep, &cw_seed, previous_control_bits[0]),
+ conditional_xor_seeds(&seed_1_keep, &cw_seed, previous_control_bits[1]),
+ ];
+
+ let (new_key_0, elements_0) =
+ convert::<V>(&seeds_corrected[0], convert_xof_fixed_key, parameter);
+ let (new_key_1, elements_1) =
+ convert::<V>(&seeds_corrected[1], convert_xof_fixed_key, parameter);
+
+ keys[0] = new_key_0;
+ keys[1] = new_key_1;
+
+ let mut cw_value = value - elements_0 + elements_1;
+ cw_value.conditional_negate(control_bits[1]);
+
+ IdpfCorrectionWord {
+ seed: cw_seed,
+ control_bits: cw_control_bits,
+ value: cw_value,
+ }
+}
+
+/// Helper function to evaluate one level of an IDPF. This updates the seed and control bit
+/// arguments that are passed in.
+#[allow(clippy::too_many_arguments)]
+fn eval_next<V>(
+ is_leader: bool,
+ parameter: &V::ValueParameter,
+ key: &mut [u8; 16],
+ control_bit: &mut Choice,
+ correction_word: &IdpfCorrectionWord<V>,
+ input_bit: Choice,
+ extend_xof_fixed_key: &XofFixedKeyAes128Key,
+ convert_xof_fixed_key: &XofFixedKeyAes128Key,
+) -> V
+where
+ V: IdpfValue,
+{
+ let (mut seeds, mut control_bits) = extend(key, extend_xof_fixed_key);
+
+ seeds[0] = conditional_xor_seeds(&seeds[0], &correction_word.seed, *control_bit);
+ control_bits[0] ^= correction_word.control_bits[0] & *control_bit;
+ seeds[1] = conditional_xor_seeds(&seeds[1], &correction_word.seed, *control_bit);
+ control_bits[1] ^= correction_word.control_bits[1] & *control_bit;
+
+ let seed_corrected = conditional_select_seed(input_bit, &seeds);
+ *control_bit = Choice::conditional_select(&control_bits[0], &control_bits[1], input_bit);
+
+ let (new_key, elements) = convert::<V>(&seed_corrected, convert_xof_fixed_key, parameter);
+ *key = new_key;
+
+ let mut out =
+ elements + V::conditional_select(&V::zero(parameter), &correction_word.value, *control_bit);
+ out.conditional_negate(Choice::from((!is_leader) as u8));
+ out
+}
+
+/// This defines a family of IDPFs (incremental distributed point functions) with certain types of
+/// values at inner tree nodes and at leaf tree nodes.
+///
+/// IDPF keys can be generated by providing an input and programmed outputs for each tree level to
+/// [`Idpf::gen`].
+pub struct Idpf<VI, VL>
+where
+ VI: IdpfValue,
+ VL: IdpfValue,
+{
+ inner_node_value_parameter: VI::ValueParameter,
+ leaf_node_value_parameter: VL::ValueParameter,
+}
+
+impl<VI, VL> Idpf<VI, VL>
+where
+ VI: IdpfValue,
+ VL: IdpfValue,
+{
+ /// Construct an [`Idpf`] instance with the given run-time parameters needed for inner and leaf
+ /// values.
+ pub fn new(
+ inner_node_value_parameter: VI::ValueParameter,
+ leaf_node_value_parameter: VL::ValueParameter,
+ ) -> Self {
+ Self {
+ inner_node_value_parameter,
+ leaf_node_value_parameter,
+ }
+ }
+
+ pub(crate) fn gen_with_random<M: IntoIterator<Item = VI>>(
+ &self,
+ input: &IdpfInput,
+ inner_values: M,
+ leaf_value: VL,
+ binder: &[u8],
+ random: &[[u8; 16]; 2],
+ ) -> Result<(IdpfPublicShare<VI, VL>, [Seed<16>; 2]), VdafError> {
+ let bits = input.len();
+
+ let initial_keys: [Seed<16>; 2] =
+ [Seed::from_bytes(random[0]), Seed::from_bytes(random[1])];
+
+ let extend_dst = [
+ VERSION, 1, /* algorithm class */
+ 0, 0, 0, 0, /* algorithm ID */
+ 0, 0, /* usage */
+ ];
+ let convert_dst = [
+ VERSION, 1, /* algorithm class */
+ 0, 0, 0, 0, /* algorithm ID */
+ 0, 1, /* usage */
+ ];
+ let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&extend_dst, binder);
+ let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&convert_dst, binder);
+
+ let mut keys = [initial_keys[0].0, initial_keys[1].0];
+ let mut control_bits = [Choice::from(0u8), Choice::from(1u8)];
+ let mut inner_correction_words = Vec::with_capacity(bits - 1);
+
+ for (level, value) in inner_values.into_iter().enumerate() {
+ if level >= bits - 1 {
+ return Err(IdpfError::InvalidParameter(
+ "too many values were supplied".to_string(),
+ )
+ .into());
+ }
+ inner_correction_words.push(generate_correction_word::<VI>(
+ Choice::from(input[level] as u8),
+ value,
+ &self.inner_node_value_parameter,
+ &mut keys,
+ &mut control_bits,
+ &extend_xof_fixed_key,
+ &convert_xof_fixed_key,
+ ));
+ }
+ if inner_correction_words.len() != bits - 1 {
+ return Err(
+ IdpfError::InvalidParameter("too few values were supplied".to_string()).into(),
+ );
+ }
+ let leaf_correction_word = generate_correction_word::<VL>(
+ Choice::from(input[bits - 1] as u8),
+ leaf_value,
+ &self.leaf_node_value_parameter,
+ &mut keys,
+ &mut control_bits,
+ &extend_xof_fixed_key,
+ &convert_xof_fixed_key,
+ );
+ let public_share = IdpfPublicShare {
+ inner_correction_words,
+ leaf_correction_word,
+ };
+
+ Ok((public_share, initial_keys))
+ }
+
+ /// The IDPF key generation algorithm.
+ ///
+ /// Generate and return a sequence of IDPF shares for `input`. The parameters `inner_values`
+ /// and `leaf_value` provide the output values for each successive level of the prefix tree.
+ pub fn gen<M>(
+ &self,
+ input: &IdpfInput,
+ inner_values: M,
+ leaf_value: VL,
+ binder: &[u8],
+ ) -> Result<(IdpfPublicShare<VI, VL>, [Seed<16>; 2]), VdafError>
+ where
+ M: IntoIterator<Item = VI>,
+ {
+ if input.is_empty() {
+ return Err(
+ IdpfError::InvalidParameter("invalid number of bits: 0".to_string()).into(),
+ );
+ }
+ let mut random = [[0u8; 16]; 2];
+ for random_seed in random.iter_mut() {
+ getrandom::getrandom(random_seed)?;
+ }
+ self.gen_with_random(input, inner_values, leaf_value, binder, &random)
+ }
+
+ /// Evaluate an IDPF share on `prefix`, starting from a particular tree level with known
+ /// intermediate values.
+ #[allow(clippy::too_many_arguments)]
+ fn eval_from_node(
+ &self,
+ is_leader: bool,
+ public_share: &IdpfPublicShare<VI, VL>,
+ start_level: usize,
+ mut key: [u8; 16],
+ mut control_bit: Choice,
+ prefix: &IdpfInput,
+ binder: &[u8],
+ cache: &mut dyn IdpfCache,
+ ) -> Result<IdpfOutputShare<VI, VL>, IdpfError> {
+ let bits = public_share.inner_correction_words.len() + 1;
+
+ let extend_dst = [
+ VERSION, 1, /* algorithm class */
+ 0, 0, 0, 0, /* algorithm ID */
+ 0, 0, /* usage */
+ ];
+ let convert_dst = [
+ VERSION, 1, /* algorithm class */
+ 0, 0, 0, 0, /* algorithm ID */
+ 0, 1, /* usage */
+ ];
+ let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&extend_dst, binder);
+ let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&convert_dst, binder);
+
+ let mut last_inner_output = None;
+ for ((correction_word, input_bit), level) in public_share.inner_correction_words
+ [start_level..]
+ .iter()
+ .zip(prefix[start_level..].iter())
+ .zip(start_level..)
+ {
+ last_inner_output = Some(eval_next(
+ is_leader,
+ &self.inner_node_value_parameter,
+ &mut key,
+ &mut control_bit,
+ correction_word,
+ Choice::from(*input_bit as u8),
+ &extend_xof_fixed_key,
+ &convert_xof_fixed_key,
+ ));
+ let cache_key = &prefix[..=level];
+ cache.insert(cache_key, &(key, control_bit.unwrap_u8()));
+ }
+
+ if prefix.len() == bits {
+ let leaf_output = eval_next(
+ is_leader,
+ &self.leaf_node_value_parameter,
+ &mut key,
+ &mut control_bit,
+ &public_share.leaf_correction_word,
+ Choice::from(prefix[bits - 1] as u8),
+ &extend_xof_fixed_key,
+ &convert_xof_fixed_key,
+ );
+ // Note: there's no point caching this node's key, because we will always run the
+ // eval_next() call for the leaf level.
+ Ok(IdpfOutputShare::Leaf(leaf_output))
+ } else {
+ Ok(IdpfOutputShare::Inner(last_inner_output.unwrap()))
+ }
+ }
+
+ /// The IDPF key evaluation algorithm.
+ ///
+ /// Evaluate an IDPF share on `prefix`.
+ pub fn eval(
+ &self,
+ agg_id: usize,
+ public_share: &IdpfPublicShare<VI, VL>,
+ key: &Seed<16>,
+ prefix: &IdpfInput,
+ binder: &[u8],
+ cache: &mut dyn IdpfCache,
+ ) -> Result<IdpfOutputShare<VI, VL>, IdpfError> {
+ let bits = public_share.inner_correction_words.len() + 1;
+ if agg_id > 1 {
+ return Err(IdpfError::InvalidParameter(format!(
+ "invalid aggregator ID {agg_id}"
+ )));
+ }
+ let is_leader = agg_id == 0;
+ if prefix.is_empty() {
+ return Err(IdpfError::InvalidParameter("empty prefix".to_string()));
+ }
+ if prefix.len() > bits {
+ return Err(IdpfError::InvalidParameter(format!(
+ "prefix length ({}) exceeds configured number of bits ({})",
+ prefix.len(),
+ bits,
+ )));
+ }
+
+ // Check for cached keys first, starting from the end of our desired path down the tree, and
+ // walking back up. If we get a hit, stop there and evaluate the remainder of the tree path
+ // going forward.
+ if prefix.len() > 1 {
+ // Skip checking for `prefix` in the cache, because we don't store field element
+ // values along with keys and control bits. Instead, start looking one node higher
+ // up, so we can recompute everything for the last level of `prefix`.
+ let mut cache_key = &prefix[..prefix.len() - 1];
+ while !cache_key.is_empty() {
+ if let Some((key, control_bit)) = cache.get(cache_key) {
+ // Evaluate the IDPF starting from the cached data at a previously-computed
+ // node, and return the result.
+ return self.eval_from_node(
+ is_leader,
+ public_share,
+ /* start_level */ cache_key.len(),
+ key,
+ Choice::from(control_bit),
+ prefix,
+ binder,
+ cache,
+ );
+ }
+ cache_key = &cache_key[..cache_key.len() - 1];
+ }
+ }
+ // Evaluate starting from the root node.
+ self.eval_from_node(
+ is_leader,
+ public_share,
+ /* start_level */ 0,
+ key.0,
+ /* control_bit */ Choice::from((!is_leader) as u8),
+ prefix,
+ binder,
+ cache,
+ )
+ }
+}
+
+/// An IDPF public share. This contains the list of correction words used by all parties when
+/// evaluating the IDPF.
+#[derive(Debug, Clone)]
+pub struct IdpfPublicShare<VI, VL> {
+ /// Correction words for each inner node level.
+ inner_correction_words: Vec<IdpfCorrectionWord<VI>>,
+ /// Correction word for the leaf node level.
+ leaf_correction_word: IdpfCorrectionWord<VL>,
+}
+
+impl<VI, VL> ConstantTimeEq for IdpfPublicShare<VI, VL>
+where
+ VI: ConstantTimeEq,
+ VL: ConstantTimeEq,
+{
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.inner_correction_words
+ .ct_eq(&other.inner_correction_words)
+ & self.leaf_correction_word.ct_eq(&other.leaf_correction_word)
+ }
+}
+
+impl<VI, VL> PartialEq for IdpfPublicShare<VI, VL>
+where
+ VI: ConstantTimeEq,
+ VL: ConstantTimeEq,
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<VI, VL> Eq for IdpfPublicShare<VI, VL>
+where
+ VI: ConstantTimeEq,
+ VL: ConstantTimeEq,
+{
+}
+
+impl<VI, VL> Encode for IdpfPublicShare<VI, VL>
+where
+ VI: Encode,
+ VL: Encode,
+{
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // Control bits need to be written within each byte in LSB-to-MSB order, and assigned into
+ // bytes in big-endian order. Thus, the first four levels will have their control bits
+ // encoded in the last byte, and the last levels will have their control bits encoded in the
+ // first byte.
+ let mut control_bits: BitVec<u8, Lsb0> =
+ BitVec::with_capacity(self.inner_correction_words.len() * 2 + 2);
+ for correction_words in self.inner_correction_words.iter() {
+ control_bits.extend(correction_words.control_bits.iter().map(|x| bool::from(*x)));
+ }
+ control_bits.extend(
+ self.leaf_correction_word
+ .control_bits
+ .iter()
+ .map(|x| bool::from(*x)),
+ );
+ control_bits.set_uninitialized(false);
+ let mut packed_control = control_bits.into_vec();
+ bytes.append(&mut packed_control);
+
+ for correction_words in self.inner_correction_words.iter() {
+ Seed(correction_words.seed).encode(bytes);
+ correction_words.value.encode(bytes);
+ }
+ Seed(self.leaf_correction_word.seed).encode(bytes);
+ self.leaf_correction_word.value.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ let control_bits_count = (self.inner_correction_words.len() + 1) * 2;
+ let mut len = (control_bits_count + 7) / 8 + (self.inner_correction_words.len() + 1) * 16;
+ for correction_words in self.inner_correction_words.iter() {
+ len += correction_words.value.encoded_len()?;
+ }
+ len += self.leaf_correction_word.value.encoded_len()?;
+ Some(len)
+ }
+}
+
+impl<VI, VL> ParameterizedDecode<usize> for IdpfPublicShare<VI, VL>
+where
+ VI: Decode,
+ VL: Decode,
+{
+ fn decode_with_param(bits: &usize, bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let packed_control_len = (bits + 3) / 4;
+ let mut packed = vec![0u8; packed_control_len];
+ bytes.read_exact(&mut packed)?;
+ let unpacked_control_bits: BitVec<u8, Lsb0> = BitVec::from_vec(packed);
+
+ let mut inner_correction_words = Vec::with_capacity(bits - 1);
+ for chunk in unpacked_control_bits[0..(bits - 1) * 2].chunks(2) {
+ let control_bits = [(chunk[0] as u8).into(), (chunk[1] as u8).into()];
+ let seed = Seed::decode(bytes)?.0;
+ let value = VI::decode(bytes)?;
+ inner_correction_words.push(IdpfCorrectionWord {
+ seed,
+ control_bits,
+ value,
+ })
+ }
+
+ let control_bits = [
+ (unpacked_control_bits[(bits - 1) * 2] as u8).into(),
+ (unpacked_control_bits[bits * 2 - 1] as u8).into(),
+ ];
+ let seed = Seed::decode(bytes)?.0;
+ let value = VL::decode(bytes)?;
+ let leaf_correction_word = IdpfCorrectionWord {
+ seed,
+ control_bits,
+ value,
+ };
+
+ // Check that unused packed bits are zero.
+ if unpacked_control_bits[bits * 2..].any() {
+ return Err(CodecError::UnexpectedValue);
+ }
+
+ Ok(IdpfPublicShare {
+ inner_correction_words,
+ leaf_correction_word,
+ })
+ }
+}
+
+#[derive(Debug, Clone)]
+struct IdpfCorrectionWord<V> {
+ seed: [u8; 16],
+ control_bits: [Choice; 2],
+ value: V,
+}
+
+impl<V> ConstantTimeEq for IdpfCorrectionWord<V>
+where
+ V: ConstantTimeEq,
+{
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.seed.ct_eq(&other.seed)
+ & self.control_bits.ct_eq(&other.control_bits)
+ & self.value.ct_eq(&other.value)
+ }
+}
+
+impl<V> PartialEq for IdpfCorrectionWord<V>
+where
+ V: ConstantTimeEq,
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<V> Eq for IdpfCorrectionWord<V> where V: ConstantTimeEq {}
+
+fn xor_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] {
+ let mut seed = [0u8; 16];
+ for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) {
+ *c = a ^ b;
+ }
+ seed
+}
+
+fn and_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] {
+ let mut seed = [0u8; 16];
+ for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) {
+ *c = a & b;
+ }
+ seed
+}
+
+fn or_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] {
+ let mut seed = [0u8; 16];
+ for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) {
+ *c = a | b;
+ }
+ seed
+}
+
+/// Take a control bit, and fan it out into a byte array that can be used as a mask for XOF seeds,
+/// without branching. If the control bit input is 0, all bytes will be equal to 0, and if the
+/// control bit input is 1, all bytes will be equal to 255.
+fn control_bit_to_seed_mask(control: Choice) -> [u8; 16] {
+ let mask = -(control.unwrap_u8() as i8) as u8;
+ [mask; 16]
+}
+
+/// Take two seeds and a control bit, and return the first seed if the control bit is zero, or the
+/// XOR of the two seeds if the control bit is one. This does not branch on the control bit.
+fn conditional_xor_seeds(
+ normal_input: &[u8; 16],
+ switched_input: &[u8; 16],
+ control: Choice,
+) -> [u8; 16] {
+ xor_seeds(
+ normal_input,
+ &and_seeds(switched_input, &control_bit_to_seed_mask(control)),
+ )
+}
+
+/// Returns one of two seeds, depending on the value of a selector bit. Does not branch on the
+/// selector input or make selector-dependent memory accesses.
+fn conditional_select_seed(select: Choice, seeds: &[[u8; 16]; 2]) -> [u8; 16] {
+ or_seeds(
+ &and_seeds(&control_bit_to_seed_mask(!select), &seeds[0]),
+ &and_seeds(&control_bit_to_seed_mask(select), &seeds[1]),
+ )
+}
+
+/// An interface that provides memoization of IDPF computations.
+///
+/// Each instance of a type implementing `IdpfCache` should only be used with one IDPF key and
+/// public share.
+///
+/// In typical use, IDPFs will be evaluated repeatedly on inputs of increasing length, as part of a
+/// protocol executed by multiple participants. Each IDPF evaluation computes keys and control
+/// bits corresponding to tree nodes along a path determined by the input to the IDPF. Thus, the
+/// values from nodes further up in the tree may be cached and reused in evaluations of subsequent
+/// longer inputs. If one IDPF input is a prefix of another input, then the first input's path down
+/// the tree is a prefix of the other input's path.
+pub trait IdpfCache {
+ /// Fetch cached values for the node identified by the IDPF input.
+ fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)>;
+
+ /// Store values corresponding to the node identified by the IDPF input.
+ fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8));
+}
+
+/// A no-op [`IdpfCache`] implementation that always reports a cache miss.
+#[derive(Default)]
+pub struct NoCache {}
+
+impl NoCache {
+ /// Construct a `NoCache` object.
+ pub fn new() -> NoCache {
+ NoCache::default()
+ }
+}
+
+impl IdpfCache for NoCache {
+ fn get(&self, _: &BitSlice) -> Option<([u8; 16], u8)> {
+ None
+ }
+
+ fn insert(&mut self, _: &BitSlice, _: &([u8; 16], u8)) {}
+}
+
+/// A simple [`IdpfCache`] implementation that caches intermediate results in an in-memory hash map,
+/// with no eviction.
+#[derive(Default)]
+pub struct HashMapCache {
+ map: HashMap<BitBox, ([u8; 16], u8)>,
+}
+
+impl HashMapCache {
+ /// Create a new unpopulated `HashMapCache`.
+ pub fn new() -> HashMapCache {
+ HashMapCache::default()
+ }
+
+ /// Create a new unpopulated `HashMapCache`, with a set pre-allocated capacity.
+ pub fn with_capacity(capacity: usize) -> HashMapCache {
+ Self {
+ map: HashMap::with_capacity(capacity),
+ }
+ }
+}
+
+impl IdpfCache for HashMapCache {
+ fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> {
+ self.map.get(input).cloned()
+ }
+
+ fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) {
+ if !self.map.contains_key(input) {
+ self.map
+ .insert(input.to_owned().into_boxed_bitslice(), *values);
+ }
+ }
+}
+
+/// A simple [`IdpfCache`] implementation that caches intermediate results in memory, with
+/// first-in-first-out eviction, and lookups via linear probing.
+pub struct RingBufferCache {
+ ring: VecDeque<(BitBox, [u8; 16], u8)>,
+}
+
+impl RingBufferCache {
+ /// Create a new unpopulated `RingBufferCache`.
+ pub fn new(capacity: usize) -> RingBufferCache {
+ Self {
+ ring: VecDeque::with_capacity(std::cmp::max(capacity, 1)),
+ }
+ }
+}
+
+impl IdpfCache for RingBufferCache {
+ fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> {
+ // iterate back-to-front, so that we check the most recently pushed entry first.
+ for entry in self.ring.iter().rev() {
+ if input == entry.0 {
+ return Some((entry.1, entry.2));
+ }
+ }
+ None
+ }
+
+ fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) {
+ // evict first (to avoid growing the storage)
+ if self.ring.len() == self.ring.capacity() {
+ self.ring.pop_front();
+ }
+ self.ring
+ .push_back((input.to_owned().into_boxed_bitslice(), values.0, values.1));
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{
+ collections::HashMap,
+ convert::{TryFrom, TryInto},
+ io::Cursor,
+ ops::{Add, AddAssign, Sub},
+ str::FromStr,
+ sync::Mutex,
+ };
+
+ use assert_matches::assert_matches;
+ use bitvec::{
+ bitbox,
+ prelude::{BitBox, Lsb0},
+ slice::BitSlice,
+ vec::BitVec,
+ };
+ use num_bigint::BigUint;
+ use rand::random;
+ use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable};
+
+ use super::{
+ HashMapCache, Idpf, IdpfCache, IdpfCorrectionWord, IdpfInput, IdpfOutputShare,
+ IdpfPublicShare, NoCache, RingBufferCache,
+ };
+ use crate::{
+ codec::{
+ decode_u32_items, encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode,
+ },
+ field::{Field128, Field255, Field64, FieldElement},
+ prng::Prng,
+ vdaf::{poplar1::Poplar1IdpfValue, xof::Seed},
+ };
+
+ #[test]
+ fn idpf_input_conversion() {
+ let input_1 = IdpfInput::from_bools(&[
+ false, true, false, false, false, false, false, true, false, true, false, false, false,
+ false, true, false,
+ ]);
+ let input_2 = IdpfInput::from_bytes(b"AB");
+ assert_eq!(input_1, input_2);
+ let bits = bitbox![0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0];
+ assert_eq!(input_1[..], bits);
+ }
+
+ /// A lossy IDPF cache, for testing purposes, that randomly returns cache misses.
+ #[derive(Default)]
+ struct LossyCache {
+ map: HashMap<BitBox, ([u8; 16], u8)>,
+ }
+
+ impl LossyCache {
+ /// Create a new unpopulated `LossyCache`.
+ fn new() -> LossyCache {
+ LossyCache::default()
+ }
+ }
+
+ impl IdpfCache for LossyCache {
+ fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> {
+ if random() {
+ self.map.get(input).cloned()
+ } else {
+ None
+ }
+ }
+
+ fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) {
+ if !self.map.contains_key(input) {
+ self.map
+ .insert(input.to_owned().into_boxed_bitslice(), *values);
+ }
+ }
+ }
+
+ /// A wrapper [`IdpfCache`] implementation that records `get()` calls, for testing purposes.
+ struct SnoopingCache<T> {
+ inner: T,
+ get_calls: Mutex<Vec<BitBox>>,
+ insert_calls: Mutex<Vec<(BitBox, [u8; 16], u8)>>,
+ }
+
+ impl<T> SnoopingCache<T> {
+ fn new(inner: T) -> SnoopingCache<T> {
+ SnoopingCache {
+ inner,
+ get_calls: Mutex::new(Vec::new()),
+ insert_calls: Mutex::new(Vec::new()),
+ }
+ }
+ }
+
+ impl<T> IdpfCache for SnoopingCache<T>
+ where
+ T: IdpfCache,
+ {
+ fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> {
+ self.get_calls
+ .lock()
+ .unwrap()
+ .push(input.to_owned().into_boxed_bitslice());
+ self.inner.get(input)
+ }
+
+ fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) {
+ self.insert_calls.lock().unwrap().push((
+ input.to_owned().into_boxed_bitslice(),
+ values.0,
+ values.1,
+ ));
+ self.inner.insert(input, values)
+ }
+ }
+
+ #[test]
+ fn test_idpf_poplar() {
+ let input = bitbox![0, 1, 1, 0, 1].into();
+ let nonce: [u8; 16] = random();
+ let idpf = Idpf::new((), ());
+ let (public_share, keys) = idpf
+ .gen(
+ &input,
+ Vec::from([Poplar1IdpfValue::new([Field64::one(), Field64::one()]); 4]),
+ Poplar1IdpfValue::new([Field255::one(), Field255::one()]),
+ &nonce,
+ )
+ .unwrap();
+
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![1].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 1].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 0].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![1, 0].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![1, 1].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 1, 1].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 1, 1, 0].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 1, 1, 0, 1].into(),
+ &nonce,
+ &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::one(), Field255::one()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 1, 1, 0, 0].into(),
+ &nonce,
+ &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::zero(), Field255::zero()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![1, 0, 1, 0, 0].into(),
+ &nonce,
+ &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::zero(), Field255::zero()])),
+ &mut NoCache::new(),
+ &mut NoCache::new(),
+ );
+ }
+
+ fn check_idpf_poplar_evaluation(
+ public_share: &IdpfPublicShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>,
+ keys: &[Seed<16>; 2],
+ prefix: &IdpfInput,
+ binder: &[u8],
+ expected_output: &IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>,
+ cache_0: &mut dyn IdpfCache,
+ cache_1: &mut dyn IdpfCache,
+ ) {
+ let idpf = Idpf::new((), ());
+ let share_0 = idpf
+ .eval(0, public_share, &keys[0], prefix, binder, cache_0)
+ .unwrap();
+ let share_1 = idpf
+ .eval(1, public_share, &keys[1], prefix, binder, cache_1)
+ .unwrap();
+ let output = share_0.merge(share_1).unwrap();
+ assert_eq!(&output, expected_output);
+ }
+
+ #[test]
+ fn test_idpf_poplar_medium() {
+ // This test on 40 byte inputs takes about a second in debug mode. (and ten milliseconds in
+ // release mode)
+ const INPUT_LEN: usize = 320;
+ let mut bits = bitbox![0; INPUT_LEN];
+ for mut bit in bits.iter_mut() {
+ bit.set(random());
+ }
+ let input = bits.clone().into();
+
+ let mut inner_values = Vec::with_capacity(INPUT_LEN - 1);
+ let mut prng = Prng::new().unwrap();
+ for _ in 0..INPUT_LEN - 1 {
+ inner_values.push(Poplar1IdpfValue::new([
+ Field64::one(),
+ prng.next().unwrap(),
+ ]));
+ }
+ let leaf_values =
+ Poplar1IdpfValue::new([Field255::one(), Prng::new().unwrap().next().unwrap()]);
+
+ let nonce: [u8; 16] = random();
+ let idpf = Idpf::new((), ());
+ let (public_share, keys) = idpf
+ .gen(&input, inner_values.clone(), leaf_values, &nonce)
+ .unwrap();
+ let mut cache_0 = RingBufferCache::new(3);
+ let mut cache_1 = RingBufferCache::new(3);
+
+ for (level, values) in inner_values.iter().enumerate() {
+ let mut prefix = BitBox::from_bitslice(&bits[..=level]).into();
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &prefix,
+ &nonce,
+ &IdpfOutputShare::Inner(*values),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ let flipped_bit = !prefix[level];
+ prefix.index.set(level, flipped_bit);
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &prefix,
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ }
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &input,
+ &nonce,
+ &IdpfOutputShare::Leaf(leaf_values),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ let mut modified_bits = bits.clone();
+ modified_bits.set(INPUT_LEN - 1, !bits[INPUT_LEN - 1]);
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &modified_bits.into(),
+ &nonce,
+ &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::zero(), Field255::zero()])),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ }
+
+ #[test]
+ fn idpf_poplar_cache_behavior() {
+ let bits = bitbox![0, 1, 1, 1, 0, 1, 0, 0];
+ let input = bits.into();
+
+ let mut inner_values = Vec::with_capacity(7);
+ let mut prng = Prng::new().unwrap();
+ for _ in 0..7 {
+ inner_values.push(Poplar1IdpfValue::new([
+ Field64::one(),
+ prng.next().unwrap(),
+ ]));
+ }
+ let leaf_values =
+ Poplar1IdpfValue::new([Field255::one(), Prng::new().unwrap().next().unwrap()]);
+
+ let nonce: [u8; 16] = random();
+ let idpf = Idpf::new((), ());
+ let (public_share, keys) = idpf
+ .gen(&input, inner_values.clone(), leaf_values, &nonce)
+ .unwrap();
+ let mut cache_0 = SnoopingCache::new(HashMapCache::new());
+ let mut cache_1 = HashMapCache::new();
+
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![1, 1, 0, 0].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ assert_eq!(
+ cache_0
+ .get_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .collect::<Vec<_>>(),
+ vec![bitbox![1, 1, 0], bitbox![1, 1], bitbox![1]],
+ );
+ assert_eq!(
+ cache_0
+ .insert_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .map(|(input, _, _)| input)
+ .collect::<Vec<_>>(),
+ vec![
+ bitbox![1],
+ bitbox![1, 1],
+ bitbox![1, 1, 0],
+ bitbox![1, 1, 0, 0]
+ ],
+ );
+
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(inner_values[0]),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ assert_eq!(
+ cache_0
+ .get_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .collect::<Vec<BitBox>>(),
+ Vec::<BitBox>::new(),
+ );
+ assert_eq!(
+ cache_0
+ .insert_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .map(|(input, _, _)| input)
+ .collect::<Vec<_>>(),
+ vec![bitbox![0]],
+ );
+
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &bitbox![0, 1].into(),
+ &nonce,
+ &IdpfOutputShare::Inner(inner_values[1]),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ assert_eq!(
+ cache_0
+ .get_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .collect::<Vec<_>>(),
+ vec![bitbox![0]],
+ );
+ assert_eq!(
+ cache_0
+ .insert_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .map(|(input, _, _)| input)
+ .collect::<Vec<_>>(),
+ vec![bitbox![0, 1]],
+ );
+
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &input,
+ &nonce,
+ &IdpfOutputShare::Leaf(leaf_values),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ assert_eq!(
+ cache_0
+ .get_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .collect::<Vec<_>>(),
+ vec![
+ bitbox![0, 1, 1, 1, 0, 1, 0],
+ bitbox![0, 1, 1, 1, 0, 1],
+ bitbox![0, 1, 1, 1, 0],
+ bitbox![0, 1, 1, 1],
+ bitbox![0, 1, 1],
+ bitbox![0, 1],
+ ],
+ );
+ assert_eq!(
+ cache_0
+ .insert_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .map(|(input, _, _)| input)
+ .collect::<Vec<_>>(),
+ vec![
+ bitbox![0, 1, 1],
+ bitbox![0, 1, 1, 1],
+ bitbox![0, 1, 1, 1, 0],
+ bitbox![0, 1, 1, 1, 0, 1],
+ bitbox![0, 1, 1, 1, 0, 1, 0],
+ ],
+ );
+
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &input,
+ &nonce,
+ &IdpfOutputShare::Leaf(leaf_values),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ assert_eq!(
+ cache_0
+ .get_calls
+ .lock()
+ .unwrap()
+ .drain(..)
+ .collect::<Vec<_>>(),
+ vec![bitbox![0, 1, 1, 1, 0, 1, 0]],
+ );
+ assert!(cache_0.insert_calls.lock().unwrap().is_empty());
+ }
+
+ #[test]
+ fn idpf_poplar_lossy_cache() {
+ let bits = bitbox![1, 0, 0, 1, 1, 0, 1, 0];
+ let input = bits.into();
+
+ let mut inner_values = Vec::with_capacity(7);
+ let mut prng = Prng::new().unwrap();
+ for _ in 0..7 {
+ inner_values.push(Poplar1IdpfValue::new([
+ Field64::one(),
+ prng.next().unwrap(),
+ ]));
+ }
+ let leaf_values =
+ Poplar1IdpfValue::new([Field255::one(), Prng::new().unwrap().next().unwrap()]);
+
+ let nonce: [u8; 16] = random();
+ let idpf = Idpf::new((), ());
+ let (public_share, keys) = idpf
+ .gen(&input, inner_values.clone(), leaf_values, &nonce)
+ .unwrap();
+ let mut cache_0 = LossyCache::new();
+ let mut cache_1 = LossyCache::new();
+
+ for (level, values) in inner_values.iter().enumerate() {
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &input[..=level].to_owned().into(),
+ &nonce,
+ &IdpfOutputShare::Inner(*values),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ }
+ check_idpf_poplar_evaluation(
+ &public_share,
+ &keys,
+ &input,
+ &nonce,
+ &IdpfOutputShare::Leaf(leaf_values),
+ &mut cache_0,
+ &mut cache_1,
+ );
+ }
+
+ #[test]
+ fn test_idpf_poplar_error_cases() {
+ let nonce: [u8; 16] = random();
+ let idpf = Idpf::new((), ());
+ // Zero bits does not make sense.
+ idpf.gen(
+ &bitbox![].into(),
+ Vec::<Poplar1IdpfValue<Field64>>::new(),
+ Poplar1IdpfValue::new([Field255::zero(); 2]),
+ &nonce,
+ )
+ .unwrap_err();
+
+ let (public_share, keys) = idpf
+ .gen(
+ &bitbox![0;10].into(),
+ Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 9]),
+ Poplar1IdpfValue::new([Field255::zero(); 2]),
+ &nonce,
+ )
+ .unwrap();
+
+ // Wrong number of values.
+ idpf.gen(
+ &bitbox![0; 10].into(),
+ Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 8]),
+ Poplar1IdpfValue::new([Field255::zero(); 2]),
+ &nonce,
+ )
+ .unwrap_err();
+ idpf.gen(
+ &bitbox![0; 10].into(),
+ Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 10]),
+ Poplar1IdpfValue::new([Field255::zero(); 2]),
+ &nonce,
+ )
+ .unwrap_err();
+
+ // Evaluating with empty prefix.
+ assert!(idpf
+ .eval(
+ 0,
+ &public_share,
+ &keys[0],
+ &bitbox![].into(),
+ &nonce,
+ &mut NoCache::new(),
+ )
+ .is_err());
+ // Evaluating with too-long prefix.
+ assert!(idpf
+ .eval(
+ 0,
+ &public_share,
+ &keys[0],
+ &bitbox![0; 11].into(),
+ &nonce,
+ &mut NoCache::new(),
+ )
+ .is_err());
+ }
+
+ #[test]
+ fn idpf_poplar_public_share_round_trip() {
+ let public_share = IdpfPublicShare {
+ inner_correction_words: Vec::from([
+ IdpfCorrectionWord {
+ seed: [0xab; 16],
+ control_bits: [Choice::from(1), Choice::from(0)],
+ value: Poplar1IdpfValue::new([
+ Field64::try_from(83261u64).unwrap(),
+ Field64::try_from(125159u64).unwrap(),
+ ]),
+ },
+ IdpfCorrectionWord{
+ seed: [0xcd;16],
+ control_bits: [Choice::from(0), Choice::from(1)],
+ value: Poplar1IdpfValue::new([
+ Field64::try_from(17614120u64).unwrap(),
+ Field64::try_from(20674u64).unwrap(),
+ ]),
+ },
+ ]),
+ leaf_correction_word: IdpfCorrectionWord {
+ seed: [0xff; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([
+ Field255::one(),
+ Field255::get_decoded(
+ b"\xf0\xde\xbc\x9a\x78\x56\x34\x12\xf0\xde\xbc\x9a\x78\x56\x34\x12\xf0\xde\xbc\x9a\x78\x56\x34\x12\xf0\xde\xbc\x9a\x78\x56\x34\x12", // field element correction word, continued
+ ).unwrap(),
+ ]),
+ },
+ };
+ let message = hex::decode(concat!(
+ "39", // packed control bit correction words (0b00111001)
+ "abababababababababababababababab", // seed correction word, first level
+ "3d45010000000000", // field element correction word
+ "e7e8010000000000", // field element correction word, continued
+ "cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd", // seed correction word, second level
+ "28c50c0100000000", // field element correction word
+ "c250000000000000", // field element correction word, continued
+ "ffffffffffffffffffffffffffffffff", // seed correction word, third level
+ "0100000000000000000000000000000000000000000000000000000000000000", // field element correction word, leaf field
+ "f0debc9a78563412f0debc9a78563412f0debc9a78563412f0debc9a78563412", // field element correction word, continued
+ ))
+ .unwrap();
+ let encoded = public_share.get_encoded();
+ let decoded = IdpfPublicShare::get_decoded_with_param(&3, &message).unwrap();
+ assert_eq!(public_share, decoded);
+ assert_eq!(message, encoded);
+ assert_eq!(public_share.encoded_len().unwrap(), encoded.len());
+
+ // check serialization of packed control bits when they span multiple bytes:
+ let public_share = IdpfPublicShare {
+ inner_correction_words: Vec::from([
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(0)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(0), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(1), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]),
+ },
+ ]),
+ leaf_correction_word: IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(0), Choice::from(1)],
+ value: Poplar1IdpfValue::new([Field255::zero(), Field255::zero()]),
+ },
+ };
+ let message = hex::decode(concat!(
+ "dffb02", // packed correction word control bits: 0b11011111, 0b11111011, 0b10
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000",
+ "0000000000000000",
+ "00000000000000000000000000000000",
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ ))
+ .unwrap();
+ let encoded = public_share.get_encoded();
+ let decoded = IdpfPublicShare::get_decoded_with_param(&9, &message).unwrap();
+ assert_eq!(public_share, decoded);
+ assert_eq!(message, encoded);
+ }
+
+ #[test]
+ fn idpf_poplar_public_share_control_bit_codec() {
+ let test_cases = [
+ (&[false, true][..], &[0b10][..]),
+ (
+ &[false, false, true, false, false, true][..],
+ &[0b10_0100u8][..],
+ ),
+ (
+ &[
+ true, true, false, true, false, false, false, false, true, true,
+ ][..],
+ &[0b0000_1011, 0b11][..],
+ ),
+ (
+ &[
+ true, true, false, true, false, true, true, true, false, true, false, true,
+ false, false, true, false,
+ ][..],
+ &[0b1110_1011, 0b0100_1010][..],
+ ),
+ (
+ &[
+ true, true, true, true, true, false, true, true, false, true, true, true,
+ false, true, false, true, false, false, true, false, true, true,
+ ][..],
+ &[0b1101_1111, 0b1010_1110, 0b11_0100][..],
+ ),
+ ];
+
+ for (control_bits, serialized_control_bits) in test_cases {
+ let public_share = IdpfPublicShare::<
+ Poplar1IdpfValue<Field64>,
+ Poplar1IdpfValue<Field255>,
+ > {
+ inner_correction_words: control_bits[..control_bits.len() - 2]
+ .chunks(2)
+ .map(|chunk| IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [Choice::from(chunk[0] as u8), Choice::from(chunk[1] as u8)],
+ value: Poplar1IdpfValue::new([Field64::zero(); 2]),
+ })
+ .collect(),
+ leaf_correction_word: IdpfCorrectionWord {
+ seed: [0; 16],
+ control_bits: [
+ Choice::from(control_bits[control_bits.len() - 2] as u8),
+ Choice::from(control_bits[control_bits.len() - 1] as u8),
+ ],
+ value: Poplar1IdpfValue::new([Field255::zero(); 2]),
+ },
+ };
+
+ let mut serialized_public_share = serialized_control_bits.to_owned();
+ let idpf_bits = control_bits.len() / 2;
+ let size_seeds = 16 * idpf_bits;
+ let size_field_vecs =
+ Field64::ENCODED_SIZE * 2 * (idpf_bits - 1) + Field255::ENCODED_SIZE * 2;
+ serialized_public_share.resize(
+ serialized_control_bits.len() + size_seeds + size_field_vecs,
+ 0,
+ );
+
+ assert_eq!(public_share.get_encoded(), serialized_public_share);
+ assert_eq!(
+ IdpfPublicShare::get_decoded_with_param(&idpf_bits, &serialized_public_share)
+ .unwrap(),
+ public_share
+ );
+ }
+ }
+
+ #[test]
+ fn idpf_poplar_public_share_unused_bits() {
+ let mut buf = vec![0u8; 4096];
+
+ buf[0] = 1 << 2;
+ let err =
+ IdpfPublicShare::<Field64, Field255>::decode_with_param(&1, &mut Cursor::new(&buf))
+ .unwrap_err();
+ assert_matches!(err, CodecError::UnexpectedValue);
+
+ buf[0] = 1 << 4;
+ let err =
+ IdpfPublicShare::<Field64, Field255>::decode_with_param(&2, &mut Cursor::new(&buf))
+ .unwrap_err();
+ assert_matches!(err, CodecError::UnexpectedValue);
+
+ buf[0] = 1 << 6;
+ let err =
+ IdpfPublicShare::<Field64, Field255>::decode_with_param(&3, &mut Cursor::new(&buf))
+ .unwrap_err();
+ assert_matches!(err, CodecError::UnexpectedValue);
+
+ buf[0] = 0;
+ buf[1] = 1 << 2;
+ let err =
+ IdpfPublicShare::<Field64, Field255>::decode_with_param(&5, &mut Cursor::new(&buf))
+ .unwrap_err();
+ assert_matches!(err, CodecError::UnexpectedValue);
+ }
+
+ /// Stores a test vector for the IDPF key generation algorithm.
+ struct IdpfTestVector {
+ /// The number of bits in IDPF inputs.
+ bits: usize,
+ /// The binder string used when generating and evaluating keys.
+ binder: Vec<u8>,
+ /// The IDPF input provided to the key generation algorithm.
+ alpha: IdpfInput,
+ /// The IDPF output values, at each inner level, provided to the key generation algorithm.
+ beta_inner: Vec<Poplar1IdpfValue<Field64>>,
+ /// The IDPF output values for the leaf level, provided to the key generation algorithm.
+ beta_leaf: Poplar1IdpfValue<Field255>,
+ /// The two keys returned by the key generation algorithm.
+ keys: [[u8; 16]; 2],
+ /// The public share returned by the key generation algorithm.
+ public_share: Vec<u8>,
+ }
+
+ /// Load a test vector for Idpf key generation.
+ fn load_idpfpoplar_test_vector() -> IdpfTestVector {
+ let test_vec: serde_json::Value =
+ serde_json::from_str(include_str!("vdaf/test_vec/07/IdpfPoplar_0.json")).unwrap();
+ let test_vec_obj = test_vec.as_object().unwrap();
+
+ let bits = test_vec_obj
+ .get("bits")
+ .unwrap()
+ .as_u64()
+ .unwrap()
+ .try_into()
+ .unwrap();
+
+ let alpha_str = test_vec_obj.get("alpha").unwrap().as_str().unwrap();
+ let alpha_bignum = BigUint::from_str(alpha_str).unwrap();
+ let zero_bignum = BigUint::from(0u8);
+ let one_bignum = BigUint::from(1u8);
+ let alpha_bits = (0..bits)
+ .map(|level| (&alpha_bignum >> (bits - level - 1)) & &one_bignum != zero_bignum)
+ .collect::<BitVec>();
+ let alpha = alpha_bits.into();
+
+ let beta_inner_level_array = test_vec_obj.get("beta_inner").unwrap().as_array().unwrap();
+ let beta_inner = beta_inner_level_array
+ .iter()
+ .map(|array| {
+ Poplar1IdpfValue::new([
+ Field64::from(array[0].as_str().unwrap().parse::<u64>().unwrap()),
+ Field64::from(array[1].as_str().unwrap().parse::<u64>().unwrap()),
+ ])
+ })
+ .collect::<Vec<_>>();
+
+ let beta_leaf_array = test_vec_obj.get("beta_leaf").unwrap().as_array().unwrap();
+ let beta_leaf = Poplar1IdpfValue::new([
+ Field255::from(
+ beta_leaf_array[0]
+ .as_str()
+ .unwrap()
+ .parse::<BigUint>()
+ .unwrap(),
+ ),
+ Field255::from(
+ beta_leaf_array[1]
+ .as_str()
+ .unwrap()
+ .parse::<BigUint>()
+ .unwrap(),
+ ),
+ ]);
+
+ let keys_array = test_vec_obj.get("keys").unwrap().as_array().unwrap();
+ let keys = [
+ hex::decode(keys_array[0].as_str().unwrap())
+ .unwrap()
+ .try_into()
+ .unwrap(),
+ hex::decode(keys_array[1].as_str().unwrap())
+ .unwrap()
+ .try_into()
+ .unwrap(),
+ ];
+
+ let public_share_hex = test_vec_obj.get("public_share").unwrap();
+ let public_share = hex::decode(public_share_hex.as_str().unwrap()).unwrap();
+
+ let binder_hex = test_vec_obj.get("binder").unwrap();
+ let binder = hex::decode(binder_hex.as_str().unwrap()).unwrap();
+
+ IdpfTestVector {
+ bits,
+ binder,
+ alpha,
+ beta_inner,
+ beta_leaf,
+ keys,
+ public_share,
+ }
+ }
+
+ #[test]
+ fn idpf_poplar_generate_test_vector() {
+ let test_vector = load_idpfpoplar_test_vector();
+ let idpf = Idpf::new((), ());
+ let (public_share, keys) = idpf
+ .gen_with_random(
+ &test_vector.alpha,
+ test_vector.beta_inner,
+ test_vector.beta_leaf,
+ &test_vector.binder,
+ &test_vector.keys,
+ )
+ .unwrap();
+
+ assert_eq!(keys[0].0, test_vector.keys[0]);
+ assert_eq!(keys[1].0, test_vector.keys[1]);
+
+ let expected_public_share =
+ IdpfPublicShare::get_decoded_with_param(&test_vector.bits, &test_vector.public_share)
+ .unwrap();
+ for (level, (correction_words, expected_correction_words)) in public_share
+ .inner_correction_words
+ .iter()
+ .zip(expected_public_share.inner_correction_words.iter())
+ .enumerate()
+ {
+ assert_eq!(
+ correction_words, expected_correction_words,
+ "layer {level} did not match\n{correction_words:#x?}\n{expected_correction_words:#x?}"
+ );
+ }
+ assert_eq!(
+ public_share.leaf_correction_word,
+ expected_public_share.leaf_correction_word
+ );
+
+ assert_eq!(
+ public_share, expected_public_share,
+ "public share did not match\n{public_share:#x?}\n{expected_public_share:#x?}"
+ );
+ let encoded_public_share = public_share.get_encoded();
+ assert_eq!(encoded_public_share, test_vector.public_share);
+ }
+
+ #[test]
+ fn idpf_input_from_bytes_to_bytes() {
+ let test_cases: &[&[u8]] = &[b"hello", b"banana", &[1], &[127], &[1, 2, 3, 4], &[]];
+ for test_case in test_cases {
+ assert_eq!(&IdpfInput::from_bytes(test_case).to_bytes(), test_case);
+ }
+ }
+
+ #[test]
+ fn idpf_input_from_bools_to_bytes() {
+ let input = IdpfInput::from_bools(&[true; 7]);
+ assert_eq!(input.to_bytes(), &[254]);
+ let input = IdpfInput::from_bools(&[true; 9]);
+ assert_eq!(input.to_bytes(), &[255, 128]);
+ }
+
+ /// Demonstrate use of an IDPF with values that need run-time parameters for random generation.
+ #[test]
+ fn idpf_with_value_parameters() {
+ use super::IdpfValue;
+
+ /// A test-only type for use as an [`IdpfValue`].
+ #[derive(Debug, Clone, Copy)]
+ struct MyUnit;
+
+ impl IdpfValue for MyUnit {
+ type ValueParameter = ();
+
+ fn generate<S>(_: &mut S, _: &Self::ValueParameter) -> Self
+ where
+ S: rand_core::RngCore,
+ {
+ MyUnit
+ }
+
+ fn zero(_: &()) -> Self {
+ MyUnit
+ }
+
+ fn conditional_select(_: &Self, _: &Self, _: Choice) -> Self {
+ MyUnit
+ }
+ }
+
+ impl Encode for MyUnit {
+ fn encode(&self, _: &mut Vec<u8>) {}
+ }
+
+ impl Decode for MyUnit {
+ fn decode(_: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(MyUnit)
+ }
+ }
+
+ impl ConditionallySelectable for MyUnit {
+ fn conditional_select(_: &Self, _: &Self, _: Choice) -> Self {
+ MyUnit
+ }
+ }
+
+ impl ConditionallyNegatable for MyUnit {
+ fn conditional_negate(&mut self, _: Choice) {}
+ }
+
+ impl Add for MyUnit {
+ type Output = Self;
+
+ fn add(self, _: Self) -> Self::Output {
+ MyUnit
+ }
+ }
+
+ impl AddAssign for MyUnit {
+ fn add_assign(&mut self, _: Self) {}
+ }
+
+ impl Sub for MyUnit {
+ type Output = Self;
+
+ fn sub(self, _: Self) -> Self::Output {
+ MyUnit
+ }
+ }
+
+ /// A test-only type for use as an [`IdpfValue`], representing a variable-length vector of
+ /// field elements. The length must be fixed before generating IDPF keys, but we assume it
+ /// is not known at compile time.
+ #[derive(Debug, Clone)]
+ struct MyVector(Vec<Field128>);
+
+ impl IdpfValue for MyVector {
+ type ValueParameter = usize;
+
+ fn generate<S>(seed_stream: &mut S, length: &Self::ValueParameter) -> Self
+ where
+ S: rand_core::RngCore,
+ {
+ let mut output = vec![<Field128 as FieldElement>::zero(); *length];
+ for element in output.iter_mut() {
+ *element = <Field128 as IdpfValue>::generate(seed_stream, &());
+ }
+ MyVector(output)
+ }
+
+ fn zero(length: &usize) -> Self {
+ MyVector(vec![<Field128 as FieldElement>::zero(); *length])
+ }
+
+ fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
+ debug_assert_eq!(a.0.len(), b.0.len());
+ let mut output = vec![<Field128 as FieldElement>::zero(); a.0.len()];
+ for ((a_elem, b_elem), output_elem) in
+ a.0.iter().zip(b.0.iter()).zip(output.iter_mut())
+ {
+ *output_elem = <Field128 as ConditionallySelectable>::conditional_select(
+ a_elem, b_elem, choice,
+ );
+ }
+ MyVector(output)
+ }
+ }
+
+ impl Encode for MyVector {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ encode_u32_items(bytes, &(), &self.0);
+ }
+ }
+
+ impl Decode for MyVector {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ decode_u32_items(&(), bytes).map(MyVector)
+ }
+ }
+
+ impl ConditionallyNegatable for MyVector {
+ fn conditional_negate(&mut self, choice: Choice) {
+ for element in self.0.iter_mut() {
+ element.conditional_negate(choice);
+ }
+ }
+ }
+
+ impl Add for MyVector {
+ type Output = Self;
+
+ fn add(self, rhs: Self) -> Self::Output {
+ debug_assert_eq!(self.0.len(), rhs.0.len());
+ let mut output = vec![<Field128 as FieldElement>::zero(); self.0.len()];
+ for ((left_elem, right_elem), output_elem) in
+ self.0.iter().zip(rhs.0.iter()).zip(output.iter_mut())
+ {
+ *output_elem = left_elem + right_elem;
+ }
+ MyVector(output)
+ }
+ }
+
+ impl AddAssign for MyVector {
+ fn add_assign(&mut self, rhs: Self) {
+ debug_assert_eq!(self.0.len(), rhs.0.len());
+ for (self_elem, right_elem) in self.0.iter_mut().zip(rhs.0.iter()) {
+ *self_elem += *right_elem;
+ }
+ }
+ }
+
+ impl Sub for MyVector {
+ type Output = Self;
+
+ fn sub(self, rhs: Self) -> Self::Output {
+ debug_assert_eq!(self.0.len(), rhs.0.len());
+ let mut output = vec![<Field128 as FieldElement>::zero(); self.0.len()];
+ for ((left_elem, right_elem), output_elem) in
+ self.0.iter().zip(rhs.0.iter()).zip(output.iter_mut())
+ {
+ *output_elem = left_elem - right_elem;
+ }
+ MyVector(output)
+ }
+ }
+
+ // Use a unit type for inner nodes, thus emulating a DPF. Use a newtype around a `Vec` for
+ // the leaf nodes, to test out values that require runtime parameters.
+ let idpf = Idpf::new((), 3);
+ let binder = b"binder";
+ let (public_share, [key_0, key_1]) = idpf
+ .gen(
+ &IdpfInput::from_bytes(b"ae"),
+ [MyUnit; 15],
+ MyVector(Vec::from([
+ Field128::from(1),
+ Field128::from(2),
+ Field128::from(3),
+ ])),
+ binder,
+ )
+ .unwrap();
+
+ let zero_share_0 = idpf
+ .eval(
+ 0,
+ &public_share,
+ &key_0,
+ &IdpfInput::from_bytes(b"ou"),
+ binder,
+ &mut NoCache::new(),
+ )
+ .unwrap();
+ let zero_share_1 = idpf
+ .eval(
+ 1,
+ &public_share,
+ &key_1,
+ &IdpfInput::from_bytes(b"ou"),
+ binder,
+ &mut NoCache::new(),
+ )
+ .unwrap();
+ let zero_output = zero_share_0.merge(zero_share_1).unwrap();
+ assert_matches!(zero_output, IdpfOutputShare::Leaf(value) => {
+ assert_eq!(value.0.len(), 3);
+ assert_eq!(value.0[0], <Field128 as FieldElement>::zero());
+ assert_eq!(value.0[1], <Field128 as FieldElement>::zero());
+ assert_eq!(value.0[2], <Field128 as FieldElement>::zero());
+ });
+
+ let programmed_share_0 = idpf
+ .eval(
+ 0,
+ &public_share,
+ &key_0,
+ &IdpfInput::from_bytes(b"ae"),
+ binder,
+ &mut NoCache::new(),
+ )
+ .unwrap();
+ let programmed_share_1 = idpf
+ .eval(
+ 1,
+ &public_share,
+ &key_1,
+ &IdpfInput::from_bytes(b"ae"),
+ binder,
+ &mut NoCache::new(),
+ )
+ .unwrap();
+ let programmed_output = programmed_share_0.merge(programmed_share_1).unwrap();
+ assert_matches!(programmed_output, IdpfOutputShare::Leaf(value) => {
+ assert_eq!(value.0.len(), 3);
+ assert_eq!(value.0[0], Field128::from(1));
+ assert_eq!(value.0[1], Field128::from(2));
+ assert_eq!(value.0[2], Field128::from(3));
+ });
+ }
+}
diff --git a/third_party/rust/prio/src/lib.rs b/third_party/rust/prio/src/lib.rs
new file mode 100644
index 0000000000..c9d4e22c49
--- /dev/null
+++ b/third_party/rust/prio/src/lib.rs
@@ -0,0 +1,34 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+#![warn(missing_docs)]
+#![cfg_attr(docsrs, feature(doc_cfg))]
+
+//! # libprio-rs
+//!
+//! Implementation of the [Prio](https://crypto.stanford.edu/prio/) private data aggregation
+//! protocol.
+//!
+//! Prio3 is available in the `vdaf` module as part of an implementation of [Verifiable Distributed
+//! Aggregation Functions][vdaf], along with an experimental implementation of Poplar1.
+//!
+//! [vdaf]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/05/
+
+pub mod benchmarked;
+pub mod codec;
+#[cfg(feature = "experimental")]
+pub mod dp;
+mod fft;
+pub mod field;
+pub mod flp;
+mod fp;
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+#[cfg_attr(
+ docsrs,
+ doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
+)]
+pub mod idpf;
+mod polynomial;
+mod prng;
+pub mod topology;
+pub mod vdaf;
diff --git a/third_party/rust/prio/src/polynomial.rs b/third_party/rust/prio/src/polynomial.rs
new file mode 100644
index 0000000000..89d8a91404
--- /dev/null
+++ b/third_party/rust/prio/src/polynomial.rs
@@ -0,0 +1,383 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Functions for polynomial interpolation and evaluation
+
+#[cfg(feature = "prio2")]
+use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
+use crate::field::FftFriendlyFieldElement;
+
+use std::convert::TryFrom;
+
+/// Temporary memory used for FFT
+#[derive(Clone, Debug)]
+pub struct PolyFFTTempMemory<F> {
+ fft_tmp: Vec<F>,
+ fft_y_sub: Vec<F>,
+ fft_roots_sub: Vec<F>,
+}
+
+impl<F: FftFriendlyFieldElement> PolyFFTTempMemory<F> {
+ fn new(length: usize) -> Self {
+ PolyFFTTempMemory {
+ fft_tmp: vec![F::zero(); length],
+ fft_y_sub: vec![F::zero(); length],
+ fft_roots_sub: vec![F::zero(); length],
+ }
+ }
+}
+
+/// Auxiliary memory for polynomial interpolation and evaluation
+#[derive(Clone, Debug)]
+pub struct PolyAuxMemory<F> {
+ pub roots_2n: Vec<F>,
+ pub roots_2n_inverted: Vec<F>,
+ pub roots_n: Vec<F>,
+ pub roots_n_inverted: Vec<F>,
+ pub coeffs: Vec<F>,
+ pub fft_memory: PolyFFTTempMemory<F>,
+}
+
+impl<F: FftFriendlyFieldElement> PolyAuxMemory<F> {
+ pub fn new(n: usize) -> Self {
+ PolyAuxMemory {
+ roots_2n: fft_get_roots(2 * n, false),
+ roots_2n_inverted: fft_get_roots(2 * n, true),
+ roots_n: fft_get_roots(n, false),
+ roots_n_inverted: fft_get_roots(n, true),
+ coeffs: vec![F::zero(); 2 * n],
+ fft_memory: PolyFFTTempMemory::new(2 * n),
+ }
+ }
+}
+
+fn fft_recurse<F: FftFriendlyFieldElement>(
+ out: &mut [F],
+ n: usize,
+ roots: &[F],
+ ys: &[F],
+ tmp: &mut [F],
+ y_sub: &mut [F],
+ roots_sub: &mut [F],
+) {
+ if n == 1 {
+ out[0] = ys[0];
+ return;
+ }
+
+ let half_n = n / 2;
+
+ let (tmp_first, tmp_second) = tmp.split_at_mut(half_n);
+ let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n);
+ let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n);
+
+ // Recurse on the first half
+ for i in 0..half_n {
+ y_sub_first[i] = ys[i] + ys[i + half_n];
+ roots_sub_first[i] = roots[2 * i];
+ }
+ fft_recurse(
+ tmp_first,
+ half_n,
+ roots_sub_first,
+ y_sub_first,
+ tmp_second,
+ y_sub_second,
+ roots_sub_second,
+ );
+ for i in 0..half_n {
+ out[2 * i] = tmp_first[i];
+ }
+
+ // Recurse on the second half
+ for i in 0..half_n {
+ y_sub_first[i] = ys[i] - ys[i + half_n];
+ y_sub_first[i] *= roots[i];
+ }
+ fft_recurse(
+ tmp_first,
+ half_n,
+ roots_sub_first,
+ y_sub_first,
+ tmp_second,
+ y_sub_second,
+ roots_sub_second,
+ );
+ for i in 0..half_n {
+ out[2 * i + 1] = tmp[i];
+ }
+}
+
+/// Calculate `count` number of roots of unity of order `count`
+fn fft_get_roots<F: FftFriendlyFieldElement>(count: usize, invert: bool) -> Vec<F> {
+ let mut roots = vec![F::zero(); count];
+ let mut gen = F::generator();
+ if invert {
+ gen = gen.inv();
+ }
+
+ roots[0] = F::one();
+ let step_size = F::generator_order() / F::Integer::try_from(count).unwrap();
+ // generator for subgroup of order count
+ gen = gen.pow(step_size);
+
+ roots[1] = gen;
+
+ for i in 2..count {
+ roots[i] = gen * roots[i - 1];
+ }
+
+ roots
+}
+
+fn fft_interpolate_raw<F: FftFriendlyFieldElement>(
+ out: &mut [F],
+ ys: &[F],
+ n_points: usize,
+ roots: &[F],
+ invert: bool,
+ mem: &mut PolyFFTTempMemory<F>,
+) {
+ fft_recurse(
+ out,
+ n_points,
+ roots,
+ ys,
+ &mut mem.fft_tmp,
+ &mut mem.fft_y_sub,
+ &mut mem.fft_roots_sub,
+ );
+ if invert {
+ let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv();
+ for out_val in out[0..n_points].iter_mut() {
+ *out_val *= n_inverse;
+ }
+ }
+}
+
+pub fn poly_fft<F: FftFriendlyFieldElement>(
+ points_out: &mut [F],
+ points_in: &[F],
+ scaled_roots: &[F],
+ n_points: usize,
+ invert: bool,
+ mem: &mut PolyFFTTempMemory<F>,
+) {
+ fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem)
+}
+
+// Evaluate a polynomial using Horner's method.
+pub fn poly_eval<F: FftFriendlyFieldElement>(poly: &[F], eval_at: F) -> F {
+ if poly.is_empty() {
+ return F::zero();
+ }
+
+ let mut result = poly[poly.len() - 1];
+ for i in (0..poly.len() - 1).rev() {
+ result *= eval_at;
+ result += poly[i];
+ }
+
+ result
+}
+
+// Returns the degree of polynomial `p`.
+pub fn poly_deg<F: FftFriendlyFieldElement>(p: &[F]) -> usize {
+ let mut d = p.len();
+ while d > 0 && p[d - 1] == F::zero() {
+ d -= 1;
+ }
+ d.saturating_sub(1)
+}
+
+// Multiplies polynomials `p` and `q` and returns the result.
+pub fn poly_mul<F: FftFriendlyFieldElement>(p: &[F], q: &[F]) -> Vec<F> {
+ let p_size = poly_deg(p) + 1;
+ let q_size = poly_deg(q) + 1;
+ let mut out = vec![F::zero(); p_size + q_size];
+ for i in 0..p_size {
+ for j in 0..q_size {
+ out[i + j] += p[i] * q[j];
+ }
+ }
+ out.truncate(poly_deg(&out) + 1);
+ out
+}
+
+#[cfg(feature = "prio2")]
+#[inline]
+pub fn poly_interpret_eval<F: FftFriendlyFieldElement>(
+ points: &[F],
+ eval_at: F,
+ tmp_coeffs: &mut [F],
+) -> F {
+ let size_inv = F::from(F::Integer::try_from(points.len()).unwrap()).inv();
+ discrete_fourier_transform(tmp_coeffs, points, points.len()).unwrap();
+ discrete_fourier_transform_inv_finish(tmp_coeffs, points.len(), size_inv);
+ poly_eval(&tmp_coeffs[..points.len()], eval_at)
+}
+
+// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise,
+// the output is not `0`.
+pub(crate) fn poly_range_check<F: FftFriendlyFieldElement>(start: usize, end: usize) -> Vec<F> {
+ let mut p = vec![F::one()];
+ let mut q = [F::zero(), F::one()];
+ for i in start..end {
+ q[0] = -F::from(F::Integer::try_from(i).unwrap());
+ p = poly_mul(&p, &q);
+ }
+ p
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ field::{
+ FftFriendlyFieldElement, Field64, FieldElement, FieldElementWithInteger, FieldPrio2,
+ },
+ polynomial::{
+ fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, PolyAuxMemory,
+ },
+ };
+ use rand::prelude::*;
+ use std::convert::TryFrom;
+
+ #[test]
+ fn test_roots() {
+ let count = 128;
+ let roots = fft_get_roots::<FieldPrio2>(count, false);
+ let roots_inv = fft_get_roots::<FieldPrio2>(count, true);
+
+ for i in 0..count {
+ assert_eq!(roots[i] * roots_inv[i], 1);
+ assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1);
+ assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1);
+ }
+ }
+
+ #[test]
+ fn test_eval() {
+ let mut poly = [FieldPrio2::from(0); 4];
+ poly[0] = 2.into();
+ poly[1] = 1.into();
+ poly[2] = 5.into();
+ // 5*3^2 + 3 + 2 = 50
+ assert_eq!(poly_eval(&poly[..3], 3.into()), 50);
+ poly[3] = 4.into();
+ // 4*3^3 + 5*3^2 + 3 + 2 = 158
+ assert_eq!(poly_eval(&poly[..4], 3.into()), 158);
+ }
+
+ #[test]
+ fn test_poly_deg() {
+ let zero = FieldPrio2::zero();
+ let one = FieldPrio2::root(0).unwrap();
+ assert_eq!(poly_deg(&[zero]), 0);
+ assert_eq!(poly_deg(&[one]), 0);
+ assert_eq!(poly_deg(&[zero, one]), 1);
+ assert_eq!(poly_deg(&[zero, zero, one]), 2);
+ assert_eq!(poly_deg(&[zero, one, one]), 2);
+ assert_eq!(poly_deg(&[zero, one, one, one]), 3);
+ assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3);
+ assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3);
+ }
+
+ #[test]
+ fn test_poly_mul() {
+ let p = [
+ Field64::from(u64::try_from(2).unwrap()),
+ Field64::from(u64::try_from(3).unwrap()),
+ ];
+
+ let q = [
+ Field64::one(),
+ Field64::zero(),
+ Field64::from(u64::try_from(5).unwrap()),
+ ];
+
+ let want = [
+ Field64::from(u64::try_from(2).unwrap()),
+ Field64::from(u64::try_from(3).unwrap()),
+ Field64::from(u64::try_from(10).unwrap()),
+ Field64::from(u64::try_from(15).unwrap()),
+ ];
+
+ let got = poly_mul(&p, &q);
+ assert_eq!(&got, &want);
+ }
+
+ #[test]
+ fn test_poly_range_check() {
+ let start = 74;
+ let end = 112;
+ let p = poly_range_check(start, end);
+
+ // Check each number in the range.
+ for i in start..end {
+ let x = Field64::from(i as u64);
+ let y = poly_eval(&p, x);
+ assert_eq!(y, Field64::zero(), "range check failed for {i}");
+ }
+
+ // Check the number below the range.
+ let x = Field64::from((start - 1) as u64);
+ let y = poly_eval(&p, x);
+ assert_ne!(y, Field64::zero());
+
+ // Check a number above the range.
+ let x = Field64::from(end as u64);
+ let y = poly_eval(&p, x);
+ assert_ne!(y, Field64::zero());
+ }
+
+ #[test]
+ fn test_fft() {
+ let count = 128;
+ let mut mem = PolyAuxMemory::new(count / 2);
+
+ let mut poly = vec![FieldPrio2::from(0); count];
+ let mut points2 = vec![FieldPrio2::from(0); count];
+
+ let points = (0..count)
+ .map(|_| FieldPrio2::from(random::<u32>()))
+ .collect::<Vec<FieldPrio2>>();
+
+ // From points to coeffs and back
+ poly_fft(
+ &mut poly,
+ &points,
+ &mem.roots_2n,
+ count,
+ false,
+ &mut mem.fft_memory,
+ );
+ poly_fft(
+ &mut points2,
+ &poly,
+ &mem.roots_2n_inverted,
+ count,
+ true,
+ &mut mem.fft_memory,
+ );
+
+ assert_eq!(points, points2);
+
+ // interpolation
+ poly_fft(
+ &mut poly,
+ &points,
+ &mem.roots_2n,
+ count,
+ false,
+ &mut mem.fft_memory,
+ );
+
+ for (poly_coeff, root) in poly[..count].iter().zip(mem.roots_2n[..count].iter()) {
+ let mut should_be = FieldPrio2::from(0);
+ for (j, point_j) in points[..count].iter().enumerate() {
+ should_be = root.pow(u32::try_from(j).unwrap()) * *point_j + should_be;
+ }
+ assert_eq!(should_be, *poly_coeff);
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/prng.rs b/third_party/rust/prio/src/prng.rs
new file mode 100644
index 0000000000..cb7d3a54c8
--- /dev/null
+++ b/third_party/rust/prio/src/prng.rs
@@ -0,0 +1,278 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Tool for generating pseudorandom field elements.
+//!
+//! NOTE: The public API for this module is a work in progress.
+
+use crate::field::{FieldElement, FieldElementExt};
+#[cfg(feature = "crypto-dependencies")]
+use crate::vdaf::xof::SeedStreamAes128;
+#[cfg(feature = "crypto-dependencies")]
+use getrandom::getrandom;
+use rand_core::RngCore;
+
+use std::marker::PhantomData;
+use std::ops::ControlFlow;
+
+const BUFFER_SIZE_IN_ELEMENTS: usize = 32;
+
+/// Errors propagated by methods in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum PrngError {
+ /// Failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+}
+
+/// This type implements an iterator that generates a pseudorandom sequence of field elements. The
+/// sequence is derived from a XOF's key stream.
+#[derive(Debug)]
+pub(crate) struct Prng<F, S> {
+ phantom: PhantomData<F>,
+ seed_stream: S,
+ buffer: Vec<u8>,
+ buffer_index: usize,
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl<F: FieldElement> Prng<F, SeedStreamAes128> {
+ /// Create a [`Prng`] from a seed for Prio 2. The first 16 bytes of the seed and the last 16
+ /// bytes of the seed are used, respectively, for the key and initialization vector for AES128
+ /// in CTR mode.
+ pub(crate) fn from_prio2_seed(seed: &[u8; 32]) -> Self {
+ let seed_stream = SeedStreamAes128::new(&seed[..16], &seed[16..]);
+ Self::from_seed_stream(seed_stream)
+ }
+
+ /// Create a [`Prng`] from a randomly generated seed.
+ pub(crate) fn new() -> Result<Self, PrngError> {
+ let mut seed = [0; 32];
+ getrandom(&mut seed)?;
+ Ok(Self::from_prio2_seed(&seed))
+ }
+}
+
+impl<F, S> Prng<F, S>
+where
+ F: FieldElement,
+ S: RngCore,
+{
+ pub(crate) fn from_seed_stream(mut seed_stream: S) -> Self {
+ let mut buffer = vec![0; BUFFER_SIZE_IN_ELEMENTS * F::ENCODED_SIZE];
+ seed_stream.fill_bytes(&mut buffer);
+
+ Self {
+ phantom: PhantomData::<F>,
+ seed_stream,
+ buffer,
+ buffer_index: 0,
+ }
+ }
+
+ pub(crate) fn get(&mut self) -> F {
+ loop {
+ // Seek to the next chunk of the buffer that encodes an element of F.
+ for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) {
+ let j = i + F::ENCODED_SIZE;
+
+ if j > self.buffer.len() {
+ break;
+ }
+
+ self.buffer_index = j;
+
+ match F::from_random_rejection(&self.buffer[i..j]) {
+ ControlFlow::Break(x) => return x,
+ ControlFlow::Continue(()) => continue, // reject this sample
+ }
+ }
+
+ // Refresh buffer with the next chunk of XOF output, filling the front of the buffer
+ // with the leftovers. This ensures continuity of the seed stream after converting the
+ // `Prng` to a new field type via `into_new_field()`.
+ let left_over = self.buffer.len() - self.buffer_index;
+ self.buffer.copy_within(self.buffer_index.., 0);
+ self.seed_stream.fill_bytes(&mut self.buffer[left_over..]);
+ self.buffer_index = 0;
+ }
+ }
+
+ /// Convert this object into a field element generator for a different field.
+ #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+ pub(crate) fn into_new_field<F1: FieldElement>(self) -> Prng<F1, S> {
+ Prng {
+ phantom: PhantomData,
+ seed_stream: self.seed_stream,
+ buffer: self.buffer,
+ buffer_index: self.buffer_index,
+ }
+ }
+}
+
+impl<F, S> Iterator for Prng<F, S>
+where
+ F: FieldElement,
+ S: RngCore,
+{
+ type Item = F;
+
+ fn next(&mut self) -> Option<F> {
+ Some(self.get())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ codec::Decode,
+ field::{Field64, FieldPrio2},
+ vdaf::xof::{Seed, SeedStreamSha3, Xof, XofShake128},
+ };
+ #[cfg(feature = "prio2")]
+ use base64::{engine::Engine, prelude::BASE64_STANDARD};
+ #[cfg(feature = "prio2")]
+ use sha2::{Digest, Sha256};
+ use std::convert::TryInto;
+
+ #[test]
+ fn secret_sharing_interop() {
+ let seed = [
+ 0xcd, 0x85, 0x5b, 0xd4, 0x86, 0x48, 0xa4, 0xce, 0x52, 0x5c, 0x36, 0xee, 0x5a, 0x71,
+ 0xf3, 0x0f, 0x66, 0x80, 0xd3, 0x67, 0x53, 0x9a, 0x39, 0x6f, 0x12, 0x2f, 0xad, 0x94,
+ 0x4d, 0x34, 0xcb, 0x58,
+ ];
+
+ let reference = [
+ 0xd0056ec5, 0xe23f9c52, 0x47e4ddb4, 0xbe5dacf6, 0x4b130aba, 0x530c7a90, 0xe8fc4ee5,
+ 0xb0569cb7, 0x7774cd3c, 0x7f24e6a5, 0xcc82355d, 0xc41f4f13, 0x67fe193c, 0xc94d63a4,
+ 0x5d7b474c, 0xcc5c9f5f, 0xe368e1d5, 0x020fa0cf, 0x9e96aa2a, 0xe924137d, 0xfa026ab9,
+ 0x8ebca0cc, 0x26fc58a5, 0x10a7b173, 0xb9c97291, 0x53ef0e28, 0x069cfb8e, 0xe9383cae,
+ 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58,
+ ];
+
+ let share2 = extract_share_from_seed::<FieldPrio2>(reference.len(), &seed);
+
+ assert_eq!(share2, reference);
+ }
+
+ /// takes a seed and hash as base64 encoded strings
+ #[cfg(feature = "prio2")]
+ fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) {
+ let seed = BASE64_STANDARD.decode(seed_base64).unwrap();
+ let random_data = extract_share_from_seed::<FieldPrio2>(len, &seed);
+
+ let random_bytes = FieldPrio2::slice_into_byte_vec(&random_data);
+
+ let mut hasher = Sha256::new();
+ hasher.update(&random_bytes);
+ let digest = hasher.finalize();
+ assert_eq!(BASE64_STANDARD.encode(digest), hash_base64);
+ }
+
+ #[test]
+ #[cfg(feature = "prio2")]
+ fn test_hash_interop() {
+ random_data_interop(
+ "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
+ "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=",
+ 100_000,
+ );
+
+ // zero seed
+ random_data_interop(
+ "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
+ "3wHQbSwAn9GPfoNkKe1qSzWdKnu/R+hPPyRwwz6Di+w=",
+ 100_000,
+ );
+ // 0, 1, 2 ... seed
+ random_data_interop(
+ "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
+ "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=",
+ 100_000,
+ );
+ // one arbirtary fixed seed
+ random_data_interop(
+ "rkLrnVcU8ULaiuXTvR3OKrfpMX0kQidqVzta1pleKKg=",
+ "b1fMXYrGUNR3wOZ/7vmUMmY51QHoPDBzwok0fz6xC0I=",
+ 100_000,
+ );
+ // all bits set seed
+ random_data_interop(
+ "//////////////////////////////////////////8=",
+ "iBiDaqLrv7/rX/+vs6akPiprGgYfULdh/XhoD61HQXA=",
+ 100_000,
+ );
+ }
+
+ fn extract_share_from_seed<F: FieldElement>(length: usize, seed: &[u8]) -> Vec<F> {
+ assert_eq!(seed.len(), 32);
+ Prng::from_prio2_seed(seed.try_into().unwrap())
+ .take(length)
+ .collect()
+ }
+
+ #[test]
+ fn rejection_sampling_test_vector() {
+ // These constants were found in a brute-force search, and they test that the XOF performs
+ // rejection sampling correctly when the raw output exceeds the prime modulus.
+ let seed = Seed::get_decoded(&[
+ 0x29, 0xb2, 0x98, 0x64, 0xb4, 0xaa, 0x4e, 0x07, 0x2a, 0x44, 0x49, 0x24, 0xf6, 0x74,
+ 0x0a, 0x3d,
+ ])
+ .unwrap();
+ let expected = Field64::from(2035552711764301796);
+
+ let seed_stream = XofShake128::seed_stream(&seed, b"", b"");
+ let mut prng = Prng::<Field64, _>::from_seed_stream(seed_stream);
+ let actual = prng.nth(33236).unwrap();
+ assert_eq!(actual, expected);
+
+ #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+ {
+ let mut seed_stream = XofShake128::seed_stream(&seed, b"", b"");
+ let mut actual = <Field64 as FieldElement>::zero();
+ for _ in 0..=33236 {
+ actual = <Field64 as crate::idpf::IdpfValue>::generate(&mut seed_stream, &());
+ }
+ assert_eq!(actual, expected);
+ }
+ }
+
+ // Test that the `Prng`'s internal buffer properly copies the end of the buffer to the front
+ // once it reaches the end.
+ #[test]
+ fn left_over_buffer_back_fill() {
+ let seed = Seed::generate().unwrap();
+
+ let mut prng: Prng<Field64, SeedStreamSha3> =
+ Prng::from_seed_stream(XofShake128::seed_stream(&seed, b"", b""));
+
+ // Construct a `Prng` with a longer-than-usual buffer.
+ let mut prng_weird_buffer_size: Prng<Field64, SeedStreamSha3> =
+ Prng::from_seed_stream(XofShake128::seed_stream(&seed, b"", b""));
+ let mut extra = [0; 7];
+ prng_weird_buffer_size.seed_stream.fill_bytes(&mut extra);
+ prng_weird_buffer_size.buffer.extend_from_slice(&extra);
+
+ // Check that the next several outputs match. We need to check enough outputs to ensure
+ // that we have to refill the buffer.
+ for _ in 0..BUFFER_SIZE_IN_ELEMENTS * 2 {
+ assert_eq!(prng.next().unwrap(), prng_weird_buffer_size.next().unwrap());
+ }
+ }
+
+ #[cfg(feature = "experimental")]
+ #[test]
+ fn into_new_field() {
+ let seed = Seed::generate().unwrap();
+ let want: Prng<Field64, SeedStreamSha3> =
+ Prng::from_seed_stream(XofShake128::seed_stream(&seed, b"", b""));
+ let want_buffer = want.buffer.clone();
+
+ let got: Prng<FieldPrio2, _> = want.into_new_field();
+ assert_eq!(got.buffer_index, 0);
+ assert_eq!(got.buffer, want_buffer);
+ }
+}
diff --git a/third_party/rust/prio/src/topology/mod.rs b/third_party/rust/prio/src/topology/mod.rs
new file mode 100644
index 0000000000..fdce6d722a
--- /dev/null
+++ b/third_party/rust/prio/src/topology/mod.rs
@@ -0,0 +1,7 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementations of some aggregator communication topologies specified in [VDAF].
+//!
+//! [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-06#section-5.7
+
+pub mod ping_pong;
diff --git a/third_party/rust/prio/src/topology/ping_pong.rs b/third_party/rust/prio/src/topology/ping_pong.rs
new file mode 100644
index 0000000000..c55d4f638d
--- /dev/null
+++ b/third_party/rust/prio/src/topology/ping_pong.rs
@@ -0,0 +1,968 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implements the Ping-Pong Topology described in [VDAF]. This topology assumes there are exactly
+//! two aggregators, designated "Leader" and "Helper". This topology is required for implementing
+//! the [Distributed Aggregation Protocol][DAP].
+//!
+//! [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+//! [DAP]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap
+
+use crate::{
+ codec::{decode_u32_items, encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode},
+ vdaf::{Aggregator, PrepareTransition, VdafError},
+};
+use std::fmt::Debug;
+
+/// Errors emitted by this module.
+#[derive(Debug, thiserror::Error)]
+pub enum PingPongError {
+ /// Error running prepare_init
+ #[error("vdaf.prepare_init: {0}")]
+ VdafPrepareInit(VdafError),
+
+ /// Error running prepare_shares_to_prepare_message
+ #[error("vdaf.prepare_shares_to_prepare_message {0}")]
+ VdafPrepareSharesToPrepareMessage(VdafError),
+
+ /// Error running prepare_next
+ #[error("vdaf.prepare_next {0}")]
+ VdafPrepareNext(VdafError),
+
+ /// Error decoding a prepare share
+ #[error("decode prep share {0}")]
+ CodecPrepShare(CodecError),
+
+ /// Error decoding a prepare message
+ #[error("decode prep message {0}")]
+ CodecPrepMessage(CodecError),
+
+ /// Host is in an unexpected state
+ #[error("host state mismatch: in {found} expected {expected}")]
+ HostStateMismatch {
+ /// The state the host is in.
+ found: &'static str,
+ /// The state the host expected to be in.
+ expected: &'static str,
+ },
+
+ /// Message from peer indicates it is in an unexpected state
+ #[error("peer message mismatch: message is {found} expected {expected}")]
+ PeerMessageMismatch {
+ /// The state in the message from the peer.
+ found: &'static str,
+ /// The message expected from the peer.
+ expected: &'static str,
+ },
+
+ /// Internal error
+ #[error("internal error: {0}")]
+ InternalError(&'static str),
+}
+
+/// Corresponds to `struct Message` in [VDAF's Ping-Pong Topology][VDAF]. All of the fields of the
+/// variants are opaque byte buffers. This is because the ping-pong routines take responsibility for
+/// decoding preparation shares and messages, which usually requires having the preparation state.
+///
+/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+#[derive(Clone, PartialEq, Eq)]
+pub enum PingPongMessage {
+ /// Corresponds to MessageType.initialize.
+ Initialize {
+ /// The leader's initial preparation share.
+ prep_share: Vec<u8>,
+ },
+ /// Corresponds to MessageType.continue.
+ Continue {
+ /// The current round's preparation message.
+ prep_msg: Vec<u8>,
+ /// The next round's preparation share.
+ prep_share: Vec<u8>,
+ },
+ /// Corresponds to MessageType.finish.
+ Finish {
+ /// The current round's preparation message.
+ prep_msg: Vec<u8>,
+ },
+}
+
+impl PingPongMessage {
+ fn variant(&self) -> &'static str {
+ match self {
+ Self::Initialize { .. } => "Initialize",
+ Self::Continue { .. } => "Continue",
+ Self::Finish { .. } => "Finish",
+ }
+ }
+}
+
+impl Debug for PingPongMessage {
+ // We want `PingPongMessage` to implement `Debug`, but we don't want that impl to print out
+ // prepare shares or messages, because (1) their contents are sensitive and (2) their contents
+ // are long and not intelligible to humans. For both reasons they generally shouldn't get
+ // logged. Normally, we'd use the `derivative` crate to customize a derived `Debug`, but that
+ // crate has not been audited (in the `cargo vet` sense) so we can't use it here unless we audit
+ // 8,000+ lines of proc macros.
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_tuple(self.variant()).finish()
+ }
+}
+
+impl Encode for PingPongMessage {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // The encoding includes an implicit discriminator byte, called MessageType in the VDAF
+ // spec.
+ match self {
+ Self::Initialize { prep_share } => {
+ 0u8.encode(bytes);
+ encode_u32_items(bytes, &(), prep_share);
+ }
+ Self::Continue {
+ prep_msg,
+ prep_share,
+ } => {
+ 1u8.encode(bytes);
+ encode_u32_items(bytes, &(), prep_msg);
+ encode_u32_items(bytes, &(), prep_share);
+ }
+ Self::Finish { prep_msg } => {
+ 2u8.encode(bytes);
+ encode_u32_items(bytes, &(), prep_msg);
+ }
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ match self {
+ Self::Initialize { prep_share } => Some(1 + 4 + prep_share.len()),
+ Self::Continue {
+ prep_msg,
+ prep_share,
+ } => Some(1 + 4 + prep_msg.len() + 4 + prep_share.len()),
+ Self::Finish { prep_msg } => Some(1 + 4 + prep_msg.len()),
+ }
+ }
+}
+
+impl Decode for PingPongMessage {
+ fn decode(bytes: &mut std::io::Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let message_type = u8::decode(bytes)?;
+ Ok(match message_type {
+ 0 => {
+ let prep_share = decode_u32_items(&(), bytes)?;
+ Self::Initialize { prep_share }
+ }
+ 1 => {
+ let prep_msg = decode_u32_items(&(), bytes)?;
+ let prep_share = decode_u32_items(&(), bytes)?;
+ Self::Continue {
+ prep_msg,
+ prep_share,
+ }
+ }
+ 2 => {
+ let prep_msg = decode_u32_items(&(), bytes)?;
+ Self::Finish { prep_msg }
+ }
+ _ => return Err(CodecError::UnexpectedValue),
+ })
+ }
+}
+
+/// A transition in the pong-pong topology. This represents the `ping_pong_transition` function
+/// defined in [VDAF].
+///
+/// # Discussion
+///
+/// The obvious implementation of `ping_pong_transition` would be a method on trait
+/// [`PingPongTopology`] that returns `(State, Message)`, and then `ContinuedValue::WithMessage`
+/// would contain those values. But then DAP implementations would have to store relatively large
+/// VDAF prepare shares between rounds of input preparation.
+///
+/// Instead, this structure stores just the previous round's prepare state and the current round's
+/// preprocessed prepare message. Their encoding is much smaller than the `(State, Message)` tuple,
+/// which can always be recomputed with [`Self::evaluate`].
+///
+/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+#[derive(Clone, Debug, Eq)]
+pub struct PingPongTransition<
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+> {
+ previous_prepare_state: A::PrepareState,
+ current_prepare_message: A::PrepareMessage,
+}
+
+impl<
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+ > PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
+{
+ /// Evaluate this transition to obtain a new [`PingPongState`] and a [`PingPongMessage`] which
+ /// should be transmitted to the peer.
+ #[allow(clippy::type_complexity)]
+ pub fn evaluate(
+ &self,
+ vdaf: &A,
+ ) -> Result<
+ (
+ PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, A>,
+ PingPongMessage,
+ ),
+ PingPongError,
+ > {
+ let prep_msg = self.current_prepare_message.get_encoded();
+
+ vdaf.prepare_next(
+ self.previous_prepare_state.clone(),
+ self.current_prepare_message.clone(),
+ )
+ .map(|transition| match transition {
+ PrepareTransition::Continue(prep_state, prep_share) => (
+ PingPongState::Continued(prep_state),
+ PingPongMessage::Continue {
+ prep_msg,
+ prep_share: prep_share.get_encoded(),
+ },
+ ),
+ PrepareTransition::Finish(output_share) => (
+ PingPongState::Finished(output_share),
+ PingPongMessage::Finish { prep_msg },
+ ),
+ })
+ .map_err(PingPongError::VdafPrepareNext)
+ }
+}
+
+impl<
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+ > PartialEq for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.previous_prepare_state == other.previous_prepare_state
+ && self.current_prepare_message == other.current_prepare_message
+ }
+}
+
+impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A> Encode
+ for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
+where
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+ A::PrepareState: Encode,
+{
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.previous_prepare_state.encode(bytes);
+ self.current_prepare_message.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(
+ self.previous_prepare_state.encoded_len()?
+ + self.current_prepare_message.encoded_len()?,
+ )
+ }
+}
+
+impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A, PrepareStateDecode>
+ ParameterizedDecode<PrepareStateDecode> for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
+where
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+ A::PrepareState: ParameterizedDecode<PrepareStateDecode> + PartialEq,
+ A::PrepareMessage: PartialEq,
+{
+ fn decode_with_param(
+ decoding_param: &PrepareStateDecode,
+ bytes: &mut std::io::Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let previous_prepare_state = A::PrepareState::decode_with_param(decoding_param, bytes)?;
+ let current_prepare_message =
+ A::PrepareMessage::decode_with_param(&previous_prepare_state, bytes)?;
+
+ Ok(Self {
+ previous_prepare_state,
+ current_prepare_message,
+ })
+ }
+}
+
+/// Corresponds to the `State` enumeration implicitly defined in [VDAF's Ping-Pong Topology][VDAF].
+/// VDAF describes `Start` and `Rejected` states, but the `Start` state is never instantiated in
+/// code, and the `Rejected` state is represented as `std::result::Result::Err`, so this enum does
+/// not include those variants.
+///
+/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum PingPongState<
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+> {
+ /// Preparation of the report will continue with the enclosed state.
+ Continued(A::PrepareState),
+ /// Preparation of the report is finished and has yielded the enclosed output share.
+ Finished(A::OutputShare),
+}
+
+/// Values returned by [`PingPongTopology::leader_continued`] or
+/// [`PingPongTopology::helper_continued`].
+#[derive(Clone, Debug)]
+pub enum PingPongContinuedValue<
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+> {
+ /// The operation resulted in a new state and a message to transmit to the peer.
+ WithMessage {
+ /// The transition that will be executed. Call `PingPongTransition::evaluate` to obtain the
+ /// next
+ /// [`PingPongState`] and a [`PingPongMessage`] to transmit to the peer.
+ transition: PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>,
+ },
+ /// The operation caused the host to finish preparation of the input share, yielding an output
+ /// share and no message for the peer.
+ FinishedNoMessage {
+ /// The output share which may now be accumulated.
+ output_share: A::OutputShare,
+ },
+}
+
+/// Extension trait on [`crate::vdaf::Aggregator`] which adds the [VDAF Ping-Pong Topology][VDAF].
+///
+/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+pub trait PingPongTopology<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>:
+ Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>
+{
+ /// Specialization of [`PingPongState`] for this VDAF.
+ type State;
+ /// Specialization of [`PingPongContinuedValue`] for this VDAF.
+ type ContinuedValue;
+ /// Specializaton of [`PingPongTransition`] for this VDAF.
+ type Transition;
+
+ /// Initialize leader state using the leader's input share. Corresponds to
+ /// `ping_pong_leader_init` in [VDAF].
+ ///
+ /// If successful, the returned [`PingPongMessage`] (which will always be
+ /// `PingPongMessage::Initialize`) should be transmitted to the helper. The returned
+ /// [`PingPongState`] (which will always be `PingPongState::Continued`) should be used by the
+ /// leader along with the next [`PingPongMessage`] received from the helper as input to
+ /// [`Self::leader_continued`] to advance to the next round.
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+ fn leader_initialized(
+ &self,
+ verify_key: &[u8; VERIFY_KEY_SIZE],
+ agg_param: &Self::AggregationParam,
+ nonce: &[u8; NONCE_SIZE],
+ public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ ) -> Result<(Self::State, PingPongMessage), PingPongError>;
+
+ /// Initialize helper state using the helper's input share and the leader's first prepare share.
+ /// Corresponds to `ping_pong_helper_init` in the forthcoming `draft-irtf-cfrg-vdaf-07`.
+ ///
+ /// If successful, the returned [`PingPongTransition`] should be evaluated, yielding a
+ /// [`PingPongMessage`], which should be transmitted to the leader, and a [`PingPongState`].
+ ///
+ /// If the state is `PingPongState::Continued`, then it should be used by the helper along with
+ /// the next `PingPongMessage` received from the leader as input to [`Self::helper_continued`]
+ /// to advance to the next round. The helper may store the `PingPongTransition` between rounds
+ /// of preparation instead of the `PingPongState` and `PingPongMessage`.
+ ///
+ /// If the state is `PingPongState::Finished`, then preparation is finished and the output share
+ /// may be accumulated.
+ ///
+ /// # Errors
+ ///
+ /// `inbound` must be `PingPongMessage::Initialize` or the function will fail.
+ fn helper_initialized(
+ &self,
+ verify_key: &[u8; VERIFY_KEY_SIZE],
+ agg_param: &Self::AggregationParam,
+ nonce: &[u8; NONCE_SIZE],
+ public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ inbound: &PingPongMessage,
+ ) -> Result<PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>, PingPongError>;
+
+ /// Continue preparation based on the leader's current state and an incoming [`PingPongMessage`]
+ /// from the helper. Corresponds to `ping_pong_leader_continued` in [VDAF].
+ ///
+ /// If successful, the returned [`PingPongContinuedValue`] will either be:
+ ///
+ /// - `PingPongContinuedValue::WithMessage { transition }`: `transition` should be evaluated,
+ /// yielding a [`PingPongMessage`], which should be transmitted to the helper, and a
+ /// [`PingPongState`].
+ ///
+ /// If the state is `PingPongState::Continued`, then it should be used by the leader along
+ /// with the next `PingPongMessage` received from the helper as input to
+ /// [`Self::leader_continued`] to advance to the next round. The leader may store the
+ /// `PingPongTransition` between rounds of preparation instead of of the `PingPongState` and
+ /// `PingPongMessage`.
+ ///
+ /// If the state is `PingPongState::Finished`, then preparation is finished and the output
+ /// share may be accumulated.
+ ///
+ /// - `PingPongContinuedValue::FinishedNoMessage`: preparation is finished and the output share
+ /// may be accumulated. No message needs to be sent to the helper.
+ ///
+ /// # Errors
+ ///
+ /// `leader_state` must be `PingPongState::Continued` or the function will fail.
+ ///
+ /// `inbound` must not be `PingPongMessage::Initialize` or the function will fail.
+ ///
+ /// # Notes
+ ///
+ /// The specification of this function in [VDAF] takes the aggregation parameter. This version
+ /// does not, because [`crate::vdaf::Aggregator::prepare_preprocess`] does not take the
+ /// aggregation parameter. This may change in the future if/when [#670][issue] is addressed.
+ ///
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+ /// [issue]: https://github.com/divviup/libprio-rs/issues/670
+ fn leader_continued(
+ &self,
+ leader_state: Self::State,
+ agg_param: &Self::AggregationParam,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::ContinuedValue, PingPongError>;
+
+ /// PingPongContinue preparation based on the helper's current state and an incoming
+ /// [`PingPongMessage`] from the leader. Corresponds to `ping_pong_helper_contnued` in [VDAF].
+ ///
+ /// If successful, the returned [`PingPongContinuedValue`] will either be:
+ ///
+ /// - `PingPongContinuedValue::WithMessage { transition }`: `transition` should be evaluated,
+ /// yielding a [`PingPongMessage`], which should be transmitted to the leader, and a
+ /// [`PingPongState`].
+ ///
+ /// If the state is `PingPongState::Continued`, then it should be used by the helper along
+ /// with the next `PingPongMessage` received from the leader as input to
+ /// [`Self::helper_continued`] to advance to the next round. The helper may store the
+ /// `PingPongTransition` between rounds of preparation instead of the `PingPongState` and
+ /// `PingPongMessage`.
+ ///
+ /// If the state is `PingPongState::Finished`, then preparation is finished and the output
+ /// share may be accumulated.
+ ///
+ /// - `PingPongContinuedValue::FinishedNoMessage`: preparation is finished and the output share
+ /// may be accumulated. No message needs to be sent to the leader.
+ ///
+ /// # Errors
+ ///
+ /// `helper_state` must be `PingPongState::Continued` or the function will fail.
+ ///
+ /// `inbound` must not be `PingPongMessage::Initialize` or the function will fail.
+ ///
+ /// # Notes
+ ///
+ /// The specification of this function in [VDAF] takes the aggregation parameter. This version
+ /// does not, because [`crate::vdaf::Aggregator::prepare_preprocess`] does not take the
+ /// aggregation parameter. This may change in the future if/when [#670][issue] is addressed.
+ ///
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8
+ /// [issue]: https://github.com/divviup/libprio-rs/issues/670
+ fn helper_continued(
+ &self,
+ helper_state: Self::State,
+ agg_param: &Self::AggregationParam,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::ContinuedValue, PingPongError>;
+}
+
+/// Private interfaces for implementing ping-pong
+trait PingPongTopologyPrivate<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>:
+ PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE>
+{
+ fn continued(
+ &self,
+ is_leader: bool,
+ host_state: Self::State,
+ agg_param: &Self::AggregationParam,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::ContinuedValue, PingPongError>;
+}
+
+impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A>
+ PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE> for A
+where
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+{
+ type State = PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
+ type ContinuedValue = PingPongContinuedValue<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
+ type Transition = PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
+
+ fn leader_initialized(
+ &self,
+ verify_key: &[u8; VERIFY_KEY_SIZE],
+ agg_param: &Self::AggregationParam,
+ nonce: &[u8; NONCE_SIZE],
+ public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ ) -> Result<(Self::State, PingPongMessage), PingPongError> {
+ self.prepare_init(
+ verify_key,
+ /* Leader */ 0,
+ agg_param,
+ nonce,
+ public_share,
+ input_share,
+ )
+ .map(|(prep_state, prep_share)| {
+ (
+ PingPongState::Continued(prep_state),
+ PingPongMessage::Initialize {
+ prep_share: prep_share.get_encoded(),
+ },
+ )
+ })
+ .map_err(PingPongError::VdafPrepareInit)
+ }
+
+ fn helper_initialized(
+ &self,
+ verify_key: &[u8; VERIFY_KEY_SIZE],
+ agg_param: &Self::AggregationParam,
+ nonce: &[u8; NONCE_SIZE],
+ public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::Transition, PingPongError> {
+ let (prep_state, prep_share) = self
+ .prepare_init(
+ verify_key,
+ /* Helper */ 1,
+ agg_param,
+ nonce,
+ public_share,
+ input_share,
+ )
+ .map_err(PingPongError::VdafPrepareInit)?;
+
+ let inbound_prep_share = if let PingPongMessage::Initialize { prep_share } = inbound {
+ Self::PrepareShare::get_decoded_with_param(&prep_state, prep_share)
+ .map_err(PingPongError::CodecPrepShare)?
+ } else {
+ return Err(PingPongError::PeerMessageMismatch {
+ found: inbound.variant(),
+ expected: "initialize",
+ });
+ };
+
+ let current_prepare_message = self
+ .prepare_shares_to_prepare_message(agg_param, [inbound_prep_share, prep_share])
+ .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?;
+
+ Ok(PingPongTransition {
+ previous_prepare_state: prep_state,
+ current_prepare_message,
+ })
+ }
+
+ fn leader_continued(
+ &self,
+ leader_state: Self::State,
+ agg_param: &Self::AggregationParam,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::ContinuedValue, PingPongError> {
+ self.continued(true, leader_state, agg_param, inbound)
+ }
+
+ fn helper_continued(
+ &self,
+ helper_state: Self::State,
+ agg_param: &Self::AggregationParam,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::ContinuedValue, PingPongError> {
+ self.continued(false, helper_state, agg_param, inbound)
+ }
+}
+
+impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A>
+ PingPongTopologyPrivate<VERIFY_KEY_SIZE, NONCE_SIZE> for A
+where
+ A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+{
+ fn continued(
+ &self,
+ is_leader: bool,
+ host_state: Self::State,
+ agg_param: &Self::AggregationParam,
+ inbound: &PingPongMessage,
+ ) -> Result<Self::ContinuedValue, PingPongError> {
+ let host_prep_state = if let PingPongState::Continued(state) = host_state {
+ state
+ } else {
+ return Err(PingPongError::HostStateMismatch {
+ found: "finished",
+ expected: "continue",
+ });
+ };
+
+ let (prep_msg, next_peer_prep_share) = match inbound {
+ PingPongMessage::Initialize { .. } => {
+ return Err(PingPongError::PeerMessageMismatch {
+ found: inbound.variant(),
+ expected: "continue",
+ });
+ }
+ PingPongMessage::Continue {
+ prep_msg,
+ prep_share,
+ } => (prep_msg, Some(prep_share)),
+ PingPongMessage::Finish { prep_msg } => (prep_msg, None),
+ };
+
+ let prep_msg = Self::PrepareMessage::get_decoded_with_param(&host_prep_state, prep_msg)
+ .map_err(PingPongError::CodecPrepMessage)?;
+ let host_prep_transition = self
+ .prepare_next(host_prep_state, prep_msg)
+ .map_err(PingPongError::VdafPrepareNext)?;
+
+ match (host_prep_transition, next_peer_prep_share) {
+ (
+ PrepareTransition::Continue(next_prep_state, next_host_prep_share),
+ Some(next_peer_prep_share),
+ ) => {
+ let next_peer_prep_share = Self::PrepareShare::get_decoded_with_param(
+ &next_prep_state,
+ next_peer_prep_share,
+ )
+ .map_err(PingPongError::CodecPrepShare)?;
+ let mut prep_shares = [next_peer_prep_share, next_host_prep_share];
+ if is_leader {
+ prep_shares.reverse();
+ }
+ let current_prepare_message = self
+ .prepare_shares_to_prepare_message(agg_param, prep_shares)
+ .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?;
+
+ Ok(PingPongContinuedValue::WithMessage {
+ transition: PingPongTransition {
+ previous_prepare_state: next_prep_state,
+ current_prepare_message,
+ },
+ })
+ }
+ (PrepareTransition::Finish(output_share), None) => {
+ Ok(PingPongContinuedValue::FinishedNoMessage { output_share })
+ }
+ (PrepareTransition::Continue(_, _), None) => {
+ return Err(PingPongError::PeerMessageMismatch {
+ found: inbound.variant(),
+ expected: "continue",
+ })
+ }
+ (PrepareTransition::Finish(_), Some(_)) => {
+ return Err(PingPongError::PeerMessageMismatch {
+ found: inbound.variant(),
+ expected: "finish",
+ })
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::io::Cursor;
+
+ use super::*;
+ use crate::vdaf::dummy;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn ping_pong_one_round() {
+ let verify_key = [];
+ let aggregation_param = dummy::AggregationParam(0);
+ let nonce = [0; 16];
+ #[allow(clippy::let_unit_value)]
+ let public_share = ();
+ let input_share = dummy::InputShare(0);
+
+ let leader = dummy::Vdaf::new(1);
+ let helper = dummy::Vdaf::new(1);
+
+ // Leader inits into round 0
+ let (leader_state, leader_message) = leader
+ .leader_initialized(
+ &verify_key,
+ &aggregation_param,
+ &nonce,
+ &public_share,
+ &input_share,
+ )
+ .unwrap();
+
+ // Helper inits into round 1
+ let (helper_state, helper_message) = helper
+ .helper_initialized(
+ &verify_key,
+ &aggregation_param,
+ &nonce,
+ &public_share,
+ &input_share,
+ &leader_message,
+ )
+ .unwrap()
+ .evaluate(&helper)
+ .unwrap();
+
+ // 1 round VDAF: helper should finish immediately.
+ assert_matches!(helper_state, PingPongState::Finished(_));
+
+ let leader_state = leader
+ .leader_continued(leader_state, &aggregation_param, &helper_message)
+ .unwrap();
+ // 1 round VDAF: leader should finish when it gets helper message and emit no message.
+ assert_matches!(
+ leader_state,
+ PingPongContinuedValue::FinishedNoMessage { .. }
+ );
+ }
+
+ #[test]
+ fn ping_pong_two_rounds() {
+ let verify_key = [];
+ let aggregation_param = dummy::AggregationParam(0);
+ let nonce = [0; 16];
+ #[allow(clippy::let_unit_value)]
+ let public_share = ();
+ let input_share = dummy::InputShare(0);
+
+ let leader = dummy::Vdaf::new(2);
+ let helper = dummy::Vdaf::new(2);
+
+ // Leader inits into round 0
+ let (leader_state, leader_message) = leader
+ .leader_initialized(
+ &verify_key,
+ &aggregation_param,
+ &nonce,
+ &public_share,
+ &input_share,
+ )
+ .unwrap();
+
+ // Helper inits into round 1
+ let (helper_state, helper_message) = helper
+ .helper_initialized(
+ &verify_key,
+ &aggregation_param,
+ &nonce,
+ &public_share,
+ &input_share,
+ &leader_message,
+ )
+ .unwrap()
+ .evaluate(&helper)
+ .unwrap();
+
+ // 2 round VDAF, round 1: helper should continue.
+ assert_matches!(helper_state, PingPongState::Continued(_));
+
+ let leader_state = leader
+ .leader_continued(leader_state, &aggregation_param, &helper_message)
+ .unwrap();
+ // 2 round VDAF, round 1: leader should finish and emit a finish message.
+ let leader_message = assert_matches!(
+ leader_state, PingPongContinuedValue::WithMessage { transition } => {
+ let (state, message) = transition.evaluate(&leader).unwrap();
+ assert_matches!(state, PingPongState::Finished(_));
+ message
+ }
+ );
+
+ let helper_state = helper
+ .helper_continued(helper_state, &aggregation_param, &leader_message)
+ .unwrap();
+ // 2 round vdaf, round 1: helper should finish and emit no message.
+ assert_matches!(
+ helper_state,
+ PingPongContinuedValue::FinishedNoMessage { .. }
+ );
+ }
+
+ #[test]
+ fn ping_pong_three_rounds() {
+ let verify_key = [];
+ let aggregation_param = dummy::AggregationParam(0);
+ let nonce = [0; 16];
+ #[allow(clippy::let_unit_value)]
+ let public_share = ();
+ let input_share = dummy::InputShare(0);
+
+ let leader = dummy::Vdaf::new(3);
+ let helper = dummy::Vdaf::new(3);
+
+ // Leader inits into round 0
+ let (leader_state, leader_message) = leader
+ .leader_initialized(
+ &verify_key,
+ &aggregation_param,
+ &nonce,
+ &public_share,
+ &input_share,
+ )
+ .unwrap();
+
+ // Helper inits into round 1
+ let (helper_state, helper_message) = helper
+ .helper_initialized(
+ &verify_key,
+ &aggregation_param,
+ &nonce,
+ &public_share,
+ &input_share,
+ &leader_message,
+ )
+ .unwrap()
+ .evaluate(&helper)
+ .unwrap();
+
+ // 3 round VDAF, round 1: helper should continue.
+ assert_matches!(helper_state, PingPongState::Continued(_));
+
+ let leader_state = leader
+ .leader_continued(leader_state, &aggregation_param, &helper_message)
+ .unwrap();
+ // 3 round VDAF, round 1: leader should continue and emit a continue message.
+ let (leader_state, leader_message) = assert_matches!(
+ leader_state, PingPongContinuedValue::WithMessage { transition } => {
+ let (state, message) = transition.evaluate(&leader).unwrap();
+ assert_matches!(state, PingPongState::Continued(_));
+ (state, message)
+ }
+ );
+
+ let helper_state = helper
+ .helper_continued(helper_state, &aggregation_param, &leader_message)
+ .unwrap();
+ // 3 round vdaf, round 2: helper should finish and emit a finish message.
+ let helper_message = assert_matches!(
+ helper_state, PingPongContinuedValue::WithMessage { transition } => {
+ let (state, message) = transition.evaluate(&helper).unwrap();
+ assert_matches!(state, PingPongState::Finished(_));
+ message
+ }
+ );
+
+ let leader_state = leader
+ .leader_continued(leader_state, &aggregation_param, &helper_message)
+ .unwrap();
+ // 3 round VDAF, round 2: leader should finish and emit no message.
+ assert_matches!(
+ leader_state,
+ PingPongContinuedValue::FinishedNoMessage { .. }
+ );
+ }
+
+ #[test]
+ fn roundtrip_message() {
+ let messages = [
+ (
+ PingPongMessage::Initialize {
+ prep_share: Vec::from("prepare share"),
+ },
+ concat!(
+ "00", // enum discriminant
+ concat!(
+ // prep_share
+ "0000000d", // length
+ "70726570617265207368617265", // contents
+ ),
+ ),
+ ),
+ (
+ PingPongMessage::Continue {
+ prep_msg: Vec::from("prepare message"),
+ prep_share: Vec::from("prepare share"),
+ },
+ concat!(
+ "01", // enum discriminant
+ concat!(
+ // prep_msg
+ "0000000f", // length
+ "70726570617265206d657373616765", // contents
+ ),
+ concat!(
+ // prep_share
+ "0000000d", // length
+ "70726570617265207368617265", // contents
+ ),
+ ),
+ ),
+ (
+ PingPongMessage::Finish {
+ prep_msg: Vec::from("prepare message"),
+ },
+ concat!(
+ "02", // enum discriminant
+ concat!(
+ // prep_msg
+ "0000000f", // length
+ "70726570617265206d657373616765", // contents
+ ),
+ ),
+ ),
+ ];
+
+ for (message, expected_hex) in messages {
+ let mut encoded_val = Vec::new();
+ message.encode(&mut encoded_val);
+ let got_hex = hex::encode(&encoded_val);
+ assert_eq!(
+ &got_hex, expected_hex,
+ "Couldn't roundtrip (encoded value differs): {message:?}",
+ );
+ let decoded_val = PingPongMessage::decode(&mut Cursor::new(&encoded_val)).unwrap();
+ assert_eq!(
+ decoded_val, message,
+ "Couldn't roundtrip (decoded value differs): {message:?}"
+ );
+ assert_eq!(
+ encoded_val.len(),
+ message.encoded_len().expect("No encoded length hint"),
+ "Encoded length hint is incorrect: {message:?}"
+ )
+ }
+ }
+
+ #[test]
+ fn roundtrip_transition() {
+ // VDAF implementations have tests for encoding/decoding their respective PrepareShare and
+ // PrepareMessage types, so we test here using the dummy VDAF.
+ let transition = PingPongTransition::<0, 16, dummy::Vdaf> {
+ previous_prepare_state: dummy::PrepareState::default(),
+ current_prepare_message: (),
+ };
+
+ let encoded = transition.get_encoded();
+ let hex_encoded = hex::encode(&encoded);
+
+ assert_eq!(
+ hex_encoded,
+ concat!(
+ concat!(
+ // previous_prepare_state
+ "00", // input_share
+ "00000000", // current_round
+ ),
+ // current_prepare_message (0 length encoding)
+ )
+ );
+
+ let decoded = PingPongTransition::get_decoded_with_param(&(), &encoded).unwrap();
+ assert_eq!(transition, decoded);
+
+ assert_eq!(
+ encoded.len(),
+ transition.encoded_len().expect("No encoded length hint"),
+ );
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf.rs b/third_party/rust/prio/src/vdaf.rs
new file mode 100644
index 0000000000..1a6c5f0315
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf.rs
@@ -0,0 +1,757 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Verifiable Distributed Aggregation Functions (VDAFs) as described in
+//! [[draft-irtf-cfrg-vdaf-07]].
+//!
+//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+
+#[cfg(feature = "experimental")]
+use crate::dp::DifferentialPrivacyStrategy;
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+use crate::idpf::IdpfError;
+use crate::{
+ codec::{CodecError, Decode, Encode, ParameterizedDecode},
+ field::{encode_fieldvec, merge_vector, FieldElement, FieldError},
+ flp::FlpError,
+ prng::PrngError,
+ vdaf::xof::Seed,
+};
+use serde::{Deserialize, Serialize};
+use std::{fmt::Debug, io::Cursor};
+use subtle::{Choice, ConstantTimeEq};
+
+/// A component of the domain-separation tag, used to bind the VDAF operations to the document
+/// version. This will be revised with each draft with breaking changes.
+pub(crate) const VERSION: u8 = 7;
+
+/// Errors emitted by this module.
+#[derive(Debug, thiserror::Error)]
+pub enum VdafError {
+ /// An error occurred.
+ #[error("vdaf error: {0}")]
+ Uncategorized(String),
+
+ /// Field error.
+ #[error("field error: {0}")]
+ Field(#[from] FieldError),
+
+ /// An error occured while parsing a message.
+ #[error("io error: {0}")]
+ IoError(#[from] std::io::Error),
+
+ /// FLP error.
+ #[error("flp error: {0}")]
+ Flp(#[from] FlpError),
+
+ /// PRNG error.
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+
+ /// Failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+
+ /// IDPF error.
+ #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+ #[error("idpf error: {0}")]
+ Idpf(#[from] IdpfError),
+}
+
+/// An additive share of a vector of field elements.
+#[derive(Clone, Debug)]
+pub enum Share<F, const SEED_SIZE: usize> {
+ /// An uncompressed share, typically sent to the leader.
+ Leader(Vec<F>),
+
+ /// A compressed share, typically sent to the helper.
+ Helper(Seed<SEED_SIZE>),
+}
+
+impl<F: Clone, const SEED_SIZE: usize> Share<F, SEED_SIZE> {
+ /// Truncate the Leader's share to the given length. If this is the Helper's share, then this
+ /// method clones the input without modifying it.
+ #[cfg(feature = "prio2")]
+ pub(crate) fn truncated(&self, len: usize) -> Self {
+ match self {
+ Self::Leader(ref data) => Self::Leader(data[..len].to_vec()),
+ Self::Helper(ref seed) => Self::Helper(seed.clone()),
+ }
+ }
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Share<F, SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Share<F, SEED_SIZE> {}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Share<F, SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> subtle::Choice {
+ // We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types'
+ // contents.
+ match (self, other) {
+ (Share::Leader(self_val), Share::Leader(other_val)) => self_val.ct_eq(other_val),
+ (Share::Helper(self_val), Share::Helper(other_val)) => self_val.ct_eq(other_val),
+ _ => Choice::from(0),
+ }
+ }
+}
+
+/// Parameters needed to decode a [`Share`]
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub(crate) enum ShareDecodingParameter<const SEED_SIZE: usize> {
+ Leader(usize),
+ Helper,
+}
+
+impl<F: FieldElement, const SEED_SIZE: usize> ParameterizedDecode<ShareDecodingParameter<SEED_SIZE>>
+ for Share<F, SEED_SIZE>
+{
+ fn decode_with_param(
+ decoding_parameter: &ShareDecodingParameter<SEED_SIZE>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ match decoding_parameter {
+ ShareDecodingParameter::Leader(share_length) => {
+ let mut data = Vec::with_capacity(*share_length);
+ for _ in 0..*share_length {
+ data.push(F::decode(bytes)?)
+ }
+ Ok(Self::Leader(data))
+ }
+ ShareDecodingParameter::Helper => {
+ let seed = Seed::decode(bytes)?;
+ Ok(Self::Helper(seed))
+ }
+ }
+ }
+}
+
+impl<F: FieldElement, const SEED_SIZE: usize> Encode for Share<F, SEED_SIZE> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ match self {
+ Share::Leader(share_data) => {
+ for x in share_data {
+ x.encode(bytes);
+ }
+ }
+ Share::Helper(share_seed) => {
+ share_seed.encode(bytes);
+ }
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ match self {
+ Share::Leader(share_data) => {
+ // Each element of the data vector has the same size.
+ Some(share_data.len() * F::ENCODED_SIZE)
+ }
+ Share::Helper(share_seed) => share_seed.encoded_len(),
+ }
+ }
+}
+
+/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`],
+/// and [`Collector`], which define the roles of the various parties involved in the execution of
+/// the VDAF.
+pub trait Vdaf: Clone + Debug {
+ /// Algorithm identifier for this VDAF.
+ const ID: u32;
+
+ /// The type of Client measurement to be aggregated.
+ type Measurement: Clone + Debug;
+
+ /// The aggregate result of the VDAF execution.
+ type AggregateResult: Clone + Debug;
+
+ /// The aggregation parameter, used by the Aggregators to map their input shares to output
+ /// shares.
+ type AggregationParam: Clone + Debug + Decode + Encode;
+
+ /// A public share sent by a Client.
+ type PublicShare: Clone + Debug + ParameterizedDecode<Self> + Encode;
+
+ /// An input share sent by a Client.
+ type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode;
+
+ /// An output share recovered from an input share by an Aggregator.
+ type OutputShare: Clone
+ + Debug
+ + for<'a> ParameterizedDecode<(&'a Self, &'a Self::AggregationParam)>
+ + Encode;
+
+ /// An Aggregator's share of the aggregate result.
+ type AggregateShare: Aggregatable<OutputShare = Self::OutputShare>
+ + for<'a> ParameterizedDecode<(&'a Self, &'a Self::AggregationParam)>
+ + Encode;
+
+ /// The number of Aggregators. The Client generates as many input shares as there are
+ /// Aggregators.
+ fn num_aggregators(&self) -> usize;
+
+ /// Generate the domain separation tag for this VDAF. The output is used for domain separation
+ /// by the XOF.
+ fn domain_separation_tag(usage: u16) -> [u8; 8] {
+ let mut dst = [0_u8; 8];
+ dst[0] = VERSION;
+ dst[1] = 0; // algorithm class
+ dst[2..6].copy_from_slice(&(Self::ID).to_be_bytes());
+ dst[6..8].copy_from_slice(&usage.to_be_bytes());
+ dst
+ }
+}
+
+/// The Client's role in the execution of a VDAF.
+pub trait Client<const NONCE_SIZE: usize>: Vdaf {
+ /// Shards a measurement into a public share and a sequence of input shares, one for each
+ /// Aggregator.
+ ///
+ /// Implements `Vdaf::shard` from [VDAF].
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.1
+ fn shard(
+ &self,
+ measurement: &Self::Measurement,
+ nonce: &[u8; NONCE_SIZE],
+ ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError>;
+}
+
+/// The Aggregator's role in the execution of a VDAF.
+pub trait Aggregator<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: Vdaf {
+ /// State of the Aggregator during the Prepare process.
+ type PrepareState: Clone + Debug + PartialEq + Eq;
+
+ /// The type of messages sent by each aggregator at each round of the Prepare Process.
+ ///
+ /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
+ /// associated with any aggregator involved in the execution of the VDAF.
+ type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;
+
+ /// Result of preprocessing a round of preparation shares. This is used by all aggregators as an
+ /// input to the next round of the Prepare Process.
+ ///
+ /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
+ /// associated with any aggregator involved in the execution of the VDAF.
+ type PrepareMessage: Clone
+ + Debug
+ + PartialEq
+ + Eq
+ + ParameterizedDecode<Self::PrepareState>
+ + Encode;
+
+ /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned
+ /// is passed to [`Self::prepare_next`] to get this aggregator's first-round prepare message.
+ ///
+ /// Implements `Vdaf.prep_init` from [VDAF].
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; VERIFY_KEY_SIZE],
+ agg_id: usize,
+ agg_param: &Self::AggregationParam,
+ nonce: &[u8; NONCE_SIZE],
+ public_share: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>;
+
+ /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`].
+ ///
+ /// Implements `Vdaf.prep_shares_to_prep` from [VDAF].
+ ///
+ /// # Notes
+ ///
+ /// [`Self::prepare_shares_to_prepare_message`] is preferable since its name better matches the
+ /// specification.
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
+ #[deprecated(
+ since = "0.15.0",
+ note = "Use Vdaf::prepare_shares_to_prepare_message instead"
+ )]
+ fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>(
+ &self,
+ agg_param: &Self::AggregationParam,
+ inputs: M,
+ ) -> Result<Self::PrepareMessage, VdafError> {
+ self.prepare_shares_to_prepare_message(agg_param, inputs)
+ }
+
+ /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`].
+ ///
+ /// Implements `Vdaf.prep_shares_to_prep` from [VDAF].
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
+ fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Self::PrepareShare>>(
+ &self,
+ agg_param: &Self::AggregationParam,
+ inputs: M,
+ ) -> Result<Self::PrepareMessage, VdafError>;
+
+ /// Compute the next state transition from the current state and the previous round of input
+ /// messages. If this returns [`PrepareTransition::Continue`], then the returned
+ /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
+ /// this round and passed into another call to this method. This continues until this method
+ /// returns [`PrepareTransition::Finish`], at which point the returned output share may be
+ /// aggregated. If the method returns an error, the aggregator should consider its input share
+ /// invalid and not attempt to process it any further.
+ ///
+ /// Implements `Vdaf.prep_next` from [VDAF].
+ ///
+ /// # Notes
+ ///
+ /// [`Self::prepare_next`] is preferable since its name better matches the specification.
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
+ #[deprecated(since = "0.15.0", note = "Use Vdaf::prepare_next")]
+ fn prepare_step(
+ &self,
+ state: Self::PrepareState,
+ input: Self::PrepareMessage,
+ ) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError> {
+ self.prepare_next(state, input)
+ }
+
+ /// Compute the next state transition from the current state and the previous round of input
+ /// messages. If this returns [`PrepareTransition::Continue`], then the returned
+ /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
+ /// this round and passed into another call to this method. This continues until this method
+ /// returns [`PrepareTransition::Finish`], at which point the returned output share may be
+ /// aggregated. If the method returns an error, the aggregator should consider its input share
+ /// invalid and not attempt to process it any further.
+ ///
+ /// Implements `Vdaf.prep_next` from [VDAF].
+ ///
+ /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
+ fn prepare_next(
+ &self,
+ state: Self::PrepareState,
+ input: Self::PrepareMessage,
+ ) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError>;
+
+ /// Aggregates a sequence of output shares into an aggregate share.
+ fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
+ &self,
+ agg_param: &Self::AggregationParam,
+ output_shares: M,
+ ) -> Result<Self::AggregateShare, VdafError>;
+}
+
+/// Aggregator that implements differential privacy with Aggregator-side noise addition.
+#[cfg(feature = "experimental")]
+pub trait AggregatorWithNoise<
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+ DPStrategy: DifferentialPrivacyStrategy,
+>: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>
+{
+ /// Adds noise to an aggregate share such that the aggregate result is differentially private
+ /// as long as one Aggregator is honest.
+ fn add_noise_to_agg_share(
+ &self,
+ dp_strategy: &DPStrategy,
+ agg_param: &Self::AggregationParam,
+ agg_share: &mut Self::AggregateShare,
+ num_measurements: usize,
+ ) -> Result<(), VdafError>;
+}
+
+/// The Collector's role in the execution of a VDAF.
+pub trait Collector: Vdaf {
+ /// Combines aggregate shares into the aggregate result.
+ fn unshard<M: IntoIterator<Item = Self::AggregateShare>>(
+ &self,
+ agg_param: &Self::AggregationParam,
+ agg_shares: M,
+ num_measurements: usize,
+ ) -> Result<Self::AggregateResult, VdafError>;
+}
+
+/// A state transition of an Aggregator during the Prepare process.
+#[derive(Clone, Debug)]
+pub enum PrepareTransition<
+ V: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
+ const VERIFY_KEY_SIZE: usize,
+ const NONCE_SIZE: usize,
+> {
+ /// Continue processing.
+ Continue(V::PrepareState, V::PrepareShare),
+
+ /// Finish processing and return the output share.
+ Finish(V::OutputShare),
+}
+
+/// An aggregate share resulting from aggregating output shares together that
+/// can merged with aggregate shares of the same type.
+pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> {
+ /// Type of output shares that can be accumulated into an aggregate share.
+ type OutputShare;
+
+ /// Update an aggregate share by merging it with another (`agg_share`).
+ fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>;
+
+ /// Update an aggregate share by adding `output_share`.
+ fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>;
+}
+
+/// An output share comprised of a vector of field elements.
+#[derive(Clone)]
+pub struct OutputShare<F>(Vec<F>);
+
+impl<F: ConstantTimeEq> PartialEq for OutputShare<F> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq> Eq for OutputShare<F> {}
+
+impl<F: ConstantTimeEq> ConstantTimeEq for OutputShare<F> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl<F> AsRef<[F]> for OutputShare<F> {
+ fn as_ref(&self) -> &[F] {
+ &self.0
+ }
+}
+
+impl<F> From<Vec<F>> for OutputShare<F> {
+ fn from(other: Vec<F>) -> Self {
+ Self(other)
+ }
+}
+
+impl<F: FieldElement> Encode for OutputShare<F> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ encode_fieldvec(&self.0, bytes)
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(F::ENCODED_SIZE * self.0.len())
+ }
+}
+
+impl<F> Debug for OutputShare<F> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_tuple("OutputShare").finish()
+ }
+}
+
+/// An aggregate share comprised of a vector of field elements.
+///
+/// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field
+/// elements, and output shares need no special transformation to be merged into an aggregate share.
+#[derive(Clone, Debug, Serialize, Deserialize)]
+
+pub struct AggregateShare<F>(Vec<F>);
+
+impl<F: ConstantTimeEq> PartialEq for AggregateShare<F> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq> Eq for AggregateShare<F> {}
+
+impl<F: ConstantTimeEq> ConstantTimeEq for AggregateShare<F> {
+ fn ct_eq(&self, other: &Self) -> subtle::Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl<F: FieldElement> AsRef<[F]> for AggregateShare<F> {
+ fn as_ref(&self) -> &[F] {
+ &self.0
+ }
+}
+
+impl<F> From<OutputShare<F>> for AggregateShare<F> {
+ fn from(other: OutputShare<F>) -> Self {
+ Self(other.0)
+ }
+}
+
+impl<F: FieldElement> Aggregatable for AggregateShare<F> {
+ type OutputShare = OutputShare<F>;
+
+ fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> {
+ self.sum(agg_share.as_ref())
+ }
+
+ fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> {
+ // For Poplar1, Prio2, and Prio3, no conversion is needed between output shares and
+ // aggregate shares.
+ self.sum(output_share.as_ref())
+ }
+}
+
+impl<F: FieldElement> AggregateShare<F> {
+ fn sum(&mut self, other: &[F]) -> Result<(), VdafError> {
+ merge_vector(&mut self.0, other).map_err(Into::into)
+ }
+}
+
+impl<F: FieldElement> Encode for AggregateShare<F> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ encode_fieldvec(&self.0, bytes)
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(F::ENCODED_SIZE * self.0.len())
+ }
+}
+
+#[cfg(test)]
+pub(crate) fn run_vdaf<V, M, const SEED_SIZE: usize>(
+ vdaf: &V,
+ agg_param: &V::AggregationParam,
+ measurements: M,
+) -> Result<V::AggregateResult, VdafError>
+where
+ V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
+ M: IntoIterator<Item = V::Measurement>,
+{
+ use rand::prelude::*;
+ let mut rng = thread_rng();
+ let mut verify_key = [0; SEED_SIZE];
+ rng.fill(&mut verify_key[..]);
+
+ let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
+ let mut num_measurements: usize = 0;
+ for measurement in measurements.into_iter() {
+ num_measurements += 1;
+ let nonce = rng.gen();
+ let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?;
+ let out_shares = run_vdaf_prepare(
+ vdaf,
+ &verify_key,
+ agg_param,
+ &nonce,
+ public_share,
+ input_shares,
+ )?;
+ for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
+ // Check serialization of output shares
+ let encoded_out_share = out_share.get_encoded();
+ let round_trip_out_share =
+ V::OutputShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_out_share)
+ .unwrap();
+ assert_eq!(round_trip_out_share.get_encoded(), encoded_out_share);
+
+ let this_agg_share = V::AggregateShare::from(out_share);
+ if let Some(ref mut inner) = agg_share {
+ inner.merge(&this_agg_share)?;
+ } else {
+ *agg_share = Some(this_agg_share);
+ }
+ }
+ }
+
+ for agg_share in agg_shares.iter() {
+ // Check serialization of aggregate shares
+ let encoded_agg_share = agg_share.as_ref().unwrap().get_encoded();
+ let round_trip_agg_share =
+ V::AggregateShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_agg_share)
+ .unwrap();
+ assert_eq!(round_trip_agg_share.get_encoded(), encoded_agg_share);
+ }
+
+ let res = vdaf.unshard(
+ agg_param,
+ agg_shares.into_iter().map(|option| option.unwrap()),
+ num_measurements,
+ )?;
+ Ok(res)
+}
+
+#[cfg(test)]
+pub(crate) fn run_vdaf_prepare<V, M, const SEED_SIZE: usize>(
+ vdaf: &V,
+ verify_key: &[u8; SEED_SIZE],
+ agg_param: &V::AggregationParam,
+ nonce: &[u8; 16],
+ public_share: V::PublicShare,
+ input_shares: M,
+) -> Result<Vec<V::OutputShare>, VdafError>
+where
+ V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
+ M: IntoIterator<Item = V::InputShare>,
+{
+ let input_shares = input_shares
+ .into_iter()
+ .map(|input_share| input_share.get_encoded());
+
+ let mut states = Vec::new();
+ let mut outbound = Vec::new();
+ for (agg_id, input_share) in input_shares.enumerate() {
+ let (state, msg) = vdaf.prepare_init(
+ verify_key,
+ agg_id,
+ agg_param,
+ nonce,
+ &public_share,
+ &V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
+ .expect("failed to decode input share"),
+ )?;
+ states.push(state);
+ outbound.push(msg.get_encoded());
+ }
+
+ let mut inbound = vdaf
+ .prepare_shares_to_prepare_message(
+ agg_param,
+ outbound.iter().map(|encoded| {
+ V::PrepareShare::get_decoded_with_param(&states[0], encoded)
+ .expect("failed to decode prep share")
+ }),
+ )?
+ .get_encoded();
+
+ let mut out_shares = Vec::new();
+ loop {
+ let mut outbound = Vec::new();
+ for state in states.iter_mut() {
+ match vdaf.prepare_next(
+ state.clone(),
+ V::PrepareMessage::get_decoded_with_param(state, &inbound)
+ .expect("failed to decode prep message"),
+ )? {
+ PrepareTransition::Continue(new_state, msg) => {
+ outbound.push(msg.get_encoded());
+ *state = new_state
+ }
+ PrepareTransition::Finish(out_share) => {
+ out_shares.push(out_share);
+ }
+ }
+ }
+
+ if outbound.len() == vdaf.num_aggregators() {
+ // Another round is required before output shares are computed.
+ inbound = vdaf
+ .prepare_shares_to_prepare_message(
+ agg_param,
+ outbound.iter().map(|encoded| {
+ V::PrepareShare::get_decoded_with_param(&states[0], encoded)
+ .expect("failed to decode prep share")
+ }),
+ )?
+ .get_encoded();
+ } else if outbound.is_empty() {
+ // Each Aggregator recovered an output share.
+ break;
+ } else {
+ panic!("Aggregators did not finish the prepare phase at the same time");
+ }
+ }
+
+ Ok(out_shares)
+}
+
+#[cfg(test)]
+fn fieldvec_roundtrip_test<F, V, T>(vdaf: &V, agg_param: &V::AggregationParam, length: usize)
+where
+ F: FieldElement,
+ V: Vdaf,
+ T: Encode,
+ for<'a> T: ParameterizedDecode<(&'a V, &'a V::AggregationParam)>,
+{
+ // Generate an arbitrary vector of field elements.
+ let g = F::one() + F::one();
+ let vec: Vec<F> = itertools::iterate(F::one(), |&v| g * v)
+ .take(length)
+ .collect();
+
+ // Serialize the field element vector into a vector of bytes.
+ let mut bytes = Vec::with_capacity(vec.len() * F::ENCODED_SIZE);
+ encode_fieldvec(&vec, &mut bytes);
+
+ // Deserialize the type of interest from those bytes.
+ let value = T::get_decoded_with_param(&(vdaf, agg_param), &bytes).unwrap();
+
+ // Round-trip the value back to a vector of bytes.
+ let encoded = value.get_encoded();
+
+ assert_eq!(encoded, bytes);
+}
+
+#[cfg(test)]
+fn equality_comparison_test<T>(values: &[T])
+where
+ T: Debug + PartialEq,
+{
+ use std::ptr;
+
+ // This function expects that every value passed in `values` is distinct, i.e. should not
+ // compare as equal to any other element. We test both (i, j) and (j, i) to gain confidence that
+ // equality implementations are symmetric.
+ for (i, i_val) in values.iter().enumerate() {
+ for (j, j_val) in values.iter().enumerate() {
+ if i == j {
+ assert!(ptr::eq(i_val, j_val)); // sanity
+ assert_eq!(
+ i_val, j_val,
+ "Expected element at index {i} to be equal to itself, but it was not"
+ );
+ } else {
+ assert_ne!(
+ i_val, j_val,
+ "Expected elements at indices {i} & {j} to not be equal, but they were"
+ )
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::vdaf::{equality_comparison_test, xof::Seed, AggregateShare, OutputShare, Share};
+
+ #[test]
+ fn share_equality_test() {
+ equality_comparison_test(&[
+ Share::Leader(Vec::from([1, 2, 3])),
+ Share::Leader(Vec::from([3, 2, 1])),
+ Share::Helper(Seed([1, 2, 3])),
+ Share::Helper(Seed([3, 2, 1])),
+ ])
+ }
+
+ #[test]
+ fn output_share_equality_test() {
+ equality_comparison_test(&[
+ OutputShare(Vec::from([1, 2, 3])),
+ OutputShare(Vec::from([3, 2, 1])),
+ ])
+ }
+
+ #[test]
+ fn aggregate_share_equality_test() {
+ equality_comparison_test(&[
+ AggregateShare(Vec::from([1, 2, 3])),
+ AggregateShare(Vec::from([3, 2, 1])),
+ ])
+ }
+}
+
+#[cfg(feature = "test-util")]
+pub mod dummy;
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+#[cfg_attr(
+ docsrs,
+ doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
+)]
+pub mod poplar1;
+#[cfg(feature = "prio2")]
+#[cfg_attr(docsrs, doc(cfg(feature = "prio2")))]
+pub mod prio2;
+pub mod prio3;
+#[cfg(test)]
+mod prio3_test;
+pub mod xof;
diff --git a/third_party/rust/prio/src/vdaf/dummy.rs b/third_party/rust/prio/src/vdaf/dummy.rs
new file mode 100644
index 0000000000..507e7916bb
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/dummy.rs
@@ -0,0 +1,316 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of a dummy VDAF which conforms to the specification in [draft-irtf-cfrg-vdaf-06]
+//! but does nothing. Useful for testing.
+//!
+//! [draft-irtf-cfrg-vdaf-06]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/06/
+
+use crate::{
+ codec::{CodecError, Decode, Encode},
+ vdaf::{self, Aggregatable, PrepareTransition, VdafError},
+};
+use rand::random;
+use std::{fmt::Debug, io::Cursor, sync::Arc};
+
+type ArcPrepInitFn =
+ Arc<dyn Fn(&AggregationParam) -> Result<(), VdafError> + 'static + Send + Sync>;
+type ArcPrepStepFn = Arc<
+ dyn Fn(&PrepareState) -> Result<PrepareTransition<Vdaf, 0, 16>, VdafError>
+ + 'static
+ + Send
+ + Sync,
+>;
+
+/// Dummy VDAF that does nothing.
+#[derive(Clone)]
+pub struct Vdaf {
+ prep_init_fn: ArcPrepInitFn,
+ prep_step_fn: ArcPrepStepFn,
+}
+
+impl Debug for Vdaf {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Vdaf")
+ .field("prep_init_fn", &"[redacted]")
+ .field("prep_step_fn", &"[redacted]")
+ .finish()
+ }
+}
+
+impl Vdaf {
+ /// The length of the verify key parameter for fake VDAF instantiations.
+ pub const VERIFY_KEY_LEN: usize = 0;
+
+ /// Construct a new instance of the dummy VDAF.
+ pub fn new(rounds: u32) -> Self {
+ Self {
+ prep_init_fn: Arc::new(|_| -> Result<(), VdafError> { Ok(()) }),
+ prep_step_fn: Arc::new(
+ move |state| -> Result<PrepareTransition<Self, 0, 16>, VdafError> {
+ let new_round = state.current_round + 1;
+ if new_round == rounds {
+ Ok(PrepareTransition::Finish(OutputShare(state.input_share)))
+ } else {
+ Ok(PrepareTransition::Continue(
+ PrepareState {
+ current_round: new_round,
+ ..*state
+ },
+ (),
+ ))
+ }
+ },
+ ),
+ }
+ }
+
+ /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_init`].
+ pub fn with_prep_init_fn<F: Fn(&AggregationParam) -> Result<(), VdafError>>(
+ mut self,
+ f: F,
+ ) -> Self
+ where
+ F: 'static + Send + Sync,
+ {
+ self.prep_init_fn = Arc::new(f);
+ self
+ }
+
+ /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_step`].
+ pub fn with_prep_step_fn<
+ F: Fn(&PrepareState) -> Result<PrepareTransition<Self, 0, 16>, VdafError>,
+ >(
+ mut self,
+ f: F,
+ ) -> Self
+ where
+ F: 'static + Send + Sync,
+ {
+ self.prep_step_fn = Arc::new(f);
+ self
+ }
+}
+
+impl Default for Vdaf {
+ fn default() -> Self {
+ Self::new(1)
+ }
+}
+
+impl vdaf::Vdaf for Vdaf {
+ const ID: u32 = 0xFFFF0000;
+
+ type Measurement = u8;
+ type AggregateResult = u8;
+ type AggregationParam = AggregationParam;
+ type PublicShare = ();
+ type InputShare = InputShare;
+ type OutputShare = OutputShare;
+ type AggregateShare = AggregateShare;
+
+ fn num_aggregators(&self) -> usize {
+ 2
+ }
+}
+
+impl vdaf::Aggregator<0, 16> for Vdaf {
+ type PrepareState = PrepareState;
+ type PrepareShare = ();
+ type PrepareMessage = ();
+
+ fn prepare_init(
+ &self,
+ _verify_key: &[u8; 0],
+ _: usize,
+ aggregation_param: &Self::AggregationParam,
+ _nonce: &[u8; 16],
+ _: &Self::PublicShare,
+ input_share: &Self::InputShare,
+ ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError> {
+ (self.prep_init_fn)(aggregation_param)?;
+ Ok((
+ PrepareState {
+ input_share: input_share.0,
+ current_round: 0,
+ },
+ (),
+ ))
+ }
+
+ fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Self::PrepareShare>>(
+ &self,
+ _: &Self::AggregationParam,
+ _: M,
+ ) -> Result<Self::PrepareMessage, VdafError> {
+ Ok(())
+ }
+
+ fn prepare_next(
+ &self,
+ state: Self::PrepareState,
+ _: Self::PrepareMessage,
+ ) -> Result<PrepareTransition<Self, 0, 16>, VdafError> {
+ (self.prep_step_fn)(&state)
+ }
+
+ fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
+ &self,
+ _: &Self::AggregationParam,
+ output_shares: M,
+ ) -> Result<Self::AggregateShare, VdafError> {
+ let mut aggregate_share = AggregateShare(0);
+ for output_share in output_shares {
+ aggregate_share.accumulate(&output_share)?;
+ }
+ Ok(aggregate_share)
+ }
+}
+
+impl vdaf::Client<16> for Vdaf {
+ fn shard(
+ &self,
+ measurement: &Self::Measurement,
+ _nonce: &[u8; 16],
+ ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError> {
+ let first_input_share = random();
+ let (second_input_share, _) = measurement.overflowing_sub(first_input_share);
+ Ok((
+ (),
+ Vec::from([
+ InputShare(first_input_share),
+ InputShare(second_input_share),
+ ]),
+ ))
+ }
+}
+
+/// A dummy input share.
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
+pub struct InputShare(pub u8);
+
+impl Encode for InputShare {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes)
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ self.0.encoded_len()
+ }
+}
+
+impl Decode for InputShare {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(Self(u8::decode(bytes)?))
+ }
+}
+
+/// Dummy aggregation parameter.
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct AggregationParam(pub u8);
+
+impl Encode for AggregationParam {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes)
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ self.0.encoded_len()
+ }
+}
+
+impl Decode for AggregationParam {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(Self(u8::decode(bytes)?))
+ }
+}
+
+/// Dummy output share.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct OutputShare(pub u8);
+
+impl Decode for OutputShare {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(Self(u8::decode(bytes)?))
+ }
+}
+
+impl Encode for OutputShare {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ self.0.encoded_len()
+ }
+}
+
+/// Dummy prepare state.
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
+pub struct PrepareState {
+ input_share: u8,
+ current_round: u32,
+}
+
+impl Encode for PrepareState {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.input_share.encode(bytes);
+ self.current_round.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(self.input_share.encoded_len()? + self.current_round.encoded_len()?)
+ }
+}
+
+impl Decode for PrepareState {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let input_share = u8::decode(bytes)?;
+ let current_round = u32::decode(bytes)?;
+
+ Ok(Self {
+ input_share,
+ current_round,
+ })
+ }
+}
+
+/// Dummy aggregate share.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct AggregateShare(pub u64);
+
+impl Aggregatable for AggregateShare {
+ type OutputShare = OutputShare;
+
+ fn merge(&mut self, other: &Self) -> Result<(), VdafError> {
+ self.0 += other.0;
+ Ok(())
+ }
+
+ fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> {
+ self.0 += u64::from(out_share.0);
+ Ok(())
+ }
+}
+
+impl From<OutputShare> for AggregateShare {
+ fn from(out_share: OutputShare) -> Self {
+ Self(u64::from(out_share.0))
+ }
+}
+
+impl Decode for AggregateShare {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let val = u64::decode(bytes)?;
+ Ok(Self(val))
+ }
+}
+
+impl Encode for AggregateShare {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes)
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ self.0.encoded_len()
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/poplar1.rs b/third_party/rust/prio/src/vdaf/poplar1.rs
new file mode 100644
index 0000000000..e8591f2049
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/poplar1.rs
@@ -0,0 +1,2465 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of Poplar1 as specified in [[draft-irtf-cfrg-vdaf-07]].
+//!
+//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+
+use crate::{
+ codec::{CodecError, Decode, Encode, ParameterizedDecode},
+ field::{decode_fieldvec, merge_vector, Field255, Field64, FieldElement},
+ idpf::{Idpf, IdpfInput, IdpfOutputShare, IdpfPublicShare, IdpfValue, RingBufferCache},
+ prng::Prng,
+ vdaf::{
+ xof::{Seed, Xof, XofShake128},
+ Aggregatable, Aggregator, Client, Collector, PrepareTransition, Vdaf, VdafError,
+ },
+};
+use bitvec::{prelude::Lsb0, vec::BitVec};
+use rand_core::RngCore;
+use std::{
+ convert::TryFrom,
+ fmt::Debug,
+ io::{Cursor, Read},
+ iter,
+ marker::PhantomData,
+ num::TryFromIntError,
+ ops::{Add, AddAssign, Sub},
+};
+use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
+
+const DST_SHARD_RANDOMNESS: u16 = 1;
+const DST_CORR_INNER: u16 = 2;
+const DST_CORR_LEAF: u16 = 3;
+const DST_VERIFY_RANDOMNESS: u16 = 4;
+
+impl<P, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> {
+ /// Create an instance of [`Poplar1`]. The caller provides the bit length of each
+ /// measurement (`BITS` as defined in the [[draft-irtf-cfrg-vdaf-07]]).
+ ///
+ /// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+ pub fn new(bits: usize) -> Self {
+ Self {
+ bits,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl Poplar1<XofShake128, 16> {
+ /// Create an instance of [`Poplar1`] using [`XofShake128`]. The caller provides the bit length of
+ /// each measurement (`BITS` as defined in the [[draft-irtf-cfrg-vdaf-07]]).
+ ///
+ /// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+ pub fn new_shake128(bits: usize) -> Self {
+ Poplar1::new(bits)
+ }
+}
+
+/// The Poplar1 VDAF.
+#[derive(Debug)]
+pub struct Poplar1<P, const SEED_SIZE: usize> {
+ bits: usize,
+ phantom: PhantomData<P>,
+}
+
+impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> {
+ /// Construct a `Prng` with the given seed and info-string suffix.
+ fn init_prng<I, B, F>(
+ seed: &[u8; SEED_SIZE],
+ usage: u16,
+ binder_chunks: I,
+ ) -> Prng<F, P::SeedStream>
+ where
+ I: IntoIterator<Item = B>,
+ B: AsRef<[u8]>,
+ P: Xof<SEED_SIZE>,
+ F: FieldElement,
+ {
+ let mut xof = P::init(seed, &Self::domain_separation_tag(usage));
+ for binder_chunk in binder_chunks.into_iter() {
+ xof.update(binder_chunk.as_ref());
+ }
+ Prng::from_seed_stream(xof.into_seed_stream())
+ }
+}
+
+impl<P, const SEED_SIZE: usize> Clone for Poplar1<P, SEED_SIZE> {
+ fn clone(&self) -> Self {
+ Self {
+ bits: self.bits,
+ phantom: PhantomData,
+ }
+ }
+}
+
+/// Poplar1 public share.
+///
+/// This is comprised of the correction words generated for the IDPF.
+pub type Poplar1PublicShare =
+ IdpfPublicShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>;
+
+impl<P, const SEED_SIZE: usize> ParameterizedDecode<Poplar1<P, SEED_SIZE>> for Poplar1PublicShare {
+ fn decode_with_param(
+ poplar1: &Poplar1<P, SEED_SIZE>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Self::decode_with_param(&poplar1.bits, bytes)
+ }
+}
+
+/// Poplar1 input share.
+///
+/// This is comprised of an IDPF key share and the correlated randomness used to compute the sketch
+/// during preparation.
+#[derive(Debug, Clone)]
+pub struct Poplar1InputShare<const SEED_SIZE: usize> {
+ /// IDPF key share.
+ idpf_key: Seed<16>,
+
+ /// Seed used to generate the Aggregator's share of the correlated randomness used in the first
+ /// part of the sketch.
+ corr_seed: Seed<SEED_SIZE>,
+
+ /// Aggregator's share of the correlated randomness used in the second part of the sketch. Used
+ /// for inner nodes of the IDPF tree.
+ corr_inner: Vec<[Field64; 2]>,
+
+ /// Aggregator's share of the correlated randomness used in the second part of the sketch. Used
+ /// for leaf nodes of the IDPF tree.
+ corr_leaf: [Field255; 2],
+}
+
+impl<const SEED_SIZE: usize> PartialEq for Poplar1InputShare<SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<const SEED_SIZE: usize> Eq for Poplar1InputShare<SEED_SIZE> {}
+
+impl<const SEED_SIZE: usize> ConstantTimeEq for Poplar1InputShare<SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We short-circuit on the length of corr_inner being different. Only the content is
+ // protected.
+ if self.corr_inner.len() != other.corr_inner.len() {
+ return Choice::from(0);
+ }
+
+ let mut res = self.idpf_key.ct_eq(&other.idpf_key)
+ & self.corr_seed.ct_eq(&other.corr_seed)
+ & self.corr_leaf.ct_eq(&other.corr_leaf);
+ for (x, y) in self.corr_inner.iter().zip(other.corr_inner.iter()) {
+ res &= x.ct_eq(y);
+ }
+ res
+ }
+}
+
+impl<const SEED_SIZE: usize> Encode for Poplar1InputShare<SEED_SIZE> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.idpf_key.encode(bytes);
+ self.corr_seed.encode(bytes);
+ for corr in self.corr_inner.iter() {
+ corr[0].encode(bytes);
+ corr[1].encode(bytes);
+ }
+ self.corr_leaf[0].encode(bytes);
+ self.corr_leaf[1].encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ let mut len = 0;
+ len += SEED_SIZE; // idpf_key
+ len += SEED_SIZE; // corr_seed
+ len += self.corr_inner.len() * 2 * Field64::ENCODED_SIZE; // corr_inner
+ len += 2 * Field255::ENCODED_SIZE; // corr_leaf
+ Some(len)
+ }
+}
+
+impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)>
+ for Poplar1InputShare<SEED_SIZE>
+{
+ fn decode_with_param(
+ (poplar1, _agg_id): &(&'a Poplar1<P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let idpf_key = Seed::decode(bytes)?;
+ let corr_seed = Seed::decode(bytes)?;
+ let mut corr_inner = Vec::with_capacity(poplar1.bits - 1);
+ for _ in 0..poplar1.bits - 1 {
+ corr_inner.push([Field64::decode(bytes)?, Field64::decode(bytes)?]);
+ }
+ let corr_leaf = [Field255::decode(bytes)?, Field255::decode(bytes)?];
+ Ok(Self {
+ idpf_key,
+ corr_seed,
+ corr_inner,
+ corr_leaf,
+ })
+ }
+}
+
+/// Poplar1 preparation state.
+#[derive(Clone, Debug)]
+pub struct Poplar1PrepareState(PrepareStateVariant);
+
+impl PartialEq for Poplar1PrepareState {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl Eq for Poplar1PrepareState {}
+
+impl ConstantTimeEq for Poplar1PrepareState {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl Encode for Poplar1PrepareState {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes)
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ self.0.encoded_len()
+ }
+}
+
+impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)>
+ for Poplar1PrepareState
+{
+ fn decode_with_param(
+ decoding_parameter: &(&'a Poplar1<P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Ok(Self(PrepareStateVariant::decode_with_param(
+ decoding_parameter,
+ bytes,
+ )?))
+ }
+}
+
+#[derive(Clone, Debug)]
+enum PrepareStateVariant {
+ Inner(PrepareState<Field64>),
+ Leaf(PrepareState<Field255>),
+}
+
+impl PartialEq for PrepareStateVariant {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl Eq for PrepareStateVariant {}
+
+impl ConstantTimeEq for PrepareStateVariant {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the type (Inner vs Leaf).
+ match (self, other) {
+ (Self::Inner(self_val), Self::Inner(other_val)) => self_val.ct_eq(other_val),
+ (Self::Leaf(self_val), Self::Leaf(other_val)) => self_val.ct_eq(other_val),
+ _ => Choice::from(0),
+ }
+ }
+}
+
+impl Encode for PrepareStateVariant {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ match self {
+ PrepareStateVariant::Inner(prep_state) => {
+ 0u8.encode(bytes);
+ prep_state.encode(bytes);
+ }
+ PrepareStateVariant::Leaf(prep_state) => {
+ 1u8.encode(bytes);
+ prep_state.encode(bytes);
+ }
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(
+ 1 + match self {
+ PrepareStateVariant::Inner(prep_state) => prep_state.encoded_len()?,
+ PrepareStateVariant::Leaf(prep_state) => prep_state.encoded_len()?,
+ },
+ )
+ }
+}
+
+impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)>
+ for PrepareStateVariant
+{
+ fn decode_with_param(
+ decoding_parameter: &(&'a Poplar1<P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ match u8::decode(bytes)? {
+ 0 => {
+ let prep_state = PrepareState::decode_with_param(decoding_parameter, bytes)?;
+ Ok(Self::Inner(prep_state))
+ }
+ 1 => {
+ let prep_state = PrepareState::decode_with_param(decoding_parameter, bytes)?;
+ Ok(Self::Leaf(prep_state))
+ }
+ _ => Err(CodecError::UnexpectedValue),
+ }
+ }
+}
+
+#[derive(Clone)]
+struct PrepareState<F> {
+ sketch: SketchState<F>,
+ output_share: Vec<F>,
+}
+
+impl<F: ConstantTimeEq> PartialEq for PrepareState<F> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq> Eq for PrepareState<F> {}
+
+impl<F: ConstantTimeEq> ConstantTimeEq for PrepareState<F> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.sketch.ct_eq(&other.sketch) & self.output_share.ct_eq(&other.output_share)
+ }
+}
+
+impl<F> Debug for PrepareState<F> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PrepareState")
+ .field("sketch", &"[redacted]")
+ .field("output_share", &"[redacted]")
+ .finish()
+ }
+}
+
+impl<F: FieldElement> Encode for PrepareState<F> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.sketch.encode(bytes);
+ // `expect` safety: output_share's length is the same as the number of prefixes; the number
+ // of prefixes is capped at 2^32-1.
+ u32::try_from(self.output_share.len())
+ .expect("Couldn't convert output_share length to u32")
+ .encode(bytes);
+ for elem in &self.output_share {
+ elem.encode(bytes);
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(self.sketch.encoded_len()? + 4 + self.output_share.len() * F::ENCODED_SIZE)
+ }
+}
+
+impl<'a, P, F: FieldElement, const SEED_SIZE: usize>
+ ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> for PrepareState<F>
+{
+ fn decode_with_param(
+ decoding_parameter: &(&'a Poplar1<P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let sketch = SketchState::<F>::decode_with_param(decoding_parameter, bytes)?;
+ let output_share_len = u32::decode(bytes)?
+ .try_into()
+ .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?;
+ let output_share = iter::repeat_with(|| F::decode(bytes))
+ .take(output_share_len)
+ .collect::<Result<_, _>>()?;
+ Ok(Self {
+ sketch,
+ output_share,
+ })
+ }
+}
+
+#[derive(Clone, Debug)]
+enum SketchState<F> {
+ #[allow(non_snake_case)]
+ RoundOne {
+ A_share: F,
+ B_share: F,
+ is_leader: bool,
+ },
+ RoundTwo,
+}
+
+impl<F: ConstantTimeEq> PartialEq for SketchState<F> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq> Eq for SketchState<F> {}
+
+impl<F: ConstantTimeEq> ConstantTimeEq for SketchState<F> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the round (RoundOne vs RoundTwo), as well as is_leader for
+ // RoundOne comparisons.
+ match (self, other) {
+ (
+ SketchState::RoundOne {
+ A_share: self_a_share,
+ B_share: self_b_share,
+ is_leader: self_is_leader,
+ },
+ SketchState::RoundOne {
+ A_share: other_a_share,
+ B_share: other_b_share,
+ is_leader: other_is_leader,
+ },
+ ) => {
+ if self_is_leader != other_is_leader {
+ return Choice::from(0);
+ }
+
+ self_a_share.ct_eq(other_a_share) & self_b_share.ct_eq(other_b_share)
+ }
+
+ (SketchState::RoundTwo, SketchState::RoundTwo) => Choice::from(1),
+ _ => Choice::from(0),
+ }
+ }
+}
+
+impl<F: FieldElement> Encode for SketchState<F> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ match self {
+ SketchState::RoundOne {
+ A_share, B_share, ..
+ } => {
+ 0u8.encode(bytes);
+ A_share.encode(bytes);
+ B_share.encode(bytes);
+ }
+ SketchState::RoundTwo => 1u8.encode(bytes),
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(
+ 1 + match self {
+ SketchState::RoundOne { .. } => 2 * F::ENCODED_SIZE,
+ SketchState::RoundTwo => 0,
+ },
+ )
+ }
+}
+
+impl<'a, P, F: FieldElement, const SEED_SIZE: usize>
+ ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> for SketchState<F>
+{
+ #[allow(non_snake_case)]
+ fn decode_with_param(
+ (_, agg_id): &(&'a Poplar1<P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ match u8::decode(bytes)? {
+ 0 => {
+ let A_share = F::decode(bytes)?;
+ let B_share = F::decode(bytes)?;
+ let is_leader = agg_id == &0;
+ Ok(Self::RoundOne {
+ A_share,
+ B_share,
+ is_leader,
+ })
+ }
+ 1 => Ok(Self::RoundTwo),
+ _ => Err(CodecError::UnexpectedValue),
+ }
+ }
+}
+
+impl<F: FieldElement> SketchState<F> {
+ fn decode_sketch_share(&self, bytes: &mut Cursor<&[u8]>) -> Result<Vec<F>, CodecError> {
+ match self {
+ // The sketch share is three field elements.
+ Self::RoundOne { .. } => Ok(vec![
+ F::decode(bytes)?,
+ F::decode(bytes)?,
+ F::decode(bytes)?,
+ ]),
+ // The sketch verifier share is one field element.
+ Self::RoundTwo => Ok(vec![F::decode(bytes)?]),
+ }
+ }
+
+ fn decode_sketch(&self, bytes: &mut Cursor<&[u8]>) -> Result<Option<[F; 3]>, CodecError> {
+ match self {
+ // The sketch is three field elements.
+ Self::RoundOne { .. } => Ok(Some([
+ F::decode(bytes)?,
+ F::decode(bytes)?,
+ F::decode(bytes)?,
+ ])),
+ // The sketch verifier should be zero if the sketch if valid. Instead of transmitting
+ // this zero over the wire, we just expect an empty message.
+ Self::RoundTwo => Ok(None),
+ }
+ }
+}
+
+/// Poplar1 preparation message.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Poplar1PrepareMessage(PrepareMessageVariant);
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+enum PrepareMessageVariant {
+ SketchInner([Field64; 3]),
+ SketchLeaf([Field255; 3]),
+ Done,
+}
+
+impl Encode for Poplar1PrepareMessage {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ match self.0 {
+ PrepareMessageVariant::SketchInner(vec) => {
+ vec[0].encode(bytes);
+ vec[1].encode(bytes);
+ vec[2].encode(bytes);
+ }
+ PrepareMessageVariant::SketchLeaf(vec) => {
+ vec[0].encode(bytes);
+ vec[1].encode(bytes);
+ vec[2].encode(bytes);
+ }
+ PrepareMessageVariant::Done => (),
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ match self.0 {
+ PrepareMessageVariant::SketchInner(..) => Some(3 * Field64::ENCODED_SIZE),
+ PrepareMessageVariant::SketchLeaf(..) => Some(3 * Field255::ENCODED_SIZE),
+ PrepareMessageVariant::Done => Some(0),
+ }
+ }
+}
+
+impl ParameterizedDecode<Poplar1PrepareState> for Poplar1PrepareMessage {
+ fn decode_with_param(
+ state: &Poplar1PrepareState,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ match state.0 {
+ PrepareStateVariant::Inner(ref state_variant) => Ok(Self(
+ state_variant
+ .sketch
+ .decode_sketch(bytes)?
+ .map_or(PrepareMessageVariant::Done, |sketch| {
+ PrepareMessageVariant::SketchInner(sketch)
+ }),
+ )),
+ PrepareStateVariant::Leaf(ref state_variant) => Ok(Self(
+ state_variant
+ .sketch
+ .decode_sketch(bytes)?
+ .map_or(PrepareMessageVariant::Done, |sketch| {
+ PrepareMessageVariant::SketchLeaf(sketch)
+ }),
+ )),
+ }
+ }
+}
+
+/// A vector of field elements transmitted while evaluating Poplar1.
+#[derive(Clone, Debug)]
+pub enum Poplar1FieldVec {
+ /// Field type for inner nodes of the IDPF tree.
+ Inner(Vec<Field64>),
+
+ /// Field type for leaf nodes of the IDPF tree.
+ Leaf(Vec<Field255>),
+}
+
+impl Poplar1FieldVec {
+ fn zero(is_leaf: bool, len: usize) -> Self {
+ if is_leaf {
+ Self::Leaf(vec![<Field255 as FieldElement>::zero(); len])
+ } else {
+ Self::Inner(vec![<Field64 as FieldElement>::zero(); len])
+ }
+ }
+}
+
+impl PartialEq for Poplar1FieldVec {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl Eq for Poplar1FieldVec {}
+
+impl ConstantTimeEq for Poplar1FieldVec {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the type (Inner vs Leaf).
+ match (self, other) {
+ (Poplar1FieldVec::Inner(self_val), Poplar1FieldVec::Inner(other_val)) => {
+ self_val.ct_eq(other_val)
+ }
+ (Poplar1FieldVec::Leaf(self_val), Poplar1FieldVec::Leaf(other_val)) => {
+ self_val.ct_eq(other_val)
+ }
+ _ => Choice::from(0),
+ }
+ }
+}
+
+impl Encode for Poplar1FieldVec {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ match self {
+ Self::Inner(ref data) => {
+ for elem in data {
+ elem.encode(bytes);
+ }
+ }
+ Self::Leaf(ref data) => {
+ for elem in data {
+ elem.encode(bytes);
+ }
+ }
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ match self {
+ Self::Inner(ref data) => Some(Field64::ENCODED_SIZE * data.len()),
+ Self::Leaf(ref data) => Some(Field255::ENCODED_SIZE * data.len()),
+ }
+ }
+}
+
+impl<'a, P: Xof<SEED_SIZE>, const SEED_SIZE: usize>
+ ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, &'a Poplar1AggregationParam)>
+ for Poplar1FieldVec
+{
+ fn decode_with_param(
+ (poplar1, agg_param): &(&'a Poplar1<P, SEED_SIZE>, &'a Poplar1AggregationParam),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ if agg_param.level() == poplar1.bits - 1 {
+ decode_fieldvec(agg_param.prefixes().len(), bytes).map(Poplar1FieldVec::Leaf)
+ } else {
+ decode_fieldvec(agg_param.prefixes().len(), bytes).map(Poplar1FieldVec::Inner)
+ }
+ }
+}
+
+impl ParameterizedDecode<Poplar1PrepareState> for Poplar1FieldVec {
+ fn decode_with_param(
+ state: &Poplar1PrepareState,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ match state.0 {
+ PrepareStateVariant::Inner(ref state_variant) => Ok(Poplar1FieldVec::Inner(
+ state_variant.sketch.decode_sketch_share(bytes)?,
+ )),
+ PrepareStateVariant::Leaf(ref state_variant) => Ok(Poplar1FieldVec::Leaf(
+ state_variant.sketch.decode_sketch_share(bytes)?,
+ )),
+ }
+ }
+}
+
+impl Aggregatable for Poplar1FieldVec {
+ type OutputShare = Self;
+
+ fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> {
+ match (self, agg_share) {
+ (Self::Inner(ref mut left), Self::Inner(right)) => Ok(merge_vector(left, right)?),
+ (Self::Leaf(ref mut left), Self::Leaf(right)) => Ok(merge_vector(left, right)?),
+ _ => Err(VdafError::Uncategorized(
+ "cannot merge leaf nodes wiith inner nodes".into(),
+ )),
+ }
+ }
+
+ fn accumulate(&mut self, output_share: &Self) -> Result<(), VdafError> {
+ match (self, output_share) {
+ (Self::Inner(ref mut left), Self::Inner(right)) => Ok(merge_vector(left, right)?),
+ (Self::Leaf(ref mut left), Self::Leaf(right)) => Ok(merge_vector(left, right)?),
+ _ => Err(VdafError::Uncategorized(
+ "cannot accumulate leaf nodes with inner nodes".into(),
+ )),
+ }
+ }
+}
+
+/// Poplar1 aggregation parameter.
+///
+/// This includes an indication of what level of the IDPF tree is being evaluated and the set of
+/// prefixes to evaluate at that level.
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+pub struct Poplar1AggregationParam {
+ level: u16,
+ prefixes: Vec<IdpfInput>,
+}
+
+impl Poplar1AggregationParam {
+ /// Construct an aggregation parameter from a set of candidate prefixes.
+ ///
+ /// # Errors
+ ///
+ /// * The list of prefixes is empty.
+ /// * The prefixes have different lengths (they must all be the same).
+ /// * The prefixes have length 0, or length longer than 2^16 bits.
+ /// * There are more than 2^32 - 1 prefixes.
+ /// * The prefixes are not unique.
+ /// * The prefixes are not in lexicographic order.
+ pub fn try_from_prefixes(prefixes: Vec<IdpfInput>) -> Result<Self, VdafError> {
+ if prefixes.is_empty() {
+ return Err(VdafError::Uncategorized(
+ "at least one prefix is required".into(),
+ ));
+ }
+ if u32::try_from(prefixes.len()).is_err() {
+ return Err(VdafError::Uncategorized("too many prefixes".into()));
+ }
+
+ let len = prefixes[0].len();
+ let mut last_prefix = None;
+ for prefix in prefixes.iter() {
+ if prefix.len() != len {
+ return Err(VdafError::Uncategorized(
+ "all prefixes must have the same length".into(),
+ ));
+ }
+ if let Some(last_prefix) = last_prefix {
+ if prefix <= last_prefix {
+ if prefix == last_prefix {
+ return Err(VdafError::Uncategorized(
+ "prefixes must be nonrepeating".into(),
+ ));
+ } else {
+ return Err(VdafError::Uncategorized(
+ "prefixes must be in lexicographic order".into(),
+ ));
+ }
+ }
+ }
+ last_prefix = Some(prefix);
+ }
+
+ let level = len
+ .checked_sub(1)
+ .ok_or_else(|| VdafError::Uncategorized("prefixes are too short".into()))?;
+ let level = u16::try_from(level)
+ .map_err(|_| VdafError::Uncategorized("prefixes are too long".into()))?;
+
+ Ok(Self { level, prefixes })
+ }
+
+ /// Return the level of the IDPF tree.
+ pub fn level(&self) -> usize {
+ usize::from(self.level)
+ }
+
+ /// Return the prefixes.
+ pub fn prefixes(&self) -> &[IdpfInput] {
+ self.prefixes.as_ref()
+ }
+}
+
+impl Encode for Poplar1AggregationParam {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ // Okay to unwrap because `try_from_prefixes()` checks this conversion succeeds.
+ let prefix_count = u32::try_from(self.prefixes.len()).unwrap();
+ self.level.encode(bytes);
+ prefix_count.encode(bytes);
+
+ // The encoding of the prefixes is defined by treating the IDPF indices as integers,
+ // shifting and ORing them together, and encoding the resulting arbitrary precision integer
+ // in big endian byte order. Thus, the first prefix will appear in the last encoded byte,
+ // aligned to its least significant bit. The last prefix will appear in the first encoded
+ // byte, not necessarily aligned to a byte boundary. If the highest bits in the first byte
+ // are unused, they will be set to zero.
+
+ // When an IDPF index is treated as an integer, the first bit is the integer's most
+ // significant bit, and bits are subsequently processed in order of decreasing significance.
+ // Thus, setting aside the order of bytes, bits within each byte are ordered with the
+ // [`Msb0`](bitvec::prelude::Msb0) convention, not [`Lsb0`](bitvec::prelude::Msb0). Yet,
+ // the entire integer is aligned to the least significant bit of the last byte, so we
+ // could not use `Msb0` directly without padding adjustments. Instead, we use `Lsb0`
+ // throughout and reverse the bit order of each prefix.
+
+ let mut packed = self
+ .prefixes
+ .iter()
+ .flat_map(|input| input.iter().rev())
+ .collect::<BitVec<u8, Lsb0>>();
+ packed.set_uninitialized(false);
+ let mut packed = packed.into_vec();
+ packed.reverse();
+ bytes.append(&mut packed);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ let packed_bit_count = (usize::from(self.level) + 1) * self.prefixes.len();
+ // 4 bytes for the number of prefixes, 2 bytes for the level, and a variable number of bytes
+ // for the packed prefixes themselves.
+ Some(6 + (packed_bit_count + 7) / 8)
+ }
+}
+
+impl Decode for Poplar1AggregationParam {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let level = u16::decode(bytes)?;
+ let prefix_count =
+ usize::try_from(u32::decode(bytes)?).map_err(|e| CodecError::Other(e.into()))?;
+
+ let packed_bit_count = (usize::from(level) + 1) * prefix_count;
+ let mut packed = vec![0u8; (packed_bit_count + 7) / 8];
+ bytes.read_exact(&mut packed)?;
+ if packed_bit_count % 8 != 0 {
+ let unused_bits = packed[0] >> (packed_bit_count % 8);
+ if unused_bits != 0 {
+ return Err(CodecError::UnexpectedValue);
+ }
+ }
+ packed.reverse();
+ let bits = BitVec::<u8, Lsb0>::from_vec(packed);
+
+ let prefixes = bits
+ .chunks_exact(usize::from(level) + 1)
+ .take(prefix_count)
+ .map(|chunk| IdpfInput::from(chunk.iter().rev().collect::<BitVec>()))
+ .collect::<Vec<IdpfInput>>();
+
+ Poplar1AggregationParam::try_from_prefixes(prefixes)
+ .map_err(|e| CodecError::Other(e.into()))
+ }
+}
+
+impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Vdaf for Poplar1<P, SEED_SIZE> {
+ const ID: u32 = 0x00001000;
+ type Measurement = IdpfInput;
+ type AggregateResult = Vec<u64>;
+ type AggregationParam = Poplar1AggregationParam;
+ type PublicShare = Poplar1PublicShare;
+ type InputShare = Poplar1InputShare<SEED_SIZE>;
+ type OutputShare = Poplar1FieldVec;
+ type AggregateShare = Poplar1FieldVec;
+
+ fn num_aggregators(&self) -> usize {
+ 2
+ }
+}
+
+impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> {
+ fn shard_with_random(
+ &self,
+ input: &IdpfInput,
+ nonce: &[u8; 16],
+ idpf_random: &[[u8; 16]; 2],
+ poplar_random: &[[u8; SEED_SIZE]; 3],
+ ) -> Result<(Poplar1PublicShare, Vec<Poplar1InputShare<SEED_SIZE>>), VdafError> {
+ if input.len() != self.bits {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected input length ({})",
+ input.len()
+ )));
+ }
+
+ // Generate the authenticator for each inner level of the IDPF tree.
+ let mut prng =
+ Self::init_prng::<_, _, Field64>(&poplar_random[2], DST_SHARD_RANDOMNESS, [&[]]);
+ let auth_inner: Vec<Field64> = (0..self.bits - 1).map(|_| prng.get()).collect();
+
+ // Generate the authenticator for the last level of the IDPF tree (i.e., the leaves).
+ //
+ // TODO(cjpatton) spec: Consider using a different XOF for the leaf and inner nodes.
+ // "Switching" the XOF between field types is awkward.
+ let mut prng = prng.into_new_field::<Field255>();
+ let auth_leaf = prng.get();
+
+ // Generate the IDPF shares.
+ let idpf = Idpf::new((), ());
+ let (public_share, [idpf_key_0, idpf_key_1]) = idpf.gen_with_random(
+ input,
+ auth_inner
+ .iter()
+ .map(|auth| Poplar1IdpfValue([Field64::one(), *auth])),
+ Poplar1IdpfValue([Field255::one(), auth_leaf]),
+ nonce,
+ idpf_random,
+ )?;
+
+ // Generate the correlated randomness for the inner nodes. This includes additive shares of
+ // the random offsets `a, b, c` and additive shares of `A := -2*a + auth` and `B := a^2 + b
+ // - a*auth + c`, where `auth` is the authenticator for the level of the tree. These values
+ // are used, respectively, to compute and verify the sketch during the preparation phase.
+ // (See Section 4.2 of [BBCG+21].)
+ let corr_seed_0 = &poplar_random[0];
+ let corr_seed_1 = &poplar_random[1];
+ let mut prng = prng.into_new_field::<Field64>();
+ let mut corr_prng_0 = Self::init_prng::<_, _, Field64>(
+ corr_seed_0,
+ DST_CORR_INNER,
+ [[0].as_slice(), nonce.as_slice()],
+ );
+ let mut corr_prng_1 = Self::init_prng::<_, _, Field64>(
+ corr_seed_1,
+ DST_CORR_INNER,
+ [[1].as_slice(), nonce.as_slice()],
+ );
+ let mut corr_inner_0 = Vec::with_capacity(self.bits - 1);
+ let mut corr_inner_1 = Vec::with_capacity(self.bits - 1);
+ for auth in auth_inner.into_iter() {
+ let (next_corr_inner_0, next_corr_inner_1) =
+ compute_next_corr_shares(&mut prng, &mut corr_prng_0, &mut corr_prng_1, auth);
+ corr_inner_0.push(next_corr_inner_0);
+ corr_inner_1.push(next_corr_inner_1);
+ }
+
+ // Generate the correlated randomness for the leaf nodes.
+ let mut prng = prng.into_new_field::<Field255>();
+ let mut corr_prng_0 = Self::init_prng::<_, _, Field255>(
+ corr_seed_0,
+ DST_CORR_LEAF,
+ [[0].as_slice(), nonce.as_slice()],
+ );
+ let mut corr_prng_1 = Self::init_prng::<_, _, Field255>(
+ corr_seed_1,
+ DST_CORR_LEAF,
+ [[1].as_slice(), nonce.as_slice()],
+ );
+ let (corr_leaf_0, corr_leaf_1) =
+ compute_next_corr_shares(&mut prng, &mut corr_prng_0, &mut corr_prng_1, auth_leaf);
+
+ Ok((
+ public_share,
+ vec![
+ Poplar1InputShare {
+ idpf_key: idpf_key_0,
+ corr_seed: Seed::from_bytes(*corr_seed_0),
+ corr_inner: corr_inner_0,
+ corr_leaf: corr_leaf_0,
+ },
+ Poplar1InputShare {
+ idpf_key: idpf_key_1,
+ corr_seed: Seed::from_bytes(*corr_seed_1),
+ corr_inner: corr_inner_1,
+ corr_leaf: corr_leaf_1,
+ },
+ ],
+ ))
+ }
+}
+
+impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Client<16> for Poplar1<P, SEED_SIZE> {
+ fn shard(
+ &self,
+ input: &IdpfInput,
+ nonce: &[u8; 16],
+ ) -> Result<(Self::PublicShare, Vec<Poplar1InputShare<SEED_SIZE>>), VdafError> {
+ let mut idpf_random = [[0u8; 16]; 2];
+ let mut poplar_random = [[0u8; SEED_SIZE]; 3];
+ for random_seed in idpf_random.iter_mut() {
+ getrandom::getrandom(random_seed)?;
+ }
+ for random_seed in poplar_random.iter_mut() {
+ getrandom::getrandom(random_seed)?;
+ }
+ self.shard_with_random(input, nonce, &idpf_random, &poplar_random)
+ }
+}
+
+impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16>
+ for Poplar1<P, SEED_SIZE>
+{
+ type PrepareState = Poplar1PrepareState;
+ type PrepareShare = Poplar1FieldVec;
+ type PrepareMessage = Poplar1PrepareMessage;
+
+ #[allow(clippy::type_complexity)]
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; SEED_SIZE],
+ agg_id: usize,
+ agg_param: &Poplar1AggregationParam,
+ nonce: &[u8; 16],
+ public_share: &Poplar1PublicShare,
+ input_share: &Poplar1InputShare<SEED_SIZE>,
+ ) -> Result<(Poplar1PrepareState, Poplar1FieldVec), VdafError> {
+ let is_leader = match agg_id {
+ 0 => true,
+ 1 => false,
+ _ => {
+ return Err(VdafError::Uncategorized(format!(
+ "invalid aggregator ID ({agg_id})"
+ )))
+ }
+ };
+
+ if usize::from(agg_param.level) < self.bits - 1 {
+ let mut corr_prng = Self::init_prng::<_, _, Field64>(
+ input_share.corr_seed.as_ref(),
+ DST_CORR_INNER,
+ [[agg_id as u8].as_slice(), nonce.as_slice()],
+ );
+ // Fast-forward the correlated randomness XOF to the level of the tree that we are
+ // aggregating.
+ for _ in 0..3 * agg_param.level {
+ corr_prng.get();
+ }
+
+ let (output_share, sketch_share) = eval_and_sketch::<P, Field64, SEED_SIZE>(
+ verify_key,
+ agg_id,
+ nonce,
+ agg_param,
+ public_share,
+ &input_share.idpf_key,
+ &mut corr_prng,
+ )?;
+
+ Ok((
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: input_share.corr_inner[usize::from(agg_param.level)][0],
+ B_share: input_share.corr_inner[usize::from(agg_param.level)][1],
+ is_leader,
+ },
+ output_share,
+ })),
+ Poplar1FieldVec::Inner(sketch_share),
+ ))
+ } else {
+ let corr_prng = Self::init_prng::<_, _, Field255>(
+ input_share.corr_seed.as_ref(),
+ DST_CORR_LEAF,
+ [[agg_id as u8].as_slice(), nonce.as_slice()],
+ );
+
+ let (output_share, sketch_share) = eval_and_sketch::<P, Field255, SEED_SIZE>(
+ verify_key,
+ agg_id,
+ nonce,
+ agg_param,
+ public_share,
+ &input_share.idpf_key,
+ &mut corr_prng.into_new_field(),
+ )?;
+
+ Ok((
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: input_share.corr_leaf[0],
+ B_share: input_share.corr_leaf[1],
+ is_leader,
+ },
+ output_share,
+ })),
+ Poplar1FieldVec::Leaf(sketch_share),
+ ))
+ }
+ }
+
+ fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Poplar1FieldVec>>(
+ &self,
+ _: &Poplar1AggregationParam,
+ inputs: M,
+ ) -> Result<Poplar1PrepareMessage, VdafError> {
+ let mut inputs = inputs.into_iter();
+ let prep_share_0 = inputs
+ .next()
+ .ok_or_else(|| VdafError::Uncategorized("insufficient number of prep shares".into()))?;
+ let prep_share_1 = inputs
+ .next()
+ .ok_or_else(|| VdafError::Uncategorized("insufficient number of prep shares".into()))?;
+ if inputs.next().is_some() {
+ return Err(VdafError::Uncategorized(
+ "more prep shares than expected".into(),
+ ));
+ }
+
+ match (prep_share_0, prep_share_1) {
+ (Poplar1FieldVec::Inner(share_0), Poplar1FieldVec::Inner(share_1)) => {
+ Ok(Poplar1PrepareMessage(
+ next_message(share_0, share_1)?.map_or(PrepareMessageVariant::Done, |sketch| {
+ PrepareMessageVariant::SketchInner(sketch)
+ }),
+ ))
+ }
+ (Poplar1FieldVec::Leaf(share_0), Poplar1FieldVec::Leaf(share_1)) => {
+ Ok(Poplar1PrepareMessage(
+ next_message(share_0, share_1)?.map_or(PrepareMessageVariant::Done, |sketch| {
+ PrepareMessageVariant::SketchLeaf(sketch)
+ }),
+ ))
+ }
+ _ => Err(VdafError::Uncategorized(
+ "received prep shares with mismatched field types".into(),
+ )),
+ }
+ }
+
+ fn prepare_next(
+ &self,
+ state: Poplar1PrepareState,
+ msg: Poplar1PrepareMessage,
+ ) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> {
+ match (state.0, msg.0) {
+ // Round one
+ (
+ PrepareStateVariant::Inner(PrepareState {
+ sketch:
+ SketchState::RoundOne {
+ A_share,
+ B_share,
+ is_leader,
+ },
+ output_share,
+ }),
+ PrepareMessageVariant::SketchInner(sketch),
+ ) => Ok(PrepareTransition::Continue(
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share,
+ })),
+ Poplar1FieldVec::Inner(finish_sketch(sketch, A_share, B_share, is_leader)),
+ )),
+ (
+ PrepareStateVariant::Leaf(PrepareState {
+ sketch:
+ SketchState::RoundOne {
+ A_share,
+ B_share,
+ is_leader,
+ },
+ output_share,
+ }),
+ PrepareMessageVariant::SketchLeaf(sketch),
+ ) => Ok(PrepareTransition::Continue(
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share,
+ })),
+ Poplar1FieldVec::Leaf(finish_sketch(sketch, A_share, B_share, is_leader)),
+ )),
+
+ // Round two
+ (
+ PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share,
+ }),
+ PrepareMessageVariant::Done,
+ ) => Ok(PrepareTransition::Finish(Poplar1FieldVec::Inner(
+ output_share,
+ ))),
+ (
+ PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share,
+ }),
+ PrepareMessageVariant::Done,
+ ) => Ok(PrepareTransition::Finish(Poplar1FieldVec::Leaf(
+ output_share,
+ ))),
+
+ _ => Err(VdafError::Uncategorized(
+ "prep message field type does not match state".into(),
+ )),
+ }
+ }
+
+ fn aggregate<M: IntoIterator<Item = Poplar1FieldVec>>(
+ &self,
+ agg_param: &Poplar1AggregationParam,
+ output_shares: M,
+ ) -> Result<Poplar1FieldVec, VdafError> {
+ aggregate(
+ usize::from(agg_param.level) == self.bits - 1,
+ agg_param.prefixes.len(),
+ output_shares,
+ )
+ }
+}
+
+impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Collector for Poplar1<P, SEED_SIZE> {
+ fn unshard<M: IntoIterator<Item = Poplar1FieldVec>>(
+ &self,
+ agg_param: &Poplar1AggregationParam,
+ agg_shares: M,
+ _num_measurements: usize,
+ ) -> Result<Vec<u64>, VdafError> {
+ let result = aggregate(
+ usize::from(agg_param.level) == self.bits - 1,
+ agg_param.prefixes.len(),
+ agg_shares,
+ )?;
+
+ match result {
+ Poplar1FieldVec::Inner(vec) => Ok(vec.into_iter().map(u64::from).collect()),
+ Poplar1FieldVec::Leaf(vec) => Ok(vec
+ .into_iter()
+ .map(u64::try_from)
+ .collect::<Result<Vec<_>, _>>()?),
+ }
+ }
+}
+
+impl From<IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>>
+ for Poplar1IdpfValue<Field64>
+{
+ fn from(
+ out_share: IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>,
+ ) -> Poplar1IdpfValue<Field64> {
+ match out_share {
+ IdpfOutputShare::Inner(array) => array,
+ IdpfOutputShare::Leaf(..) => panic!("tried to convert leaf share into inner field"),
+ }
+ }
+}
+
+impl From<IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>>
+ for Poplar1IdpfValue<Field255>
+{
+ fn from(
+ out_share: IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>,
+ ) -> Poplar1IdpfValue<Field255> {
+ match out_share {
+ IdpfOutputShare::Inner(..) => panic!("tried to convert inner share into leaf field"),
+ IdpfOutputShare::Leaf(array) => array,
+ }
+ }
+}
+
+/// Derive shares of the correlated randomness for the next level of the IDPF tree.
+//
+// TODO(cjpatton) spec: Consider deriving the shares of a, b, c for each level directly from the
+// seed, rather than iteratively, as we do in Doplar. This would be more efficient for the
+// Aggregators. As long as the Client isn't significantly slower, this should be a win.
+#[allow(non_snake_case)]
+fn compute_next_corr_shares<F: FieldElement + From<u64>, S: RngCore>(
+ prng: &mut Prng<F, S>,
+ corr_prng_0: &mut Prng<F, S>,
+ corr_prng_1: &mut Prng<F, S>,
+ auth: F,
+) -> ([F; 2], [F; 2]) {
+ let two = F::from(2);
+ let a = corr_prng_0.get() + corr_prng_1.get();
+ let b = corr_prng_0.get() + corr_prng_1.get();
+ let c = corr_prng_0.get() + corr_prng_1.get();
+ let A = -two * a + auth;
+ let B = a * a + b - a * auth + c;
+ let corr_1 = [prng.get(), prng.get()];
+ let corr_0 = [A - corr_1[0], B - corr_1[1]];
+ (corr_0, corr_1)
+}
+
+/// Evaluate the IDPF at the given prefixes and compute the Aggregator's share of the sketch.
+fn eval_and_sketch<P, F, const SEED_SIZE: usize>(
+ verify_key: &[u8; SEED_SIZE],
+ agg_id: usize,
+ nonce: &[u8; 16],
+ agg_param: &Poplar1AggregationParam,
+ public_share: &Poplar1PublicShare,
+ idpf_key: &Seed<16>,
+ corr_prng: &mut Prng<F, P::SeedStream>,
+) -> Result<(Vec<F>, Vec<F>), VdafError>
+where
+ P: Xof<SEED_SIZE>,
+ F: FieldElement,
+ Poplar1IdpfValue<F>:
+ From<IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>>,
+{
+ // TODO(cjpatton) spec: Consider not encoding the prefixes here.
+ let mut verify_prng = Poplar1::<P, SEED_SIZE>::init_prng(
+ verify_key,
+ DST_VERIFY_RANDOMNESS,
+ [nonce.as_slice(), agg_param.level.to_be_bytes().as_slice()],
+ );
+
+ let mut out_share = Vec::with_capacity(agg_param.prefixes.len());
+ let mut sketch_share = vec![
+ corr_prng.get(), // a_share
+ corr_prng.get(), // b_share
+ corr_prng.get(), // c_share
+ ];
+
+ let mut idpf_eval_cache = RingBufferCache::new(agg_param.prefixes.len());
+ let idpf = Idpf::<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>::new((), ());
+ for prefix in agg_param.prefixes.iter() {
+ let share = Poplar1IdpfValue::<F>::from(idpf.eval(
+ agg_id,
+ public_share,
+ idpf_key,
+ prefix,
+ nonce,
+ &mut idpf_eval_cache,
+ )?);
+
+ let r = verify_prng.get();
+ let checked_data_share = share.0[0] * r;
+ sketch_share[0] += checked_data_share;
+ sketch_share[1] += checked_data_share * r;
+ sketch_share[2] += share.0[1] * r;
+ out_share.push(share.0[0]);
+ }
+
+ Ok((out_share, sketch_share))
+}
+
+/// Compute the Aggregator's share of the sketch verifier. The shares should sum to zero.
+#[allow(non_snake_case)]
+fn finish_sketch<F: FieldElement>(
+ sketch: [F; 3],
+ A_share: F,
+ B_share: F,
+ is_leader: bool,
+) -> Vec<F> {
+ let mut next_sketch_share = A_share * sketch[0] + B_share;
+ if !is_leader {
+ next_sketch_share += sketch[0] * sketch[0] - sketch[1] - sketch[2];
+ }
+ vec![next_sketch_share]
+}
+
+fn next_message<F: FieldElement>(
+ mut share_0: Vec<F>,
+ share_1: Vec<F>,
+) -> Result<Option<[F; 3]>, VdafError> {
+ merge_vector(&mut share_0, &share_1)?;
+
+ if share_0.len() == 1 {
+ if share_0[0] != F::zero() {
+ Err(VdafError::Uncategorized(
+ "sketch verification failed".into(),
+ )) // Invalid sketch
+ } else {
+ Ok(None) // Sketch verification succeeded
+ }
+ } else if share_0.len() == 3 {
+ Ok(Some([share_0[0], share_0[1], share_0[2]])) // Sketch verification continues
+ } else {
+ Err(VdafError::Uncategorized(format!(
+ "unexpected sketch length ({})",
+ share_0.len()
+ )))
+ }
+}
+
+fn aggregate<M: IntoIterator<Item = Poplar1FieldVec>>(
+ is_leaf: bool,
+ len: usize,
+ shares: M,
+) -> Result<Poplar1FieldVec, VdafError> {
+ let mut result = Poplar1FieldVec::zero(is_leaf, len);
+ for share in shares.into_iter() {
+ result.accumulate(&share)?;
+ }
+ Ok(result)
+}
+
+/// A vector of two field elements.
+///
+/// This represents the values that Poplar1 programs into IDPFs while sharding.
+#[derive(Debug, Clone, Copy)]
+pub struct Poplar1IdpfValue<F>([F; 2]);
+
+impl<F> Poplar1IdpfValue<F> {
+ /// Create a new value from a pair of field elements.
+ pub fn new(array: [F; 2]) -> Self {
+ Self(array)
+ }
+}
+
+impl<F> IdpfValue for Poplar1IdpfValue<F>
+where
+ F: FieldElement,
+{
+ type ValueParameter = ();
+
+ fn zero(_: &()) -> Self {
+ Self([F::zero(); 2])
+ }
+
+ fn generate<S: RngCore>(seed_stream: &mut S, _: &()) -> Self {
+ Self([F::generate(seed_stream, &()), F::generate(seed_stream, &())])
+ }
+
+ fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
+ ConditionallySelectable::conditional_select(a, b, choice)
+ }
+}
+
+impl<F> Add for Poplar1IdpfValue<F>
+where
+ F: Copy + Add<Output = F>,
+{
+ type Output = Self;
+
+ fn add(self, rhs: Self) -> Self {
+ Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]])
+ }
+}
+
+impl<F> AddAssign for Poplar1IdpfValue<F>
+where
+ F: Copy + AddAssign,
+{
+ fn add_assign(&mut self, rhs: Self) {
+ self.0[0] += rhs.0[0];
+ self.0[1] += rhs.0[1];
+ }
+}
+
+impl<F> Sub for Poplar1IdpfValue<F>
+where
+ F: Copy + Sub<Output = F>,
+{
+ type Output = Self;
+
+ fn sub(self, rhs: Self) -> Self {
+ Self([self.0[0] - rhs.0[0], self.0[1] - rhs.0[1]])
+ }
+}
+
+impl<F> PartialEq for Poplar1IdpfValue<F>
+where
+ F: PartialEq,
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
+}
+
+impl<F> ConstantTimeEq for Poplar1IdpfValue<F>
+where
+ F: ConstantTimeEq,
+{
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl<F> Encode for Poplar1IdpfValue<F>
+where
+ F: FieldElement,
+{
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0[0].encode(bytes);
+ self.0[1].encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(F::ENCODED_SIZE * 2)
+ }
+}
+
+impl<F> Decode for Poplar1IdpfValue<F>
+where
+ F: Decode,
+{
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ Ok(Self([F::decode(bytes)?, F::decode(bytes)?]))
+ }
+}
+
+impl<F> ConditionallySelectable for Poplar1IdpfValue<F>
+where
+ F: ConditionallySelectable,
+{
+ fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
+ Self([
+ F::conditional_select(&a.0[0], &b.0[0], choice),
+ F::conditional_select(&a.0[1], &b.0[1], choice),
+ ])
+ }
+}
+
+impl<F> ConditionallyNegatable for Poplar1IdpfValue<F>
+where
+ F: ConditionallyNegatable,
+{
+ fn conditional_negate(&mut self, choice: subtle::Choice) {
+ F::conditional_negate(&mut self.0[0], choice);
+ F::conditional_negate(&mut self.0[1], choice);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::vdaf::{equality_comparison_test, run_vdaf_prepare};
+ use assert_matches::assert_matches;
+ use rand::prelude::*;
+ use serde::Deserialize;
+ use std::collections::HashSet;
+
+ fn test_prepare<P: Xof<SEED_SIZE>, const SEED_SIZE: usize>(
+ vdaf: &Poplar1<P, SEED_SIZE>,
+ verify_key: &[u8; SEED_SIZE],
+ nonce: &[u8; 16],
+ public_share: &Poplar1PublicShare,
+ input_shares: &[Poplar1InputShare<SEED_SIZE>],
+ agg_param: &Poplar1AggregationParam,
+ expected_result: Vec<u64>,
+ ) {
+ let out_shares = run_vdaf_prepare(
+ vdaf,
+ verify_key,
+ agg_param,
+ nonce,
+ public_share.clone(),
+ input_shares.to_vec(),
+ )
+ .unwrap();
+
+ // Convert aggregate shares and unshard.
+ let agg_share_0 = vdaf.aggregate(agg_param, [out_shares[0].clone()]).unwrap();
+ let agg_share_1 = vdaf.aggregate(agg_param, [out_shares[1].clone()]).unwrap();
+ let result = vdaf
+ .unshard(agg_param, [agg_share_0, agg_share_1], 1)
+ .unwrap();
+ assert_eq!(
+ result, expected_result,
+ "unexpected result (level={})",
+ agg_param.level
+ );
+ }
+
+ fn run_heavy_hitters<B: AsRef<[u8]>, P: Xof<SEED_SIZE>, const SEED_SIZE: usize>(
+ vdaf: &Poplar1<P, SEED_SIZE>,
+ verify_key: &[u8; SEED_SIZE],
+ threshold: usize,
+ measurements: impl IntoIterator<Item = B>,
+ expected_result: impl IntoIterator<Item = B>,
+ ) {
+ let mut rng = thread_rng();
+
+ // Sharding step
+ let reports: Vec<(
+ [u8; 16],
+ Poplar1PublicShare,
+ Vec<Poplar1InputShare<SEED_SIZE>>,
+ )> = measurements
+ .into_iter()
+ .map(|measurement| {
+ let nonce = rng.gen();
+ let (public_share, input_shares) = vdaf
+ .shard(&IdpfInput::from_bytes(measurement.as_ref()), &nonce)
+ .unwrap();
+ (nonce, public_share, input_shares)
+ })
+ .collect();
+
+ let mut agg_param = Poplar1AggregationParam {
+ level: 0,
+ prefixes: vec![
+ IdpfInput::from_bools(&[false]),
+ IdpfInput::from_bools(&[true]),
+ ],
+ };
+
+ let mut agg_result = Vec::new();
+ for level in 0..vdaf.bits {
+ let mut out_shares_0 = Vec::with_capacity(reports.len());
+ let mut out_shares_1 = Vec::with_capacity(reports.len());
+
+ // Preparation step
+ for (nonce, public_share, input_shares) in reports.iter() {
+ let out_shares = run_vdaf_prepare(
+ vdaf,
+ verify_key,
+ &agg_param,
+ nonce,
+ public_share.clone(),
+ input_shares.to_vec(),
+ )
+ .unwrap();
+
+ out_shares_0.push(out_shares[0].clone());
+ out_shares_1.push(out_shares[1].clone());
+ }
+
+ // Aggregation step
+ let agg_share_0 = vdaf.aggregate(&agg_param, out_shares_0).unwrap();
+ let agg_share_1 = vdaf.aggregate(&agg_param, out_shares_1).unwrap();
+
+ // Unsharding step
+ agg_result = vdaf
+ .unshard(&agg_param, [agg_share_0, agg_share_1], reports.len())
+ .unwrap();
+
+ agg_param.level += 1;
+
+ // Unless this is the last level of the tree, construct the next set of candidate
+ // prefixes.
+ if level < vdaf.bits - 1 {
+ let mut next_prefixes = Vec::new();
+ for (prefix, count) in agg_param.prefixes.into_iter().zip(agg_result.iter()) {
+ if *count >= threshold as u64 {
+ next_prefixes.push(prefix.clone_with_suffix(&[false]));
+ next_prefixes.push(prefix.clone_with_suffix(&[true]));
+ }
+ }
+
+ agg_param.prefixes = next_prefixes;
+ }
+ }
+
+ let got: HashSet<IdpfInput> = agg_param
+ .prefixes
+ .into_iter()
+ .zip(agg_result.iter())
+ .filter(|(_prefix, count)| **count >= threshold as u64)
+ .map(|(prefix, _count)| prefix)
+ .collect();
+
+ let want: HashSet<IdpfInput> = expected_result
+ .into_iter()
+ .map(|bytes| IdpfInput::from_bytes(bytes.as_ref()))
+ .collect();
+
+ assert_eq!(got, want);
+ }
+
+ #[test]
+ fn shard_prepare() {
+ let mut rng = thread_rng();
+ let vdaf = Poplar1::new_shake128(64);
+ let verify_key = rng.gen();
+ let input = IdpfInput::from_bytes(b"12341324");
+ let nonce = rng.gen();
+ let (public_share, input_shares) = vdaf.shard(&input, &nonce).unwrap();
+
+ test_prepare(
+ &vdaf,
+ &verify_key,
+ &nonce,
+ &public_share,
+ &input_shares,
+ &Poplar1AggregationParam {
+ level: 7,
+ prefixes: vec![
+ IdpfInput::from_bytes(b"0"),
+ IdpfInput::from_bytes(b"1"),
+ IdpfInput::from_bytes(b"2"),
+ IdpfInput::from_bytes(b"f"),
+ ],
+ },
+ vec![0, 1, 0, 0],
+ );
+
+ for level in 0..vdaf.bits {
+ test_prepare(
+ &vdaf,
+ &verify_key,
+ &nonce,
+ &public_share,
+ &input_shares,
+ &Poplar1AggregationParam {
+ level: level as u16,
+ prefixes: vec![input.prefix(level)],
+ },
+ vec![1],
+ );
+ }
+ }
+
+ #[test]
+ fn heavy_hitters() {
+ let mut rng = thread_rng();
+ let verify_key = rng.gen();
+ let vdaf = Poplar1::new_shake128(8);
+
+ run_heavy_hitters(
+ &vdaf,
+ &verify_key,
+ 2, // threshold
+ [
+ "a", "b", "c", "d", "e", "f", "g", "g", "h", "i", "i", "i", "j", "j", "k", "l",
+ ], // measurements
+ ["g", "i", "j"], // heavy hitters
+ );
+ }
+
+ #[test]
+ fn encoded_len() {
+ // Input share
+ let input_share = Poplar1InputShare {
+ idpf_key: Seed::<16>::generate().unwrap(),
+ corr_seed: Seed::<16>::generate().unwrap(),
+ corr_inner: vec![
+ [Field64::one(), <Field64 as FieldElement>::zero()],
+ [Field64::one(), <Field64 as FieldElement>::zero()],
+ [Field64::one(), <Field64 as FieldElement>::zero()],
+ ],
+ corr_leaf: [Field255::one(), <Field255 as FieldElement>::zero()],
+ };
+ assert_eq!(
+ input_share.get_encoded().len(),
+ input_share.encoded_len().unwrap()
+ );
+
+ // Prepaare message variants
+ let prep_msg = Poplar1PrepareMessage(PrepareMessageVariant::SketchInner([
+ Field64::one(),
+ Field64::one(),
+ Field64::one(),
+ ]));
+ assert_eq!(
+ prep_msg.get_encoded().len(),
+ prep_msg.encoded_len().unwrap()
+ );
+ let prep_msg = Poplar1PrepareMessage(PrepareMessageVariant::SketchLeaf([
+ Field255::one(),
+ Field255::one(),
+ Field255::one(),
+ ]));
+ assert_eq!(
+ prep_msg.get_encoded().len(),
+ prep_msg.encoded_len().unwrap()
+ );
+ let prep_msg = Poplar1PrepareMessage(PrepareMessageVariant::Done);
+ assert_eq!(
+ prep_msg.get_encoded().len(),
+ prep_msg.encoded_len().unwrap()
+ );
+
+ // Field vector variants.
+ let field_vec = Poplar1FieldVec::Inner(vec![Field64::one(); 23]);
+ assert_eq!(
+ field_vec.get_encoded().len(),
+ field_vec.encoded_len().unwrap()
+ );
+ let field_vec = Poplar1FieldVec::Leaf(vec![Field255::one(); 23]);
+ assert_eq!(
+ field_vec.get_encoded().len(),
+ field_vec.encoded_len().unwrap()
+ );
+
+ // Aggregation parameter.
+ let agg_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([
+ IdpfInput::from_bytes(b"ab"),
+ IdpfInput::from_bytes(b"cd"),
+ ]))
+ .unwrap();
+ assert_eq!(
+ agg_param.get_encoded().len(),
+ agg_param.encoded_len().unwrap()
+ );
+ let agg_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([
+ IdpfInput::from_bools(&[false]),
+ IdpfInput::from_bools(&[true]),
+ ]))
+ .unwrap();
+ assert_eq!(
+ agg_param.get_encoded().len(),
+ agg_param.encoded_len().unwrap()
+ );
+ }
+
+ #[test]
+ fn round_trip_prepare_state() {
+ let vdaf = Poplar1::new_shake128(1);
+ for (agg_id, prep_state) in [
+ (
+ 0,
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(0),
+ B_share: Field64::from(1),
+ is_leader: true,
+ },
+ output_share: Vec::from([Field64::from(2), Field64::from(3), Field64::from(4)]),
+ })),
+ ),
+ (
+ 1,
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(5),
+ B_share: Field64::from(6),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field64::from(7), Field64::from(8), Field64::from(9)]),
+ })),
+ ),
+ (
+ 0,
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([
+ Field64::from(10),
+ Field64::from(11),
+ Field64::from(12),
+ ]),
+ })),
+ ),
+ (
+ 1,
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([
+ Field64::from(13),
+ Field64::from(14),
+ Field64::from(15),
+ ]),
+ })),
+ ),
+ (
+ 0,
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(16),
+ B_share: Field255::from(17),
+ is_leader: true,
+ },
+ output_share: Vec::from([
+ Field255::from(18),
+ Field255::from(19),
+ Field255::from(20),
+ ]),
+ })),
+ ),
+ (
+ 1,
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(21),
+ B_share: Field255::from(22),
+ is_leader: false,
+ },
+ output_share: Vec::from([
+ Field255::from(23),
+ Field255::from(24),
+ Field255::from(25),
+ ]),
+ })),
+ ),
+ (
+ 0,
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([
+ Field255::from(26),
+ Field255::from(27),
+ Field255::from(28),
+ ]),
+ })),
+ ),
+ (
+ 1,
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([
+ Field255::from(29),
+ Field255::from(30),
+ Field255::from(31),
+ ]),
+ })),
+ ),
+ ] {
+ let encoded_prep_state = prep_state.get_encoded();
+ assert_eq!(prep_state.encoded_len(), Some(encoded_prep_state.len()));
+ let decoded_prep_state =
+ Poplar1PrepareState::get_decoded_with_param(&(&vdaf, agg_id), &encoded_prep_state)
+ .unwrap();
+ assert_eq!(prep_state, decoded_prep_state);
+ }
+ }
+
+ #[test]
+ fn round_trip_agg_param() {
+ // These test cases were generated using the reference Sage implementation.
+ // (https://github.com/cfrg/draft-irtf-cfrg-vdaf/tree/main/poc) Sage statements used to
+ // generate each test case are given in comments.
+ for (prefixes, reference_encoding) in [
+ // poplar.encode_agg_param(0, [0])
+ (
+ Vec::from([IdpfInput::from_bools(&[false])]),
+ [0, 0, 0, 0, 0, 1, 0].as_slice(),
+ ),
+ // poplar.encode_agg_param(0, [1])
+ (
+ Vec::from([IdpfInput::from_bools(&[true])]),
+ [0, 0, 0, 0, 0, 1, 1].as_slice(),
+ ),
+ // poplar.encode_agg_param(0, [0, 1])
+ (
+ Vec::from([
+ IdpfInput::from_bools(&[false]),
+ IdpfInput::from_bools(&[true]),
+ ]),
+ [0, 0, 0, 0, 0, 2, 2].as_slice(),
+ ),
+ // poplar.encode_agg_param(1, [0b00, 0b01, 0b10, 0b11])
+ (
+ Vec::from([
+ IdpfInput::from_bools(&[false, false]),
+ IdpfInput::from_bools(&[false, true]),
+ IdpfInput::from_bools(&[true, false]),
+ IdpfInput::from_bools(&[true, true]),
+ ]),
+ [0, 1, 0, 0, 0, 4, 0xe4].as_slice(),
+ ),
+ // poplar.encode_agg_param(1, [0b00, 0b10, 0b11])
+ (
+ Vec::from([
+ IdpfInput::from_bools(&[false, false]),
+ IdpfInput::from_bools(&[true, false]),
+ IdpfInput::from_bools(&[true, true]),
+ ]),
+ [0, 1, 0, 0, 0, 3, 0x38].as_slice(),
+ ),
+ // poplar.encode_agg_param(2, [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111])
+ (
+ Vec::from([
+ IdpfInput::from_bools(&[false, false, false]),
+ IdpfInput::from_bools(&[false, false, true]),
+ IdpfInput::from_bools(&[false, true, false]),
+ IdpfInput::from_bools(&[false, true, true]),
+ IdpfInput::from_bools(&[true, false, false]),
+ IdpfInput::from_bools(&[true, false, true]),
+ IdpfInput::from_bools(&[true, true, false]),
+ IdpfInput::from_bools(&[true, true, true]),
+ ]),
+ [0, 2, 0, 0, 0, 8, 0xfa, 0xc6, 0x88].as_slice(),
+ ),
+ // poplar.encode_agg_param(9, [0b01_1011_0010, 0b10_1101_1010])
+ (
+ Vec::from([
+ IdpfInput::from_bools(&[
+ false, true, true, false, true, true, false, false, true, false,
+ ]),
+ IdpfInput::from_bools(&[
+ true, false, true, true, false, true, true, false, true, false,
+ ]),
+ ]),
+ [0, 9, 0, 0, 0, 2, 0x0b, 0x69, 0xb2].as_slice(),
+ ),
+ // poplar.encode_agg_param(15, [0xcafe])
+ (
+ Vec::from([IdpfInput::from_bytes(b"\xca\xfe")]),
+ [0, 15, 0, 0, 0, 1, 0xca, 0xfe].as_slice(),
+ ),
+ ] {
+ let agg_param = Poplar1AggregationParam::try_from_prefixes(prefixes).unwrap();
+ let encoded = agg_param.get_encoded();
+ assert_eq!(encoded, reference_encoding);
+ let decoded = Poplar1AggregationParam::get_decoded(reference_encoding).unwrap();
+ assert_eq!(decoded, agg_param);
+ }
+ }
+
+ #[test]
+ fn agg_param_wrong_unused_bit() {
+ let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 1, 2]).unwrap_err();
+ assert_matches!(err, CodecError::UnexpectedValue);
+ }
+
+ #[test]
+ fn agg_param_ordering() {
+ let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 2, 1]).unwrap_err();
+ assert_matches!(err, CodecError::Other(_));
+ let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 2, 0]).unwrap_err();
+ assert_matches!(err, CodecError::Other(_));
+ let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 2, 3]).unwrap_err();
+ assert_matches!(err, CodecError::Other(_));
+ }
+
+ #[derive(Debug, Deserialize)]
+ struct HexEncoded(#[serde(with = "hex")] Vec<u8>);
+
+ impl AsRef<[u8]> for HexEncoded {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+ }
+
+ #[derive(Debug, Deserialize)]
+ struct PoplarTestVector {
+ agg_param: (usize, Vec<u64>),
+ agg_result: Vec<u64>,
+ agg_shares: Vec<HexEncoded>,
+ bits: usize,
+ prep: Vec<PreparationTestVector>,
+ verify_key: HexEncoded,
+ }
+
+ #[derive(Debug, Deserialize)]
+ struct PreparationTestVector {
+ input_shares: Vec<HexEncoded>,
+ measurement: u64,
+ nonce: HexEncoded,
+ out_shares: Vec<Vec<HexEncoded>>,
+ prep_messages: Vec<HexEncoded>,
+ prep_shares: Vec<Vec<HexEncoded>>,
+ public_share: HexEncoded,
+ rand: HexEncoded,
+ }
+
+ fn check_test_vec(input: &str) {
+ let test_vector: PoplarTestVector = serde_json::from_str(input).unwrap();
+ assert_eq!(test_vector.prep.len(), 1);
+ let prep = &test_vector.prep[0];
+ let measurement_bits = (0..test_vector.bits)
+ .rev()
+ .map(|i| (prep.measurement >> i) & 1 != 0)
+ .collect::<BitVec>();
+ let measurement = IdpfInput::from(measurement_bits);
+ let (agg_param_level, agg_param_prefixes_int) = test_vector.agg_param;
+ let agg_param_prefixes = agg_param_prefixes_int
+ .iter()
+ .map(|int| {
+ let bits = (0..=agg_param_level)
+ .rev()
+ .map(|i| (*int >> i) & 1 != 0)
+ .collect::<BitVec>();
+ bits.into()
+ })
+ .collect::<Vec<IdpfInput>>();
+ let agg_param = Poplar1AggregationParam::try_from_prefixes(agg_param_prefixes).unwrap();
+ let verify_key = test_vector.verify_key.as_ref().try_into().unwrap();
+ let nonce = prep.nonce.as_ref().try_into().unwrap();
+
+ let mut idpf_random = [[0u8; 16]; 2];
+ let mut poplar_random = [[0u8; 16]; 3];
+ for (input, output) in prep
+ .rand
+ .as_ref()
+ .chunks_exact(16)
+ .zip(idpf_random.iter_mut().chain(poplar_random.iter_mut()))
+ {
+ output.copy_from_slice(input);
+ }
+
+ // Shard measurement.
+ let poplar = Poplar1::new_shake128(test_vector.bits);
+ let (public_share, input_shares) = poplar
+ .shard_with_random(&measurement, &nonce, &idpf_random, &poplar_random)
+ .unwrap();
+
+ // Run aggregation.
+ let (init_prep_state_0, init_prep_share_0) = poplar
+ .prepare_init(
+ &verify_key,
+ 0,
+ &agg_param,
+ &nonce,
+ &public_share,
+ &input_shares[0],
+ )
+ .unwrap();
+ let (init_prep_state_1, init_prep_share_1) = poplar
+ .prepare_init(
+ &verify_key,
+ 1,
+ &agg_param,
+ &nonce,
+ &public_share,
+ &input_shares[1],
+ )
+ .unwrap();
+
+ let r1_prep_msg = poplar
+ .prepare_shares_to_prepare_message(
+ &agg_param,
+ [init_prep_share_0.clone(), init_prep_share_1.clone()],
+ )
+ .unwrap();
+
+ let (r1_prep_state_0, r1_prep_share_0) = assert_matches!(
+ poplar
+ .prepare_next(init_prep_state_0.clone(), r1_prep_msg.clone())
+ .unwrap(),
+ PrepareTransition::Continue(state, share) => (state, share)
+ );
+ let (r1_prep_state_1, r1_prep_share_1) = assert_matches!(
+ poplar
+ .prepare_next(init_prep_state_1.clone(), r1_prep_msg.clone())
+ .unwrap(),
+ PrepareTransition::Continue(state, share) => (state, share)
+ );
+
+ let r2_prep_msg = poplar
+ .prepare_shares_to_prepare_message(
+ &agg_param,
+ [r1_prep_share_0.clone(), r1_prep_share_1.clone()],
+ )
+ .unwrap();
+
+ let out_share_0 = assert_matches!(
+ poplar
+ .prepare_next(r1_prep_state_0.clone(), r2_prep_msg.clone())
+ .unwrap(),
+ PrepareTransition::Finish(out) => out
+ );
+ let out_share_1 = assert_matches!(
+ poplar
+ .prepare_next(r1_prep_state_1, r2_prep_msg.clone())
+ .unwrap(),
+ PrepareTransition::Finish(out) => out
+ );
+
+ let agg_share_0 = poplar.aggregate(&agg_param, [out_share_0.clone()]).unwrap();
+ let agg_share_1 = poplar.aggregate(&agg_param, [out_share_1.clone()]).unwrap();
+
+ // Collect result.
+ let agg_result = poplar
+ .unshard(&agg_param, [agg_share_0.clone(), agg_share_1.clone()], 1)
+ .unwrap();
+
+ // Check all intermediate results against the test vector, and exercise both encoding and decoding.
+ assert_eq!(
+ public_share,
+ Poplar1PublicShare::get_decoded_with_param(&poplar, prep.public_share.as_ref())
+ .unwrap()
+ );
+ assert_eq!(&public_share.get_encoded(), prep.public_share.as_ref());
+ assert_eq!(
+ input_shares[0],
+ Poplar1InputShare::get_decoded_with_param(&(&poplar, 0), prep.input_shares[0].as_ref())
+ .unwrap()
+ );
+ assert_eq!(
+ &input_shares[0].get_encoded(),
+ prep.input_shares[0].as_ref()
+ );
+ assert_eq!(
+ input_shares[1],
+ Poplar1InputShare::get_decoded_with_param(&(&poplar, 1), prep.input_shares[1].as_ref())
+ .unwrap()
+ );
+ assert_eq!(
+ &input_shares[1].get_encoded(),
+ prep.input_shares[1].as_ref()
+ );
+ assert_eq!(
+ init_prep_share_0,
+ Poplar1FieldVec::get_decoded_with_param(
+ &init_prep_state_0,
+ prep.prep_shares[0][0].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(
+ &init_prep_share_0.get_encoded(),
+ prep.prep_shares[0][0].as_ref()
+ );
+ assert_eq!(
+ init_prep_share_1,
+ Poplar1FieldVec::get_decoded_with_param(
+ &init_prep_state_1,
+ prep.prep_shares[0][1].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(
+ &init_prep_share_1.get_encoded(),
+ prep.prep_shares[0][1].as_ref()
+ );
+ assert_eq!(
+ r1_prep_msg,
+ Poplar1PrepareMessage::get_decoded_with_param(
+ &init_prep_state_0,
+ prep.prep_messages[0].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(&r1_prep_msg.get_encoded(), prep.prep_messages[0].as_ref());
+
+ assert_eq!(
+ r1_prep_share_0,
+ Poplar1FieldVec::get_decoded_with_param(
+ &r1_prep_state_0,
+ prep.prep_shares[1][0].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(
+ &r1_prep_share_0.get_encoded(),
+ prep.prep_shares[1][0].as_ref()
+ );
+ assert_eq!(
+ r1_prep_share_1,
+ Poplar1FieldVec::get_decoded_with_param(
+ &r1_prep_state_0,
+ prep.prep_shares[1][1].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(
+ &r1_prep_share_1.get_encoded(),
+ prep.prep_shares[1][1].as_ref()
+ );
+ assert_eq!(
+ r2_prep_msg,
+ Poplar1PrepareMessage::get_decoded_with_param(
+ &r1_prep_state_0,
+ prep.prep_messages[1].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(&r2_prep_msg.get_encoded(), prep.prep_messages[1].as_ref());
+ for (out_share, expected_out_share) in [
+ (out_share_0, &prep.out_shares[0]),
+ (out_share_1, &prep.out_shares[1]),
+ ] {
+ match out_share {
+ Poplar1FieldVec::Inner(vec) => {
+ assert_eq!(vec.len(), expected_out_share.len());
+ for (element, expected) in vec.iter().zip(expected_out_share.iter()) {
+ assert_eq!(&element.get_encoded(), expected.as_ref());
+ }
+ }
+ Poplar1FieldVec::Leaf(vec) => {
+ assert_eq!(vec.len(), expected_out_share.len());
+ for (element, expected) in vec.iter().zip(expected_out_share.iter()) {
+ assert_eq!(&element.get_encoded(), expected.as_ref());
+ }
+ }
+ };
+ }
+ assert_eq!(
+ agg_share_0,
+ Poplar1FieldVec::get_decoded_with_param(
+ &(&poplar, &agg_param),
+ test_vector.agg_shares[0].as_ref()
+ )
+ .unwrap()
+ );
+
+ assert_eq!(
+ &agg_share_0.get_encoded(),
+ test_vector.agg_shares[0].as_ref()
+ );
+ assert_eq!(
+ agg_share_1,
+ Poplar1FieldVec::get_decoded_with_param(
+ &(&poplar, &agg_param),
+ test_vector.agg_shares[1].as_ref()
+ )
+ .unwrap()
+ );
+ assert_eq!(
+ &agg_share_1.get_encoded(),
+ test_vector.agg_shares[1].as_ref()
+ );
+ assert_eq!(agg_result, test_vector.agg_result);
+ }
+
+ #[test]
+ fn test_vec_poplar1_0() {
+ check_test_vec(include_str!("test_vec/07/Poplar1_0.json"));
+ }
+
+ #[test]
+ fn test_vec_poplar1_1() {
+ check_test_vec(include_str!("test_vec/07/Poplar1_1.json"));
+ }
+
+ #[test]
+ fn test_vec_poplar1_2() {
+ check_test_vec(include_str!("test_vec/07/Poplar1_2.json"));
+ }
+
+ #[test]
+ fn test_vec_poplar1_3() {
+ check_test_vec(include_str!("test_vec/07/Poplar1_3.json"));
+ }
+
+ #[test]
+ fn input_share_equality_test() {
+ equality_comparison_test(&[
+ // Default.
+ Poplar1InputShare {
+ idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
+ corr_seed: Seed([16, 17, 18]),
+ corr_inner: Vec::from([
+ [Field64::from(19), Field64::from(20)],
+ [Field64::from(21), Field64::from(22)],
+ [Field64::from(23), Field64::from(24)],
+ ]),
+ corr_leaf: [Field255::from(25), Field255::from(26)],
+ },
+ // Modified idpf_key.
+ Poplar1InputShare {
+ idpf_key: Seed([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
+ corr_seed: Seed([16, 17, 18]),
+ corr_inner: Vec::from([
+ [Field64::from(19), Field64::from(20)],
+ [Field64::from(21), Field64::from(22)],
+ [Field64::from(23), Field64::from(24)],
+ ]),
+ corr_leaf: [Field255::from(25), Field255::from(26)],
+ },
+ // Modified corr_seed.
+ Poplar1InputShare {
+ idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
+ corr_seed: Seed([18, 17, 16]),
+ corr_inner: Vec::from([
+ [Field64::from(19), Field64::from(20)],
+ [Field64::from(21), Field64::from(22)],
+ [Field64::from(23), Field64::from(24)],
+ ]),
+ corr_leaf: [Field255::from(25), Field255::from(26)],
+ },
+ // Modified corr_inner.
+ Poplar1InputShare {
+ idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
+ corr_seed: Seed([16, 17, 18]),
+ corr_inner: Vec::from([
+ [Field64::from(24), Field64::from(23)],
+ [Field64::from(22), Field64::from(21)],
+ [Field64::from(20), Field64::from(19)],
+ ]),
+ corr_leaf: [Field255::from(25), Field255::from(26)],
+ },
+ // Modified corr_leaf.
+ Poplar1InputShare {
+ idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
+ corr_seed: Seed([16, 17, 18]),
+ corr_inner: Vec::from([
+ [Field64::from(19), Field64::from(20)],
+ [Field64::from(21), Field64::from(22)],
+ [Field64::from(23), Field64::from(24)],
+ ]),
+ corr_leaf: [Field255::from(26), Field255::from(25)],
+ },
+ ])
+ }
+
+ #[test]
+ fn prepare_state_equality_test() {
+ // This test effectively covers PrepareStateVariant, PrepareState, SketchState as well.
+ equality_comparison_test(&[
+ // Inner, round one. (default)
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(0),
+ B_share: Field64::from(1),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field64::from(2), Field64::from(3)]),
+ })),
+ // Inner, round one, modified A_share.
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(100),
+ B_share: Field64::from(1),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field64::from(2), Field64::from(3)]),
+ })),
+ // Inner, round one, modified B_share.
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(0),
+ B_share: Field64::from(101),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field64::from(2), Field64::from(3)]),
+ })),
+ // Inner, round one, modified is_leader.
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(0),
+ B_share: Field64::from(1),
+ is_leader: true,
+ },
+ output_share: Vec::from([Field64::from(2), Field64::from(3)]),
+ })),
+ // Inner, round one, modified output_share.
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field64::from(0),
+ B_share: Field64::from(1),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field64::from(3), Field64::from(2)]),
+ })),
+ // Inner, round two. (default)
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([Field64::from(2), Field64::from(3)]),
+ })),
+ // Inner, round two, modified output_share.
+ Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([Field64::from(3), Field64::from(2)]),
+ })),
+ // Leaf, round one. (default)
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(0),
+ B_share: Field255::from(1),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field255::from(2), Field255::from(3)]),
+ })),
+ // Leaf, round one, modified A_share.
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(100),
+ B_share: Field255::from(1),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field255::from(2), Field255::from(3)]),
+ })),
+ // Leaf, round one, modified B_share.
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(0),
+ B_share: Field255::from(101),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field255::from(2), Field255::from(3)]),
+ })),
+ // Leaf, round one, modified is_leader.
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(0),
+ B_share: Field255::from(1),
+ is_leader: true,
+ },
+ output_share: Vec::from([Field255::from(2), Field255::from(3)]),
+ })),
+ // Leaf, round one, modified output_share.
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundOne {
+ A_share: Field255::from(0),
+ B_share: Field255::from(1),
+ is_leader: false,
+ },
+ output_share: Vec::from([Field255::from(3), Field255::from(2)]),
+ })),
+ // Leaf, round two. (default)
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([Field255::from(2), Field255::from(3)]),
+ })),
+ // Leaf, round two, modified output_share.
+ Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState {
+ sketch: SketchState::RoundTwo,
+ output_share: Vec::from([Field255::from(3), Field255::from(2)]),
+ })),
+ ])
+ }
+
+ #[test]
+ fn field_vec_equality_test() {
+ equality_comparison_test(&[
+ // Inner. (default)
+ Poplar1FieldVec::Inner(Vec::from([Field64::from(0), Field64::from(1)])),
+ // Inner, modified value.
+ Poplar1FieldVec::Inner(Vec::from([Field64::from(1), Field64::from(0)])),
+ // Leaf. (deafult)
+ Poplar1FieldVec::Leaf(Vec::from([Field255::from(0), Field255::from(1)])),
+ // Leaf, modified value.
+ Poplar1FieldVec::Leaf(Vec::from([Field255::from(1), Field255::from(0)])),
+ ])
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs
new file mode 100644
index 0000000000..4669c47d00
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio2.rs
@@ -0,0 +1,543 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Backwards-compatible port of the ENPA Prio system to a VDAF.
+
+use crate::{
+ codec::{CodecError, Decode, Encode, ParameterizedDecode},
+ field::{
+ decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldPrio2,
+ },
+ prng::Prng,
+ vdaf::{
+ prio2::{
+ client::{self as v2_client, proof_length},
+ server as v2_server,
+ },
+ xof::Seed,
+ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare,
+ PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError,
+ },
+};
+use hmac::{Hmac, Mac};
+use rand_core::RngCore;
+use sha2::Sha256;
+use std::{convert::TryFrom, io::Cursor};
+use subtle::{Choice, ConstantTimeEq};
+
+mod client;
+mod server;
+#[cfg(test)]
+mod test_vector;
+
+/// The Prio2 VDAF. It supports the same measurement type as
+/// [`Prio3SumVec`](crate::vdaf::prio3::Prio3SumVec) with `bits == 1` but uses the proof system and
+/// finite field deployed in ENPA.
+#[derive(Clone, Debug)]
+pub struct Prio2 {
+ input_len: usize,
+}
+
+impl Prio2 {
+ /// Returns an instance of the VDAF for the given input length.
+ pub fn new(input_len: usize) -> Result<Self, VdafError> {
+ let n = (input_len + 1).next_power_of_two();
+ if let Ok(size) = u32::try_from(2 * n) {
+ if size > FieldPrio2::generator_order() {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+ } else {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds memory capacity".into(),
+ ));
+ }
+
+ Ok(Prio2 { input_len })
+ }
+
+ /// Prepare an input share for aggregation using the given field element `query_rand` to
+ /// compute the verifier share.
+ ///
+ /// In the [`Aggregator`] trait implementation for [`Prio2`], the query randomness is computed
+ /// jointly by the Aggregators. This method is designed to be used in applications, like ENPA,
+ /// in which the query randomness is instead chosen by a third-party.
+ pub fn prepare_init_with_query_rand(
+ &self,
+ query_rand: FieldPrio2,
+ input_share: &Share<FieldPrio2, 32>,
+ is_leader: bool,
+ ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
+ let expanded_data: Option<Vec<FieldPrio2>> = match input_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => {
+ let prng = Prng::from_prio2_seed(seed.as_ref());
+ Some(prng.take(proof_length(self.input_len)).collect())
+ }
+ };
+ let data = match input_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_data.as_ref().unwrap(),
+ };
+
+ let verifier_share = v2_server::generate_verification_message(
+ self.input_len,
+ query_rand,
+ data, // Combined input and proof shares
+ is_leader,
+ )
+ .map_err(|e| VdafError::Uncategorized(e.to_string()))?;
+
+ Ok((
+ Prio2PrepareState(input_share.truncated(self.input_len)),
+ Prio2PrepareShare(verifier_share),
+ ))
+ }
+
+ /// Choose a random point for polynomial evaluation.
+ ///
+ /// The point returned is not one of the roots used for polynomial interpolation.
+ pub(crate) fn choose_eval_at<S>(&self, prng: &mut Prng<FieldPrio2, S>) -> FieldPrio2
+ where
+ S: RngCore,
+ {
+ // Make sure the query randomness isn't a root of unity. Evaluating the proof at any of
+ // these points would be a privacy violation, since these points were used by the prover to
+ // construct the wire polynomials.
+ let n = (self.input_len + 1).next_power_of_two();
+ let proof_length = 2 * n;
+ loop {
+ let eval_at: FieldPrio2 = prng.get();
+ // Unwrap safety: the constructor checks that this conversion succeeds.
+ if eval_at.pow(u32::try_from(proof_length).unwrap()) != FieldPrio2::one() {
+ return eval_at;
+ }
+ }
+ }
+}
+
+impl Vdaf for Prio2 {
+ const ID: u32 = 0xFFFF0000;
+ type Measurement = Vec<u32>;
+ type AggregateResult = Vec<u32>;
+ type AggregationParam = ();
+ type PublicShare = ();
+ type InputShare = Share<FieldPrio2, 32>;
+ type OutputShare = OutputShare<FieldPrio2>;
+ type AggregateShare = AggregateShare<FieldPrio2>;
+
+ fn num_aggregators(&self) -> usize {
+ // Prio2 can easily be extended to support more than two Aggregators.
+ 2
+ }
+}
+
+impl Client<16> for Prio2 {
+ fn shard(
+ &self,
+ measurement: &Vec<u32>,
+ _nonce: &[u8; 16],
+ ) -> Result<(Self::PublicShare, Vec<Share<FieldPrio2, 32>>), VdafError> {
+ if measurement.len() != self.input_len {
+ return Err(VdafError::Uncategorized("incorrect input length".into()));
+ }
+ let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len());
+ for int in measurement {
+ input.push((*int).into());
+ }
+
+ let mut mem = v2_client::ClientMemory::new(self.input_len)?;
+ let copy_data = |share_data: &mut [FieldPrio2]| {
+ share_data[..].clone_from_slice(&input);
+ };
+ let mut leader_data = mem.prove_with(self.input_len, copy_data);
+
+ let helper_seed = Seed::generate()?;
+ let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref());
+ for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) {
+ *s1 -= d;
+ }
+
+ Ok((
+ (),
+ vec![Share::Leader(leader_data), Share::Helper(helper_seed)],
+ ))
+ }
+}
+
+/// State of each [`Aggregator`] during the Preparation phase.
+#[derive(Clone, Debug)]
+pub struct Prio2PrepareState(Share<FieldPrio2, 32>);
+
+impl PartialEq for Prio2PrepareState {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl Eq for Prio2PrepareState {}
+
+impl ConstantTimeEq for Prio2PrepareState {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl Encode for Prio2PrepareState {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ self.0.encoded_len()
+ }
+}
+
+impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
+ fn decode_with_param(
+ (prio2, agg_id): &(&'a Prio2, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let share_decoder = if *agg_id == 0 {
+ ShareDecodingParameter::Leader(prio2.input_len)
+ } else {
+ ShareDecodingParameter::Helper
+ };
+ let out_share = Share::decode_with_param(&share_decoder, bytes)?;
+ Ok(Self(out_share))
+ }
+}
+
+/// Message emitted by each [`Aggregator`] during the Preparation phase.
+#[derive(Clone, Debug)]
+pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);
+
+impl Encode for Prio2PrepareShare {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.0.f_r.encode(bytes);
+ self.0.g_r.encode(bytes);
+ self.0.h_r.encode(bytes);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(FieldPrio2::ENCODED_SIZE * 3)
+ }
+}
+
+impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
+ fn decode_with_param(
+ _state: &Prio2PrepareState,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ Ok(Self(v2_server::VerificationMessage {
+ f_r: FieldPrio2::decode(bytes)?,
+ g_r: FieldPrio2::decode(bytes)?,
+ h_r: FieldPrio2::decode(bytes)?,
+ }))
+ }
+}
+
+impl Aggregator<32, 16> for Prio2 {
+ type PrepareState = Prio2PrepareState;
+ type PrepareShare = Prio2PrepareShare;
+ type PrepareMessage = ();
+
+ fn prepare_init(
+ &self,
+ agg_key: &[u8; 32],
+ agg_id: usize,
+ _agg_param: &Self::AggregationParam,
+ nonce: &[u8; 16],
+ _public_share: &Self::PublicShare,
+ input_share: &Share<FieldPrio2, 32>,
+ ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
+ let is_leader = role_try_from(agg_id)?;
+
+ // In the ENPA Prio system, the query randomness is generated by a third party and
+ // distributed to the Aggregators after they receive their input shares. In a VDAF, shared
+ // randomness is derived from a nonce selected by the client. For Prio2 we compute the
+ // query using HMAC-SHA256 evaluated over the nonce.
+ //
+ // Unwrap safety: new_from_slice() is infallible for Hmac.
+ let mut mac = Hmac::<Sha256>::new_from_slice(agg_key).unwrap();
+ mac.update(nonce);
+ let hmac_tag = mac.finalize();
+ let mut prng = Prng::from_prio2_seed(&hmac_tag.into_bytes().into());
+ let query_rand = self.choose_eval_at(&mut prng);
+
+ self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
+ }
+
+ fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Prio2PrepareShare>>(
+ &self,
+ _: &Self::AggregationParam,
+ inputs: M,
+ ) -> Result<(), VdafError> {
+ let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
+ inputs.into_iter().map(|msg| msg.0).collect();
+ if verifier_shares.len() != 2 {
+ return Err(VdafError::Uncategorized(
+ "wrong number of verifier shares".into(),
+ ));
+ }
+
+ if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) {
+ return Err(VdafError::Uncategorized(
+ "proof verifier check failed".into(),
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn prepare_next(
+ &self,
+ state: Prio2PrepareState,
+ _input: (),
+ ) -> Result<PrepareTransition<Self, 32, 16>, VdafError> {
+ let data = match state.0 {
+ Share::Leader(data) => data,
+ Share::Helper(seed) => {
+ let prng = Prng::from_prio2_seed(seed.as_ref());
+ prng.take(self.input_len).collect()
+ }
+ };
+ Ok(PrepareTransition::Finish(OutputShare::from(data)))
+ }
+
+ fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>(
+ &self,
+ _agg_param: &Self::AggregationParam,
+ out_shares: M,
+ ) -> Result<AggregateShare<FieldPrio2>, VdafError> {
+ let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
+ for out_share in out_shares.into_iter() {
+ agg_share.accumulate(&out_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+impl Collector for Prio2 {
+ fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>(
+ &self,
+ _agg_param: &Self::AggregationParam,
+ agg_shares: M,
+ _num_measurements: usize,
+ ) -> Result<Vec<u32>, VdafError> {
+ let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
+ for agg_share in agg_shares.into_iter() {
+ agg.merge(&agg_share)?;
+ }
+
+ Ok(agg.0.into_iter().map(u32::from).collect())
+ }
+}
+
+impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> {
+ fn decode_with_param(
+ (prio2, agg_id): &(&'a Prio2, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
+ let decoder = if is_leader {
+ ShareDecodingParameter::Leader(proof_length(prio2.input_len))
+ } else {
+ ShareDecodingParameter::Helper
+ };
+
+ Share::decode_with_param(&decoder, bytes)
+ }
+}
+
+impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for OutputShare<F>
+where
+ F: FieldElement,
+{
+ fn decode_with_param(
+ (prio2, _): &(&'a Prio2, &'a ()),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ decode_fieldvec(prio2.input_len, bytes).map(Self)
+ }
+}
+
+impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for AggregateShare<F>
+where
+ F: FieldElement,
+{
+ fn decode_with_param(
+ (prio2, _): &(&'a Prio2, &'a ()),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ decode_fieldvec(prio2.input_len, bytes).map(Self)
+ }
+}
+
+fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
+ match agg_id {
+ 0 => Ok(true),
+ 1 => Ok(false),
+ _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::vdaf::{
+ equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector,
+ run_vdaf,
+ };
+ use assert_matches::assert_matches;
+ use rand::prelude::*;
+
+ #[test]
+ fn run_prio2() {
+ let prio2 = Prio2::new(6).unwrap();
+
+ assert_eq!(
+ run_vdaf(
+ &prio2,
+ &(),
+ [
+ vec![0, 0, 0, 0, 1, 0],
+ vec![0, 1, 0, 0, 0, 0],
+ vec![0, 1, 1, 0, 0, 0],
+ vec![1, 1, 1, 0, 0, 0],
+ vec![0, 0, 0, 0, 1, 1],
+ ]
+ )
+ .unwrap(),
+ vec![1, 3, 2, 0, 2, 1],
+ );
+ }
+
+ #[test]
+ fn prepare_state_serialization() {
+ let mut rng = thread_rng();
+ let verify_key = rng.gen::<[u8; 32]>();
+ let nonce = rng.gen::<[u8; 16]>();
+ let data = vec![0, 0, 1, 1, 0];
+ let prio2 = Prio2::new(data.len()).unwrap();
+ let (public_share, input_shares) = prio2.shard(&data, &nonce).unwrap();
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (prepare_state, prepare_share) = prio2
+ .prepare_init(
+ &verify_key,
+ agg_id,
+ &(),
+ &[0; 16],
+ &public_share,
+ input_share,
+ )
+ .unwrap();
+
+ let encoded_prepare_state = prepare_state.get_encoded();
+ let decoded_prepare_state = Prio2PrepareState::get_decoded_with_param(
+ &(&prio2, agg_id),
+ &encoded_prepare_state,
+ )
+ .expect("failed to decode prepare state");
+ assert_eq!(decoded_prepare_state, prepare_state);
+ assert_eq!(
+ prepare_state.encoded_len().unwrap(),
+ encoded_prepare_state.len()
+ );
+
+ let encoded_prepare_share = prepare_share.get_encoded();
+ let decoded_prepare_share =
+ Prio2PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share)
+ .expect("failed to decode prepare share");
+ assert_eq!(decoded_prepare_share.0.f_r, prepare_share.0.f_r);
+ assert_eq!(decoded_prepare_share.0.g_r, prepare_share.0.g_r);
+ assert_eq!(decoded_prepare_share.0.h_r, prepare_share.0.h_r);
+ assert_eq!(
+ prepare_share.encoded_len().unwrap(),
+ encoded_prepare_share.len()
+ );
+ }
+ }
+
+ #[test]
+ fn roundtrip_output_share() {
+ let vdaf = Prio2::new(31).unwrap();
+ fieldvec_roundtrip_test::<FieldPrio2, Prio2, OutputShare<FieldPrio2>>(&vdaf, &(), 31);
+ }
+
+ #[test]
+ fn roundtrip_aggregate_share() {
+ let vdaf = Prio2::new(31).unwrap();
+ fieldvec_roundtrip_test::<FieldPrio2, Prio2, AggregateShare<FieldPrio2>>(&vdaf, &(), 31);
+ }
+
+ #[test]
+ fn priov2_backward_compatibility() {
+ let test_vector: Priov2TestVector =
+ serde_json::from_str(include_str!("test_vec/prio2/fieldpriov2.json")).unwrap();
+ let vdaf = Prio2::new(test_vector.dimension).unwrap();
+ let mut leader_output_shares = Vec::new();
+ let mut helper_output_shares = Vec::new();
+ for (server_1_share, server_2_share) in test_vector
+ .server_1_decrypted_shares
+ .iter()
+ .zip(&test_vector.server_2_decrypted_shares)
+ {
+ let input_share_1 = Share::get_decoded_with_param(&(&vdaf, 0), server_1_share).unwrap();
+ let input_share_2 = Share::get_decoded_with_param(&(&vdaf, 1), server_2_share).unwrap();
+ let (prepare_state_1, prepare_share_1) = vdaf
+ .prepare_init(&[0; 32], 0, &(), &[0; 16], &(), &input_share_1)
+ .unwrap();
+ let (prepare_state_2, prepare_share_2) = vdaf
+ .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2)
+ .unwrap();
+ vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2])
+ .unwrap();
+ let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap();
+ let output_share_1 =
+ assert_matches!(transition_1, PrepareTransition::Finish(out) => out);
+ let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap();
+ let output_share_2 =
+ assert_matches!(transition_2, PrepareTransition::Finish(out) => out);
+ leader_output_shares.push(output_share_1);
+ helper_output_shares.push(output_share_2);
+ }
+
+ let leader_aggregate_share = vdaf.aggregate(&(), leader_output_shares).unwrap();
+ let helper_aggregate_share = vdaf.aggregate(&(), helper_output_shares).unwrap();
+ let aggregate_result = vdaf
+ .unshard(
+ &(),
+ [leader_aggregate_share, helper_aggregate_share],
+ test_vector.server_1_decrypted_shares.len(),
+ )
+ .unwrap();
+ let reconstructed = aggregate_result
+ .into_iter()
+ .map(FieldPrio2::from)
+ .collect::<Vec<_>>();
+
+ assert_eq!(reconstructed, test_vector.reference_sum);
+ }
+
+ #[test]
+ fn prepare_state_equality_test() {
+ equality_comparison_test(&[
+ Prio2PrepareState(Share::Leader(Vec::from([
+ FieldPrio2::from(0),
+ FieldPrio2::from(1),
+ ]))),
+ Prio2PrepareState(Share::Leader(Vec::from([
+ FieldPrio2::from(1),
+ FieldPrio2::from(0),
+ ]))),
+ Prio2PrepareState(Share::Helper(Seed(
+ (0..32).collect::<Vec<_>>().try_into().unwrap(),
+ ))),
+ Prio2PrepareState(Share::Helper(Seed(
+ (1..33).collect::<Vec<_>>().try_into().unwrap(),
+ ))),
+ ])
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio2/client.rs b/third_party/rust/prio/src/vdaf/prio2/client.rs
new file mode 100644
index 0000000000..dbce39ee3f
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio2/client.rs
@@ -0,0 +1,306 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Primitives for the Prio2 client.
+
+use crate::{
+ field::{FftFriendlyFieldElement, FieldError},
+ polynomial::{poly_fft, PolyAuxMemory},
+ prng::{Prng, PrngError},
+ vdaf::{xof::SeedStreamAes128, VdafError},
+};
+
+use std::convert::TryFrom;
+
+/// Errors that might be emitted by the client.
+#[derive(Debug, thiserror::Error)]
+pub(crate) enum ClientError {
+ /// PRNG error
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+ /// VDAF error
+ #[error("vdaf error: {0}")]
+ Vdaf(#[from] VdafError),
+ /// failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+}
+
+/// Serialization errors
+#[derive(Debug, thiserror::Error)]
+pub enum SerializeError {
+ /// Emitted by `unpack_proof[_mut]` if the serialized share+proof has the wrong length
+ #[error("serialized input has wrong length")]
+ UnpackInputSizeMismatch,
+ /// Finite field operation error.
+ #[error("finite field operation error")]
+ Field(#[from] FieldError),
+}
+
+#[derive(Debug)]
+pub(crate) struct ClientMemory<F> {
+ prng: Prng<F, SeedStreamAes128>,
+ points_f: Vec<F>,
+ points_g: Vec<F>,
+ evals_f: Vec<F>,
+ evals_g: Vec<F>,
+ poly_mem: PolyAuxMemory<F>,
+}
+
+impl<F: FftFriendlyFieldElement> ClientMemory<F> {
+ pub(crate) fn new(dimension: usize) -> Result<Self, VdafError> {
+ let n = (dimension + 1).next_power_of_two();
+ if let Ok(size) = F::Integer::try_from(2 * n) {
+ if size > F::generator_order() {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+ } else {
+ return Err(VdafError::Uncategorized(
+ "input size exceeds field capacity".into(),
+ ));
+ }
+
+ Ok(Self {
+ prng: Prng::new()?,
+ points_f: vec![F::zero(); n],
+ points_g: vec![F::zero(); n],
+ evals_f: vec![F::zero(); 2 * n],
+ evals_g: vec![F::zero(); 2 * n],
+ poly_mem: PolyAuxMemory::new(n),
+ })
+ }
+}
+
+impl<F: FftFriendlyFieldElement> ClientMemory<F> {
+ pub(crate) fn prove_with<G>(&mut self, dimension: usize, init_function: G) -> Vec<F>
+ where
+ G: FnOnce(&mut [F]),
+ {
+ let mut proof = vec![F::zero(); proof_length(dimension)];
+ // unpack one long vector to different subparts
+ let unpacked = unpack_proof_mut(&mut proof, dimension).unwrap();
+ // initialize the data part
+ init_function(unpacked.data);
+ // fill in the rest
+ construct_proof(
+ unpacked.data,
+ dimension,
+ unpacked.f0,
+ unpacked.g0,
+ unpacked.h0,
+ unpacked.points_h_packed,
+ self,
+ );
+
+ proof
+ }
+}
+
+/// Returns the number of field elements in the proof for given dimension of
+/// data elements
+///
+/// Proof is a vector, where the first `dimension` elements are the data
+/// elements, the next 3 elements are the zero terms for polynomials f, g and h
+/// and the remaining elements are non-zero points of h(x).
+pub(crate) fn proof_length(dimension: usize) -> usize {
+ // number of data items + number of zero terms + N
+ dimension + 3 + (dimension + 1).next_power_of_two()
+}
+
+/// Unpacked proof with subcomponents
+#[derive(Debug)]
+pub(crate) struct UnpackedProof<'a, F: FftFriendlyFieldElement> {
+ /// Data
+ pub data: &'a [F],
+ /// Zeroth coefficient of polynomial f
+ pub f0: &'a F,
+ /// Zeroth coefficient of polynomial g
+ pub g0: &'a F,
+ /// Zeroth coefficient of polynomial h
+ pub h0: &'a F,
+ /// Non-zero points of polynomial h
+ pub points_h_packed: &'a [F],
+}
+
+/// Unpacked proof with mutable subcomponents
+#[derive(Debug)]
+pub(crate) struct UnpackedProofMut<'a, F: FftFriendlyFieldElement> {
+ /// Data
+ pub data: &'a mut [F],
+ /// Zeroth coefficient of polynomial f
+ pub f0: &'a mut F,
+ /// Zeroth coefficient of polynomial g
+ pub g0: &'a mut F,
+ /// Zeroth coefficient of polynomial h
+ pub h0: &'a mut F,
+ /// Non-zero points of polynomial h
+ pub points_h_packed: &'a mut [F],
+}
+
+/// Unpacks the proof vector into subcomponents
+pub(crate) fn unpack_proof<F: FftFriendlyFieldElement>(
+ proof: &[F],
+ dimension: usize,
+) -> Result<UnpackedProof<F>, SerializeError> {
+ // check the proof length
+ if proof.len() != proof_length(dimension) {
+ return Err(SerializeError::UnpackInputSizeMismatch);
+ }
+ // split share into components
+ let (data, rest) = proof.split_at(dimension);
+ if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) {
+ Ok(UnpackedProof {
+ data,
+ f0,
+ g0,
+ h0,
+ points_h_packed,
+ })
+ } else {
+ Err(SerializeError::UnpackInputSizeMismatch)
+ }
+}
+
+/// Unpacks a mutable proof vector into mutable subcomponents
+pub(crate) fn unpack_proof_mut<F: FftFriendlyFieldElement>(
+ proof: &mut [F],
+ dimension: usize,
+) -> Result<UnpackedProofMut<F>, SerializeError> {
+ // check the share length
+ if proof.len() != proof_length(dimension) {
+ return Err(SerializeError::UnpackInputSizeMismatch);
+ }
+ // split share into components
+ let (data, rest) = proof.split_at_mut(dimension);
+ if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) {
+ Ok(UnpackedProofMut {
+ data,
+ f0,
+ g0,
+ h0,
+ points_h_packed,
+ })
+ } else {
+ Err(SerializeError::UnpackInputSizeMismatch)
+ }
+}
+
+fn interpolate_and_evaluate_at_2n<F: FftFriendlyFieldElement>(
+ n: usize,
+ points_in: &[F],
+ evals_out: &mut [F],
+ mem: &mut PolyAuxMemory<F>,
+) {
+ // interpolate through roots of unity
+ poly_fft(
+ &mut mem.coeffs,
+ points_in,
+ &mem.roots_n_inverted,
+ n,
+ true,
+ &mut mem.fft_memory,
+ );
+ // evaluate at 2N roots of unity
+ poly_fft(
+ evals_out,
+ &mem.coeffs,
+ &mem.roots_2n,
+ 2 * n,
+ false,
+ &mut mem.fft_memory,
+ );
+}
+
+/// Proof construction
+///
+/// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation
+/// This constructs the output \pi by doing the necessesary calculations
+fn construct_proof<F: FftFriendlyFieldElement>(
+ data: &[F],
+ dimension: usize,
+ f0: &mut F,
+ g0: &mut F,
+ h0: &mut F,
+ points_h_packed: &mut [F],
+ mem: &mut ClientMemory<F>,
+) {
+ let n = (dimension + 1).next_power_of_two();
+
+ // set zero terms to random
+ *f0 = mem.prng.get();
+ *g0 = mem.prng.get();
+ mem.points_f[0] = *f0;
+ mem.points_g[0] = *g0;
+
+ // set zero term for the proof polynomial
+ *h0 = *f0 * *g0;
+
+ // set f_i = data_(i - 1)
+ // set g_i = f_i - 1
+ for ((f_coeff, g_coeff), data_val) in mem.points_f[1..1 + dimension]
+ .iter_mut()
+ .zip(mem.points_g[1..1 + dimension].iter_mut())
+ .zip(data[..dimension].iter())
+ {
+ *f_coeff = *data_val;
+ *g_coeff = *data_val - F::one();
+ }
+
+ // interpolate and evaluate at roots of unity
+ interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem);
+ interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem);
+
+ // calculate the proof polynomial as evals_f(r) * evals_g(r)
+ // only add non-zero points
+ let mut j: usize = 0;
+ let mut i: usize = 1;
+ while i < 2 * n {
+ points_h_packed[j] = mem.evals_f[i] * mem.evals_g[i];
+ j += 1;
+ i += 2;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+
+ use crate::{
+ field::{Field64, FieldPrio2},
+ vdaf::prio2::client::{proof_length, unpack_proof, unpack_proof_mut, SerializeError},
+ };
+
+ #[test]
+ fn test_unpack_share_mut() {
+ let dim = 15;
+ let len = proof_length(dim);
+
+ let mut share = vec![FieldPrio2::from(0); len];
+ let unpacked = unpack_proof_mut(&mut share, dim).unwrap();
+ *unpacked.f0 = FieldPrio2::from(12);
+ assert_eq!(share[dim], 12);
+
+ let mut short_share = vec![FieldPrio2::from(0); len - 1];
+ assert_matches!(
+ unpack_proof_mut(&mut short_share, dim),
+ Err(SerializeError::UnpackInputSizeMismatch)
+ );
+ }
+
+ #[test]
+ fn test_unpack_share() {
+ let dim = 15;
+ let len = proof_length(dim);
+
+ let share = vec![Field64::from(0); len];
+ unpack_proof(&share, dim).unwrap();
+
+ let short_share = vec![Field64::from(0); len - 1];
+ assert_matches!(
+ unpack_proof(&short_share, dim),
+ Err(SerializeError::UnpackInputSizeMismatch)
+ );
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio2/server.rs b/third_party/rust/prio/src/vdaf/prio2/server.rs
new file mode 100644
index 0000000000..11c161babf
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio2/server.rs
@@ -0,0 +1,386 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! Primitives for the Prio2 server.
+use crate::{
+ field::{FftFriendlyFieldElement, FieldError},
+ polynomial::poly_interpret_eval,
+ prng::PrngError,
+ vdaf::prio2::client::{unpack_proof, SerializeError},
+};
+use serde::{Deserialize, Serialize};
+
+/// Possible errors from server operations
+#[derive(Debug, thiserror::Error)]
+pub enum ServerError {
+ /// Unexpected Share Length
+ #[allow(unused)]
+ #[error("unexpected share length")]
+ ShareLength,
+ /// Finite field operation error
+ #[error("finite field operation error")]
+ Field(#[from] FieldError),
+ /// Serialization/deserialization error
+ #[error("serialization/deserialization error")]
+ Serialize(#[from] SerializeError),
+ /// Failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+ /// PRNG error.
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+}
+
+/// Verification message for proof validation
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct VerificationMessage<F> {
+ /// f evaluated at random point
+ pub f_r: F,
+ /// g evaluated at random point
+ pub g_r: F,
+ /// h evaluated at random point
+ pub h_r: F,
+}
+
+/// Given a proof and evaluation point, this constructs the verification
+/// message.
+pub(crate) fn generate_verification_message<F: FftFriendlyFieldElement>(
+ dimension: usize,
+ eval_at: F,
+ proof: &[F],
+ is_first_server: bool,
+) -> Result<VerificationMessage<F>, ServerError> {
+ let unpacked = unpack_proof(proof, dimension)?;
+ let n: usize = (dimension + 1).next_power_of_two();
+ let proof_length = 2 * n;
+ let mut fft_in = vec![F::zero(); proof_length];
+ let mut fft_mem = vec![F::zero(); proof_length];
+
+ // construct and evaluate polynomial f at the random point
+ fft_in[0] = *unpacked.f0;
+ fft_in[1..unpacked.data.len() + 1].copy_from_slice(unpacked.data);
+ let f_r = poly_interpret_eval(&fft_in[..n], eval_at, &mut fft_mem);
+
+ // construct and evaluate polynomial g at the random point
+ fft_in[0] = *unpacked.g0;
+ if is_first_server {
+ for x in fft_in[1..unpacked.data.len() + 1].iter_mut() {
+ *x -= F::one();
+ }
+ }
+ let g_r = poly_interpret_eval(&fft_in[..n], eval_at, &mut fft_mem);
+
+ // construct and evaluate polynomial h at the random point
+ fft_in[0] = *unpacked.h0;
+ fft_in[1] = unpacked.points_h_packed[0];
+ for (x, chunk) in unpacked.points_h_packed[1..]
+ .iter()
+ .zip(fft_in[2..proof_length].chunks_exact_mut(2))
+ {
+ chunk[0] = F::zero();
+ chunk[1] = *x;
+ }
+ let h_r = poly_interpret_eval(&fft_in, eval_at, &mut fft_mem);
+
+ Ok(VerificationMessage { f_r, g_r, h_r })
+}
+
+/// Decides if the distributed proof is valid
+pub(crate) fn is_valid_share<F: FftFriendlyFieldElement>(
+ v1: &VerificationMessage<F>,
+ v2: &VerificationMessage<F>,
+) -> bool {
+ // reconstruct f_r, g_r, h_r
+ let f_r = v1.f_r + v2.f_r;
+ let g_r = v1.g_r + v2.g_r;
+ let h_r = v1.h_r + v2.h_r;
+ // validity check
+ f_r * g_r == h_r
+}
+
+#[cfg(test)]
+mod test_util {
+ use crate::{
+ field::{merge_vector, FftFriendlyFieldElement},
+ prng::Prng,
+ vdaf::prio2::client::proof_length,
+ };
+
+ use super::{generate_verification_message, is_valid_share, ServerError, VerificationMessage};
+
+ /// Main workhorse of the server.
+ #[derive(Debug)]
+ pub(crate) struct Server<F> {
+ dimension: usize,
+ is_first_server: bool,
+ accumulator: Vec<F>,
+ }
+
+ impl<F: FftFriendlyFieldElement> Server<F> {
+ /// Construct a new server instance
+ ///
+ /// Params:
+ /// * `dimension`: the number of elements in the aggregation vector.
+ /// * `is_first_server`: only one of the servers should have this true.
+ pub fn new(dimension: usize, is_first_server: bool) -> Result<Server<F>, ServerError> {
+ Ok(Server {
+ dimension,
+ is_first_server,
+ accumulator: vec![F::zero(); dimension],
+ })
+ }
+
+ /// Deserialize
+ fn deserialize_share(&self, share: &[u8]) -> Result<Vec<F>, ServerError> {
+ let len = proof_length(self.dimension);
+ Ok(if self.is_first_server {
+ F::byte_slice_into_vec(share)?
+ } else {
+ if share.len() != 32 {
+ return Err(ServerError::ShareLength);
+ }
+
+ Prng::from_prio2_seed(&share.try_into().unwrap())
+ .take(len)
+ .collect()
+ })
+ }
+
+ /// Generate verification message from an encrypted share
+ ///
+ /// This decrypts the share of the proof and constructs the
+ /// [`VerificationMessage`](struct.VerificationMessage.html).
+ /// The `eval_at` field should be generate by
+ /// [choose_eval_at](#method.choose_eval_at).
+ pub fn generate_verification_message(
+ &mut self,
+ eval_at: F,
+ share: &[u8],
+ ) -> Result<VerificationMessage<F>, ServerError> {
+ let share_field = self.deserialize_share(share)?;
+ generate_verification_message(
+ self.dimension,
+ eval_at,
+ &share_field,
+ self.is_first_server,
+ )
+ }
+
+ /// Add the content of the encrypted share into the accumulator
+ ///
+ /// This only changes the accumulator if the verification messages `v1` and
+ /// `v2` indicate that the share passed validation.
+ pub fn aggregate(
+ &mut self,
+ share: &[u8],
+ v1: &VerificationMessage<F>,
+ v2: &VerificationMessage<F>,
+ ) -> Result<bool, ServerError> {
+ let share_field = self.deserialize_share(share)?;
+ let is_valid = is_valid_share(v1, v2);
+ if is_valid {
+ // Add to the accumulator. share_field also includes the proof
+ // encoding, so we slice off the first dimension fields, which are
+ // the actual data share.
+ merge_vector(&mut self.accumulator, &share_field[..self.dimension])?;
+ }
+
+ Ok(is_valid)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ codec::Encode,
+ field::{FieldElement, FieldPrio2},
+ prng::Prng,
+ vdaf::{
+ prio2::{client::unpack_proof_mut, server::test_util::Server, Prio2},
+ Client,
+ },
+ };
+ use rand::{random, Rng};
+
+ fn secret_share(share: &mut [FieldPrio2]) -> Vec<FieldPrio2> {
+ let mut rng = rand::thread_rng();
+ let mut share2 = vec![FieldPrio2::zero(); share.len()];
+ for (f1, f2) in share.iter_mut().zip(share2.iter_mut()) {
+ let f = FieldPrio2::from(rng.gen::<u32>());
+ *f2 = f;
+ *f1 -= f;
+ }
+ share2
+ }
+
+ #[test]
+ fn test_validation() {
+ let dim = 8;
+ let proof_u32: Vec<u32> = vec![
+ 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
+ 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
+ 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
+ ];
+
+ let mut proof: Vec<FieldPrio2> = proof_u32.iter().map(|x| FieldPrio2::from(*x)).collect();
+ let share2 = secret_share(&mut proof);
+ let eval_at = FieldPrio2::from(12313);
+
+ let v1 = generate_verification_message(dim, eval_at, &proof, true).unwrap();
+ let v2 = generate_verification_message(dim, eval_at, &share2, false).unwrap();
+ assert!(is_valid_share(&v1, &v2));
+ }
+
+ #[test]
+ fn test_verification_message_serde() {
+ let dim = 8;
+ let proof_u32: Vec<u32> = vec![
+ 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
+ 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
+ 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
+ ];
+
+ let mut proof: Vec<FieldPrio2> = proof_u32.iter().map(|x| FieldPrio2::from(*x)).collect();
+ let share2 = secret_share(&mut proof);
+ let eval_at = FieldPrio2::from(12313);
+
+ let v1 = generate_verification_message(dim, eval_at, &proof, true).unwrap();
+ let v2 = generate_verification_message(dim, eval_at, &share2, false).unwrap();
+
+ // serialize and deserialize the first verification message
+ let serialized = serde_json::to_string(&v1).unwrap();
+ let deserialized: VerificationMessage<FieldPrio2> =
+ serde_json::from_str(&serialized).unwrap();
+
+ assert!(is_valid_share(&deserialized, &v2));
+ }
+
+ #[derive(Debug, Clone, Copy, PartialEq)]
+ enum Tweak {
+ None,
+ WrongInput,
+ DataPartOfShare,
+ ZeroTermF,
+ ZeroTermG,
+ ZeroTermH,
+ PointsH,
+ VerificationF,
+ VerificationG,
+ VerificationH,
+ }
+
+ fn tweaks(tweak: Tweak) {
+ let dim = 123;
+
+ let mut server1 = Server::<FieldPrio2>::new(dim, true).unwrap();
+ let mut server2 = Server::new(dim, false).unwrap();
+
+ // all zero data
+ let mut data = vec![0; dim];
+
+ if let Tweak::WrongInput = tweak {
+ data[0] = 2;
+ }
+
+ let vdaf = Prio2::new(dim).unwrap();
+ let (_, shares) = vdaf.shard(&data, &[0; 16]).unwrap();
+ let share1_original = shares[0].get_encoded();
+ let share2 = shares[1].get_encoded();
+
+ let mut share1_field = FieldPrio2::byte_slice_into_vec(&share1_original).unwrap();
+ let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap();
+
+ let one = FieldPrio2::from(1);
+
+ match tweak {
+ Tweak::DataPartOfShare => unpacked_share1.data[0] += one,
+ Tweak::ZeroTermF => *unpacked_share1.f0 += one,
+ Tweak::ZeroTermG => *unpacked_share1.g0 += one,
+ Tweak::ZeroTermH => *unpacked_share1.h0 += one,
+ Tweak::PointsH => unpacked_share1.points_h_packed[0] += one,
+ _ => (),
+ };
+
+ // reserialize altered share1
+ let share1_modified = FieldPrio2::slice_into_byte_vec(&share1_field);
+
+ let mut prng = Prng::from_prio2_seed(&random());
+ let eval_at = vdaf.choose_eval_at(&mut prng);
+
+ let mut v1 = server1
+ .generate_verification_message(eval_at, &share1_modified)
+ .unwrap();
+ let v2 = server2
+ .generate_verification_message(eval_at, &share2)
+ .unwrap();
+
+ match tweak {
+ Tweak::VerificationF => v1.f_r += one,
+ Tweak::VerificationG => v1.g_r += one,
+ Tweak::VerificationH => v1.h_r += one,
+ _ => (),
+ }
+
+ let should_be_valid = matches!(tweak, Tweak::None);
+ assert_eq!(
+ server1.aggregate(&share1_modified, &v1, &v2).unwrap(),
+ should_be_valid
+ );
+ assert_eq!(
+ server2.aggregate(&share2, &v1, &v2).unwrap(),
+ should_be_valid
+ );
+ }
+
+ #[test]
+ fn tweak_none() {
+ tweaks(Tweak::None);
+ }
+
+ #[test]
+ fn tweak_input() {
+ tweaks(Tweak::WrongInput);
+ }
+
+ #[test]
+ fn tweak_data() {
+ tweaks(Tweak::DataPartOfShare);
+ }
+
+ #[test]
+ fn tweak_f_zero() {
+ tweaks(Tweak::ZeroTermF);
+ }
+
+ #[test]
+ fn tweak_g_zero() {
+ tweaks(Tweak::ZeroTermG);
+ }
+
+ #[test]
+ fn tweak_h_zero() {
+ tweaks(Tweak::ZeroTermH);
+ }
+
+ #[test]
+ fn tweak_h_points() {
+ tweaks(Tweak::PointsH);
+ }
+
+ #[test]
+ fn tweak_f_verif() {
+ tweaks(Tweak::VerificationF);
+ }
+
+ #[test]
+ fn tweak_g_verif() {
+ tweaks(Tweak::VerificationG);
+ }
+
+ #[test]
+ fn tweak_h_verif() {
+ tweaks(Tweak::VerificationH);
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio2/test_vector.rs b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
new file mode 100644
index 0000000000..ae2b8b0f9d
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs
@@ -0,0 +1,83 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Test vectors of serialized Prio inputs, enabling backward compatibility testing.
+
+use crate::{field::FieldPrio2, vdaf::prio2::client::ClientError};
+use serde::{Deserialize, Serialize};
+use std::fmt::Debug;
+
+/// Errors propagated by functions in this module.
+#[derive(Debug, thiserror::Error)]
+pub(crate) enum TestVectorError {
+ /// Error from Prio client
+ #[error("Prio client error {0}")]
+ Client(#[from] ClientError),
+}
+
+/// A test vector of serialized Priov2 inputs, along with a reference sum. The field is always
+/// [`FieldPrio2`].
+#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub(crate) struct Priov2TestVector {
+ /// Dimension (number of buckets) of the inputs
+ pub dimension: usize,
+ /// Decrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA"
+ /// server. The inner `Vec`s are encrypted bytes.
+ #[serde(
+ serialize_with = "base64::serialize_bytes",
+ deserialize_with = "base64::deserialize_bytes"
+ )]
+ pub server_1_decrypted_shares: Vec<Vec<u8>>,
+ /// Decrypted share of Priov2 format inputs for the non-"first" a.k.a.
+ /// "facilitator" server.
+ #[serde(
+ serialize_with = "base64::serialize_bytes",
+ deserialize_with = "base64::deserialize_bytes"
+ )]
+ pub server_2_decrypted_shares: Vec<Vec<u8>>,
+ /// The sum over the inputs.
+ #[serde(
+ serialize_with = "base64::serialize_field",
+ deserialize_with = "base64::deserialize_field"
+ )]
+ pub reference_sum: Vec<FieldPrio2>,
+}
+
+mod base64 {
+ //! Custom serialization module used for some members of struct
+ //! `Priov2TestVector` so that byte slices are serialized as base64 strings
+ //! instead of an array of an array of integers when serializing to JSON.
+ //
+ // Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2
+ use crate::field::{FieldElement, FieldPrio2};
+ use base64::{engine::Engine, prelude::BASE64_STANDARD};
+ use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
+
+ pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> {
+ let base64_vec = v
+ .iter()
+ .map(|bytes| BASE64_STANDARD.encode(bytes))
+ .collect();
+ <Vec<String>>::serialize(&base64_vec, s)
+ }
+
+ pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> {
+ <Vec<String>>::deserialize(d)?
+ .iter()
+ .map(|s| BASE64_STANDARD.decode(s.as_bytes()).map_err(Error::custom))
+ .collect()
+ }
+
+ pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> {
+ String::serialize(
+ &BASE64_STANDARD.encode(FieldPrio2::slice_into_byte_vec(v)),
+ s,
+ )
+ }
+
+ pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> {
+ let bytes = BASE64_STANDARD
+ .decode(String::deserialize(d)?.as_bytes())
+ .map_err(Error::custom)?;
+ FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom)
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs
new file mode 100644
index 0000000000..4a7cdefb84
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio3.rs
@@ -0,0 +1,2127 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-07]].
+//!
+//! **WARNING:** This code has not undergone significant security analysis. Use at your own risk.
+//!
+//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented
+//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO
+//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication
+//! cost. The security of the construction was analyzed in [[DPRS23]].
+//!
+//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-07]] into
+//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of
+//! which are instantiated here:
+//!
+//! - [`Prio3Count`] for aggregating a counter (*)
+//! - [`Prio3Sum`] for copmputing the sum of integers (*)
+//! - [`Prio3SumVec`] for aggregating a vector of integers
+//! - [`Prio3Histogram`] for estimating a distribution via a histogram (*)
+//!
+//! Additional types can be constructed from [`Prio3`] as needed.
+//!
+//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-07]].
+//!
+//! [BBCG+19]: https://ia.cr/2019/188
+//! [CGB17]: https://crypto.stanford.edu/prio/
+//! [DPRS23]: https://ia.cr/2023/130
+//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+
+use super::xof::XofShake128;
+#[cfg(feature = "experimental")]
+use super::AggregatorWithNoise;
+use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
+#[cfg(feature = "experimental")]
+use crate::dp::DifferentialPrivacyStrategy;
+use crate::field::{decode_fieldvec, FftFriendlyFieldElement, FieldElement};
+use crate::field::{Field128, Field64};
+#[cfg(feature = "multithreaded")]
+use crate::flp::gadgets::ParallelSumMultithreaded;
+#[cfg(feature = "experimental")]
+use crate::flp::gadgets::PolyEval;
+use crate::flp::gadgets::{Mul, ParallelSum};
+#[cfg(feature = "experimental")]
+use crate::flp::types::fixedpoint_l2::{
+ compatible_float::CompatibleFloat, FixedPointBoundedL2VecSum,
+};
+use crate::flp::types::{Average, Count, Histogram, Sum, SumVec};
+use crate::flp::Type;
+#[cfg(feature = "experimental")]
+use crate::flp::TypeWithNoise;
+use crate::prng::Prng;
+use crate::vdaf::xof::{IntoFieldVec, Seed, Xof};
+use crate::vdaf::{
+ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
+ Share, ShareDecodingParameter, Vdaf, VdafError,
+};
+#[cfg(feature = "experimental")]
+use fixed::traits::Fixed;
+use std::convert::TryFrom;
+use std::fmt::Debug;
+use std::io::Cursor;
+use std::iter::{self, IntoIterator};
+use std::marker::PhantomData;
+use subtle::{Choice, ConstantTimeEq};
+
+const DST_MEASUREMENT_SHARE: u16 = 1;
+const DST_PROOF_SHARE: u16 = 2;
+const DST_JOINT_RANDOMNESS: u16 = 3;
+const DST_PROVE_RANDOMNESS: u16 = 4;
+const DST_QUERY_RANDOMNESS: u16 = 5;
+const DST_JOINT_RAND_SEED: u16 = 6;
+const DST_JOINT_RAND_PART: u16 = 7;
+
+/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum.
+pub type Prio3Count = Prio3<Count<Field64>, XofShake128, 16>;
+
+impl Prio3Count {
+ /// Construct an instance of Prio3Count with the given number of aggregators.
+ pub fn new_count(num_aggregators: u8) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, Count::new())
+ }
+}
+
+/// The count-vector type. Each measurement is a vector of integers in `[0,2^bits)` and the
+/// aggregate is the element-wise sum.
+pub type Prio3SumVec =
+ Prio3<SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>, XofShake128, 16>;
+
+impl Prio3SumVec {
+ /// Construct an instance of Prio3SumVec with the given number of aggregators. `bits` defines
+ /// the bit width of each summand of the measurement; `len` defines the length of the
+ /// measurement vector.
+ pub fn new_sum_vec(
+ num_aggregators: u8,
+ bits: usize,
+ len: usize,
+ chunk_length: usize,
+ ) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, SumVec::new(bits, len, chunk_length)?)
+ }
+}
+
+/// Like [`Prio3SumVec`] except this type uses multithreading to improve sharding and preparation
+/// time. Note that the improvement is only noticeable for very large input lengths.
+#[cfg(feature = "multithreaded")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+pub type Prio3SumVecMultithreaded =
+ Prio3<SumVec<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>, XofShake128, 16>;
+
+#[cfg(feature = "multithreaded")]
+impl Prio3SumVecMultithreaded {
+ /// Construct an instance of Prio3SumVecMultithreaded with the given number of
+ /// aggregators. `bits` defines the bit width of each summand of the measurement; `len` defines
+ /// the length of the measurement vector.
+ pub fn new_sum_vec_multithreaded(
+ num_aggregators: u8,
+ bits: usize,
+ len: usize,
+ chunk_length: usize,
+ ) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, SumVec::new(bits, len, chunk_length)?)
+ }
+}
+
+/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the
+/// aggregate is the sum.
+pub type Prio3Sum = Prio3<Sum<Field128>, XofShake128, 16>;
+
+impl Prio3Sum {
+ /// Construct an instance of Prio3Sum with the given number of aggregators and required bit
+ /// length. The bit length must not exceed 64.
+ pub fn new_sum(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> {
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({bits}) exceeds limit for aggregate type (64)"
+ )));
+ }
+
+ Prio3::new(num_aggregators, Sum::new(bits)?)
+ }
+}
+
+/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers
+/// and the aggregate is the sum represented as 64-bit floats. The preparation phase
+/// ensures the L2 norm of the input vector is < 1.
+///
+/// This is useful for aggregating gradients in a federated version of
+/// [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) with
+/// [differential privacy](https://en.wikipedia.org/wiki/Differential_privacy),
+/// useful, e.g., for [differentially private deep learning](https://arxiv.org/pdf/1607.00133.pdf).
+/// The bound on input norms is required for differential privacy. The fixed point representation
+/// allows an easy conversion to the integer type used in internal computation, while leaving
+/// conversion to the client. The model itself will have floating point parameters, so the output
+/// sum has that type as well.
+#[cfg(feature = "experimental")]
+#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
+pub type Prio3FixedPointBoundedL2VecSum<Fx> = Prio3<
+ FixedPointBoundedL2VecSum<
+ Fx,
+ ParallelSum<Field128, PolyEval<Field128>>,
+ ParallelSum<Field128, Mul<Field128>>,
+ >,
+ XofShake128,
+ 16,
+>;
+
+#[cfg(feature = "experimental")]
+impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSum<Fx> {
+ /// Construct an instance of this VDAF with the given number of aggregators and number of
+ /// vector entries.
+ pub fn new_fixedpoint_boundedl2_vec_sum(
+ num_aggregators: u8,
+ entries: usize,
+ ) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+ Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?)
+ }
+}
+
+/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers
+/// and the aggregate is the sum represented as 64-bit floats. The verification function
+/// ensures the L2 norm of the input vector is < 1.
+#[cfg(all(feature = "experimental", feature = "multithreaded"))]
+#[cfg_attr(
+ docsrs,
+ doc(cfg(all(feature = "experimental", feature = "multithreaded")))
+)]
+pub type Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> = Prio3<
+ FixedPointBoundedL2VecSum<
+ Fx,
+ ParallelSumMultithreaded<Field128, PolyEval<Field128>>,
+ ParallelSumMultithreaded<Field128, Mul<Field128>>,
+ >,
+ XofShake128,
+ 16,
+>;
+
+#[cfg(all(feature = "experimental", feature = "multithreaded"))]
+impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> {
+ /// Construct an instance of this VDAF with the given number of aggregators and number of
+ /// vector entries.
+ pub fn new_fixedpoint_boundedl2_vec_sum_multithreaded(
+ num_aggregators: u8,
+ entries: usize,
+ ) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+ Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?)
+ }
+}
+
+/// The histogram type. Each measurement is an integer in `[0, length)` and the result is a
+/// histogram counting the number of occurrences of each measurement.
+pub type Prio3Histogram =
+ Prio3<Histogram<Field128, ParallelSum<Field128, Mul<Field128>>>, XofShake128, 16>;
+
+impl Prio3Histogram {
+ /// Constructs an instance of Prio3Histogram with the given number of aggregators,
+ /// number of buckets, and parallel sum gadget chunk length.
+ pub fn new_histogram(
+ num_aggregators: u8,
+ length: usize,
+ chunk_length: usize,
+ ) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, Histogram::new(length, chunk_length)?)
+ }
+}
+
+/// Like [`Prio3Histogram`] except this type uses multithreading to improve sharding and preparation
+/// time. Note that this improvement is only noticeable for very large input lengths.
+#[cfg(feature = "multithreaded")]
+#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
+pub type Prio3HistogramMultithreaded =
+ Prio3<Histogram<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>, XofShake128, 16>;
+
+#[cfg(feature = "multithreaded")]
+impl Prio3HistogramMultithreaded {
+ /// Construct an instance of Prio3HistogramMultithreaded with the given number of aggregators,
+ /// number of buckets, and parallel sum gadget chunk length.
+ pub fn new_histogram_multithreaded(
+ num_aggregators: u8,
+ length: usize,
+ chunk_length: usize,
+ ) -> Result<Self, VdafError> {
+ Prio3::new(num_aggregators, Histogram::new(length, chunk_length)?)
+ }
+}
+
+/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and
+/// the aggregate is the arithmetic average.
+pub type Prio3Average = Prio3<Average<Field128>, XofShake128, 16>;
+
+impl Prio3Average {
+ /// Construct an instance of Prio3Average with the given number of aggregators and required bit
+ /// length. The bit length must not exceed 64.
+ pub fn new_average(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({bits}) exceeds limit for aggregate type (64)"
+ )));
+ }
+
+ Ok(Prio3 {
+ num_aggregators,
+ typ: Average::new(bits)?,
+ phantom: PhantomData,
+ })
+ }
+}
+
+/// The base type for Prio3.
+///
+/// An instance of Prio3 is determined by:
+///
+/// - a [`Type`] that defines the set of valid input measurements; and
+/// - a [`Xof`] for deriving vectors of field elements from seeds.
+///
+/// New instances can be defined by aliasing the base type. For example, [`Prio3Count`] is an alias
+/// for `Prio3<Count<Field64>, XofShake128, 16>`.
+///
+/// ```
+/// use prio::vdaf::{
+/// Aggregator, Client, Collector, PrepareTransition,
+/// prio3::Prio3,
+/// };
+/// use rand::prelude::*;
+///
+/// let num_shares = 2;
+/// let vdaf = Prio3::new_count(num_shares).unwrap();
+///
+/// let mut out_shares = vec![vec![]; num_shares.into()];
+/// let mut rng = thread_rng();
+/// let verify_key = rng.gen();
+/// let measurements = [0, 1, 1, 1, 0];
+/// for measurement in measurements {
+/// // Shard
+/// let nonce = rng.gen::<[u8; 16]>();
+/// let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap();
+///
+/// // Prepare
+/// let mut prep_states = vec![];
+/// let mut prep_shares = vec![];
+/// for (agg_id, input_share) in input_shares.iter().enumerate() {
+/// let (state, share) = vdaf.prepare_init(
+/// &verify_key,
+/// agg_id,
+/// &(),
+/// &nonce,
+/// &public_share,
+/// input_share
+/// ).unwrap();
+/// prep_states.push(state);
+/// prep_shares.push(share);
+/// }
+/// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap();
+///
+/// for (agg_id, state) in prep_states.into_iter().enumerate() {
+/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() {
+/// PrepareTransition::Finish(out_share) => out_share,
+/// _ => panic!("unexpected transition"),
+/// };
+/// out_shares[agg_id].push(out_share);
+/// }
+/// }
+///
+/// // Aggregate
+/// let agg_shares = out_shares.into_iter()
+/// .map(|o| vdaf.aggregate(&(), o).unwrap());
+///
+/// // Unshard
+/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap();
+/// assert_eq!(agg_res, 3);
+/// ```
+#[derive(Clone, Debug)]
+pub struct Prio3<T, P, const SEED_SIZE: usize>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ num_aggregators: u8,
+ typ: T,
+ phantom: PhantomData<P>,
+}
+
+impl<T, P, const SEED_SIZE: usize> Prio3<T, P, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the
+ /// underlying type.
+ pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> {
+ check_num_aggregators(num_aggregators)?;
+ Ok(Self {
+ num_aggregators,
+ typ,
+ phantom: PhantomData,
+ })
+ }
+
+ /// The output length of the underlying FLP.
+ pub fn output_len(&self) -> usize {
+ self.typ.output_len()
+ }
+
+ /// The verifier length of the underlying FLP.
+ pub fn verifier_len(&self) -> usize {
+ self.typ.verifier_len()
+ }
+
+ fn derive_joint_rand_seed<'a>(
+ parts: impl Iterator<Item = &'a Seed<SEED_SIZE>>,
+ ) -> Seed<SEED_SIZE> {
+ let mut xof = P::init(
+ &[0; SEED_SIZE],
+ &Self::domain_separation_tag(DST_JOINT_RAND_SEED),
+ );
+ for part in parts {
+ xof.update(part.as_ref());
+ }
+ xof.into_seed()
+ }
+
+ fn random_size(&self) -> usize {
+ if self.typ.joint_rand_len() == 0 {
+ // Two seeds per helper for measurement and proof shares, plus one seed for proving
+ // randomness.
+ (usize::from(self.num_aggregators - 1) * 2 + 1) * SEED_SIZE
+ } else {
+ (
+ // Two seeds per helper for measurement and proof shares
+ usize::from(self.num_aggregators - 1) * 2
+ // One seed for proving randomness
+ + 1
+ // One seed per aggregator for joint randomness blinds
+ + usize::from(self.num_aggregators)
+ ) * SEED_SIZE
+ }
+ }
+
+ #[allow(clippy::type_complexity)]
+ pub(crate) fn shard_with_random<const N: usize>(
+ &self,
+ measurement: &T::Measurement,
+ nonce: &[u8; N],
+ random: &[u8],
+ ) -> Result<
+ (
+ Prio3PublicShare<SEED_SIZE>,
+ Vec<Prio3InputShare<T::Field, SEED_SIZE>>,
+ ),
+ VdafError,
+ > {
+ if random.len() != self.random_size() {
+ return Err(VdafError::Uncategorized(
+ "incorrect random input length".to_string(),
+ ));
+ }
+ let mut random_seeds = random.chunks_exact(SEED_SIZE);
+ let num_aggregators = self.num_aggregators;
+ let encoded_measurement = self.typ.encode_measurement(measurement)?;
+
+ // Generate the measurement shares and compute the joint randomness.
+ let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1);
+ let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 {
+ Some(Vec::with_capacity(num_aggregators as usize - 1))
+ } else {
+ None
+ };
+ let mut leader_measurement_share = encoded_measurement.clone();
+ for agg_id in 1..num_aggregators {
+ // The Option from the ChunksExact iterator is okay to unwrap because we checked that
+ // the randomness slice is long enough for this VDAF. The slice-to-array conversion
+ // Result is okay to unwrap because the ChunksExact iterator always returns slices of
+ // the correct length.
+ let measurement_share_seed = random_seeds.next().unwrap().try_into().unwrap();
+ let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap();
+ let measurement_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
+ &Seed(measurement_share_seed),
+ &Self::domain_separation_tag(DST_MEASUREMENT_SHARE),
+ &[agg_id],
+ ));
+ let joint_rand_blind =
+ if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() {
+ let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap();
+ let mut joint_rand_part_xof = P::init(
+ &joint_rand_blind,
+ &Self::domain_separation_tag(DST_JOINT_RAND_PART),
+ );
+ joint_rand_part_xof.update(&[agg_id]); // Aggregator ID
+ joint_rand_part_xof.update(nonce);
+
+ let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
+ for (x, y) in leader_measurement_share
+ .iter_mut()
+ .zip(measurement_share_prng)
+ {
+ *x -= y;
+ y.encode(&mut encoding_buffer);
+ joint_rand_part_xof.update(&encoding_buffer);
+ encoding_buffer.clear();
+ }
+
+ helper_joint_rand_parts.push(joint_rand_part_xof.into_seed());
+
+ Some(joint_rand_blind)
+ } else {
+ for (x, y) in leader_measurement_share
+ .iter_mut()
+ .zip(measurement_share_prng)
+ {
+ *x -= y;
+ }
+ None
+ };
+ let helper =
+ HelperShare::from_seeds(measurement_share_seed, proof_share_seed, joint_rand_blind);
+ helper_shares.push(helper);
+ }
+
+ let mut leader_blind_opt = None;
+ let public_share = Prio3PublicShare {
+ joint_rand_parts: helper_joint_rand_parts
+ .as_ref()
+ .map(|helper_joint_rand_parts| {
+ let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap();
+ let leader_blind = Seed::from_bytes(leader_blind_bytes);
+
+ let mut joint_rand_part_xof = P::init(
+ leader_blind.as_ref(),
+ &Self::domain_separation_tag(DST_JOINT_RAND_PART),
+ );
+ joint_rand_part_xof.update(&[0]); // Aggregator ID
+ joint_rand_part_xof.update(nonce);
+ let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
+ for x in leader_measurement_share.iter() {
+ x.encode(&mut encoding_buffer);
+ joint_rand_part_xof.update(&encoding_buffer);
+ encoding_buffer.clear();
+ }
+ leader_blind_opt = Some(leader_blind);
+
+ let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed();
+
+ let mut vec = Vec::with_capacity(self.num_aggregators());
+ vec.push(leader_joint_rand_seed_part);
+ vec.extend(helper_joint_rand_parts.iter().cloned());
+ vec
+ }),
+ };
+
+ // Compute the joint randomness.
+ let joint_rand: Vec<T::Field> = public_share
+ .joint_rand_parts
+ .as_ref()
+ .map(|joint_rand_parts| {
+ let joint_rand_seed = Self::derive_joint_rand_seed(joint_rand_parts.iter());
+ P::seed_stream(
+ &joint_rand_seed,
+ &Self::domain_separation_tag(DST_JOINT_RANDOMNESS),
+ &[],
+ )
+ .into_field_vec(self.typ.joint_rand_len())
+ })
+ .unwrap_or_default();
+
+ // Run the proof-generation algorithm.
+ let prove_rand_seed = random_seeds.next().unwrap().try_into().unwrap();
+ let prove_rand = P::seed_stream(
+ &Seed::from_bytes(prove_rand_seed),
+ &Self::domain_separation_tag(DST_PROVE_RANDOMNESS),
+ &[],
+ )
+ .into_field_vec(self.typ.prove_rand_len());
+ let mut leader_proof_share =
+ self.typ
+ .prove(&encoded_measurement, &prove_rand, &joint_rand)?;
+
+ // Generate the proof shares and distribute the joint randomness seed hints.
+ for (j, helper) in helper_shares.iter_mut().enumerate() {
+ let proof_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
+ &helper.proof_share,
+ &Self::domain_separation_tag(DST_PROOF_SHARE),
+ &[j as u8 + 1],
+ ));
+ for (x, y) in leader_proof_share
+ .iter_mut()
+ .zip(proof_share_prng)
+ .take(self.typ.proof_len())
+ {
+ *x -= y;
+ }
+ }
+
+ // Prep the output messages.
+ let mut out = Vec::with_capacity(num_aggregators as usize);
+ out.push(Prio3InputShare {
+ measurement_share: Share::Leader(leader_measurement_share),
+ proof_share: Share::Leader(leader_proof_share),
+ joint_rand_blind: leader_blind_opt,
+ });
+
+ for helper in helper_shares.into_iter() {
+ out.push(Prio3InputShare {
+ measurement_share: Share::Helper(helper.measurement_share),
+ proof_share: Share::Helper(helper.proof_share),
+ joint_rand_blind: helper.joint_rand_blind,
+ });
+ }
+
+ Ok((public_share, out))
+ }
+
+ fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> {
+ if agg_id >= self.num_aggregators as usize {
+ return Err(VdafError::Uncategorized("unexpected aggregator id".into()));
+ }
+ Ok(u8::try_from(agg_id).unwrap())
+ }
+}
+
+impl<T, P, const SEED_SIZE: usize> Vdaf for Prio3<T, P, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ const ID: u32 = T::ID;
+ type Measurement = T::Measurement;
+ type AggregateResult = T::AggregateResult;
+ type AggregationParam = ();
+ type PublicShare = Prio3PublicShare<SEED_SIZE>;
+ type InputShare = Prio3InputShare<T::Field, SEED_SIZE>;
+ type OutputShare = OutputShare<T::Field>;
+ type AggregateShare = AggregateShare<T::Field>;
+
+ fn num_aggregators(&self) -> usize {
+ self.num_aggregators as usize
+ }
+}
+
+/// Message broadcast by the [`Client`] to every [`Aggregator`] during the Sharding phase.
+#[derive(Clone, Debug)]
+pub struct Prio3PublicShare<const SEED_SIZE: usize> {
+ /// Contributions to the joint randomness from every aggregator's share.
+ joint_rand_parts: Option<Vec<Seed<SEED_SIZE>>>,
+}
+
+impl<const SEED_SIZE: usize> Encode for Prio3PublicShare<SEED_SIZE> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() {
+ for part in joint_rand_parts.iter() {
+ part.encode(bytes);
+ }
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() {
+ // Each seed has the same size.
+ Some(SEED_SIZE * joint_rand_parts.len())
+ } else {
+ Some(0)
+ }
+ }
+}
+
+impl<const SEED_SIZE: usize> PartialEq for Prio3PublicShare<SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<const SEED_SIZE: usize> Eq for Prio3PublicShare<SEED_SIZE> {}
+
+impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PublicShare<SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the presence or absence of the joint_rand_parts.
+ option_ct_eq(
+ self.joint_rand_parts.as_deref(),
+ other.joint_rand_parts.as_deref(),
+ )
+ }
+}
+
+impl<T, P, const SEED_SIZE: usize> ParameterizedDecode<Prio3<T, P, SEED_SIZE>>
+ for Prio3PublicShare<SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3<T, P, SEED_SIZE>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ if decoding_parameter.typ.joint_rand_len() > 0 {
+ let joint_rand_parts = iter::repeat_with(|| Seed::<SEED_SIZE>::decode(bytes))
+ .take(decoding_parameter.num_aggregators.into())
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(Self {
+ joint_rand_parts: Some(joint_rand_parts),
+ })
+ } else {
+ Ok(Self {
+ joint_rand_parts: None,
+ })
+ }
+ }
+}
+
+/// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase.
+#[derive(Clone, Debug)]
+pub struct Prio3InputShare<F, const SEED_SIZE: usize> {
+ /// The measurement share.
+ measurement_share: Share<F, SEED_SIZE>,
+
+ /// The proof share.
+ proof_share: Share<F, SEED_SIZE>,
+
+ /// Blinding seed used by the Aggregator to compute the joint randomness. This field is optional
+ /// because not every [`Type`] requires joint randomness.
+ joint_rand_blind: Option<Seed<SEED_SIZE>>,
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3InputShare<F, SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3InputShare<F, SEED_SIZE> {}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3InputShare<F, SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the presence or absence of the joint_rand_blind.
+ option_ct_eq(
+ self.joint_rand_blind.as_ref(),
+ other.joint_rand_blind.as_ref(),
+ ) & self.measurement_share.ct_eq(&other.measurement_share)
+ & self.proof_share.ct_eq(&other.proof_share)
+ }
+}
+
+impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode for Prio3InputShare<F, SEED_SIZE> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if matches!(
+ (&self.measurement_share, &self.proof_share),
+ (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_))
+ ) {
+ panic!("tried to encode input share with ambiguous encoding")
+ }
+
+ self.measurement_share.encode(bytes);
+ self.proof_share.encode(bytes);
+ if let Some(ref blind) = self.joint_rand_blind {
+ blind.encode(bytes);
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ let mut len = self.measurement_share.encoded_len()? + self.proof_share.encoded_len()?;
+ if let Some(ref blind) = self.joint_rand_blind {
+ len += blind.encoded_len()?;
+ }
+ Some(len)
+ }
+}
+
+impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)>
+ for Prio3InputShare<T::Field, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ fn decode_with_param(
+ (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let agg_id = prio3
+ .role_try_from(*agg_id)
+ .map_err(|e| CodecError::Other(Box::new(e)))?;
+ let (input_decoder, proof_decoder) = if agg_id == 0 {
+ (
+ ShareDecodingParameter::Leader(prio3.typ.input_len()),
+ ShareDecodingParameter::Leader(prio3.typ.proof_len()),
+ )
+ } else {
+ (
+ ShareDecodingParameter::Helper,
+ ShareDecodingParameter::Helper,
+ )
+ };
+
+ let measurement_share = Share::decode_with_param(&input_decoder, bytes)?;
+ let proof_share = Share::decode_with_param(&proof_decoder, bytes)?;
+ let joint_rand_blind = if prio3.typ.joint_rand_len() > 0 {
+ let blind = Seed::decode(bytes)?;
+ Some(blind)
+ } else {
+ None
+ };
+
+ Ok(Prio3InputShare {
+ measurement_share,
+ proof_share,
+ joint_rand_blind,
+ })
+ }
+}
+
+#[derive(Clone, Debug)]
+/// Message broadcast by each [`Aggregator`] in each round of the Preparation phase.
+pub struct Prio3PrepareShare<F, const SEED_SIZE: usize> {
+ /// A share of the FLP verifier message. (See [`Type`].)
+ verifier: Vec<F>,
+
+ /// A part of the joint randomness seed.
+ joint_rand_part: Option<Seed<SEED_SIZE>>,
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareShare<F, SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareShare<F, SEED_SIZE> {}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareShare<F, SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the presence or absence of the joint_rand_part.
+ option_ct_eq(
+ self.joint_rand_part.as_ref(),
+ other.joint_rand_part.as_ref(),
+ ) & self.verifier.ct_eq(&other.verifier)
+ }
+}
+
+impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode
+ for Prio3PrepareShare<F, SEED_SIZE>
+{
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ for x in &self.verifier {
+ x.encode(bytes);
+ }
+ if let Some(ref seed) = self.joint_rand_part {
+ seed.encode(bytes);
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ // Each element of the verifier has the same size.
+ let mut len = F::ENCODED_SIZE * self.verifier.len();
+ if let Some(ref seed) = self.joint_rand_part {
+ len += seed.encoded_len()?;
+ }
+ Some(len)
+ }
+}
+
+impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
+ ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareShare<F, SEED_SIZE>
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len);
+ for _ in 0..decoding_parameter.verifier_len {
+ verifier.push(F::decode(bytes)?);
+ }
+
+ let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareShare {
+ verifier,
+ joint_rand_part,
+ })
+ }
+}
+
+#[derive(Clone, Debug)]
+/// Result of combining a round of [`Prio3PrepareShare`] messages.
+pub struct Prio3PrepareMessage<const SEED_SIZE: usize> {
+ /// The joint randomness seed computed by the Aggregators.
+ joint_rand_seed: Option<Seed<SEED_SIZE>>,
+}
+
+impl<const SEED_SIZE: usize> PartialEq for Prio3PrepareMessage<SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<const SEED_SIZE: usize> Eq for Prio3PrepareMessage<SEED_SIZE> {}
+
+impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareMessage<SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the presnce or absence of the joint_rand_seed.
+ option_ct_eq(
+ self.joint_rand_seed.as_ref(),
+ other.joint_rand_seed.as_ref(),
+ )
+ }
+}
+
+impl<const SEED_SIZE: usize> Encode for Prio3PrepareMessage<SEED_SIZE> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encode(bytes);
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encoded_len()
+ } else {
+ Some(0)
+ }
+ }
+}
+
+impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
+ ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareMessage<SEED_SIZE>
+{
+ fn decode_with_param(
+ decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>,
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareMessage { joint_rand_seed })
+ }
+}
+
+impl<T, P, const SEED_SIZE: usize> Client<16> for Prio3<T, P, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ #[allow(clippy::type_complexity)]
+ fn shard(
+ &self,
+ measurement: &T::Measurement,
+ nonce: &[u8; 16],
+ ) -> Result<(Self::PublicShare, Vec<Prio3InputShare<T::Field, SEED_SIZE>>), VdafError> {
+ let mut random = vec![0u8; self.random_size()];
+ getrandom::getrandom(&mut random)?;
+ self.shard_with_random(measurement, nonce, &random)
+ }
+}
+
+/// State of each [`Aggregator`] during the Preparation phase.
+#[derive(Clone)]
+pub struct Prio3PrepareState<F, const SEED_SIZE: usize> {
+ measurement_share: Share<F, SEED_SIZE>,
+ joint_rand_seed: Option<Seed<SEED_SIZE>>,
+ agg_id: u8,
+ verifier_len: usize,
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareState<F, SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareState<F, SEED_SIZE> {}
+
+impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareState<F, SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ // We allow short-circuiting on the presence or absence of the joint_rand_seed, as well as
+ // the aggregator ID & verifier length parameters.
+ if self.agg_id != other.agg_id || self.verifier_len != other.verifier_len {
+ return Choice::from(0);
+ }
+
+ option_ct_eq(
+ self.joint_rand_seed.as_ref(),
+ other.joint_rand_seed.as_ref(),
+ ) & self.measurement_share.ct_eq(&other.measurement_share)
+ }
+}
+
+impl<F, const SEED_SIZE: usize> Debug for Prio3PrepareState<F, SEED_SIZE> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Prio3PrepareState")
+ .field("measurement_share", &"[redacted]")
+ .field(
+ "joint_rand_seed",
+ match self.joint_rand_seed {
+ Some(_) => &"Some([redacted])",
+ None => &"None",
+ },
+ )
+ .field("agg_id", &self.agg_id)
+ .field("verifier_len", &self.verifier_len)
+ .finish()
+ }
+}
+
+impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode
+ for Prio3PrepareState<F, SEED_SIZE>
+{
+ /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ self.measurement_share.encode(bytes);
+ if let Some(ref seed) = self.joint_rand_seed {
+ seed.encode(bytes);
+ }
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ let mut len = self.measurement_share.encoded_len()?;
+ if let Some(ref seed) = self.joint_rand_seed {
+ len += seed.encoded_len()?;
+ }
+ Some(len)
+ }
+}
+
+impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)>
+ for Prio3PrepareState<T::Field, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ fn decode_with_param(
+ (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ let agg_id = prio3
+ .role_try_from(*agg_id)
+ .map_err(|e| CodecError::Other(Box::new(e)))?;
+
+ let share_decoder = if agg_id == 0 {
+ ShareDecodingParameter::Leader(prio3.typ.input_len())
+ } else {
+ ShareDecodingParameter::Helper
+ };
+ let measurement_share = Share::decode_with_param(&share_decoder, bytes)?;
+
+ let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 {
+ Some(Seed::decode(bytes)?)
+ } else {
+ None
+ };
+
+ Ok(Self {
+ measurement_share,
+ joint_rand_seed,
+ agg_id,
+ verifier_len: prio3.typ.verifier_len(),
+ })
+ }
+}
+
+impl<T, P, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16> for Prio3<T, P, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ type PrepareState = Prio3PrepareState<T::Field, SEED_SIZE>;
+ type PrepareShare = Prio3PrepareShare<T::Field, SEED_SIZE>;
+ type PrepareMessage = Prio3PrepareMessage<SEED_SIZE>;
+
+ /// Begins the Prep process with the other aggregators. The result of this process is
+ /// the aggregator's output share.
+ #[allow(clippy::type_complexity)]
+ fn prepare_init(
+ &self,
+ verify_key: &[u8; SEED_SIZE],
+ agg_id: usize,
+ _agg_param: &Self::AggregationParam,
+ nonce: &[u8; 16],
+ public_share: &Self::PublicShare,
+ msg: &Prio3InputShare<T::Field, SEED_SIZE>,
+ ) -> Result<
+ (
+ Prio3PrepareState<T::Field, SEED_SIZE>,
+ Prio3PrepareShare<T::Field, SEED_SIZE>,
+ ),
+ VdafError,
+ > {
+ let agg_id = self.role_try_from(agg_id)?;
+ let mut query_rand_xof = P::init(
+ verify_key,
+ &Self::domain_separation_tag(DST_QUERY_RANDOMNESS),
+ );
+ query_rand_xof.update(nonce);
+ let query_rand = query_rand_xof
+ .into_seed_stream()
+ .into_field_vec(self.typ.query_rand_len());
+
+ // Create a reference to the (expanded) measurement share.
+ let expanded_measurement_share: Option<Vec<T::Field>> = match msg.measurement_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => Some(
+ P::seed_stream(
+ seed,
+ &Self::domain_separation_tag(DST_MEASUREMENT_SHARE),
+ &[agg_id],
+ )
+ .into_field_vec(self.typ.input_len()),
+ ),
+ };
+ let measurement_share = match msg.measurement_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_measurement_share.as_ref().unwrap(),
+ };
+
+ // Create a reference to the (expanded) proof share.
+ let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share {
+ Share::Leader(_) => None,
+ Share::Helper(ref seed) => Some(
+ P::seed_stream(
+ seed,
+ &Self::domain_separation_tag(DST_PROOF_SHARE),
+ &[agg_id],
+ )
+ .into_field_vec(self.typ.proof_len()),
+ ),
+ };
+ let proof_share = match msg.proof_share {
+ Share::Leader(ref data) => data,
+ Share::Helper(_) => expanded_proof_share.as_ref().unwrap(),
+ };
+
+ // Compute the joint randomness.
+ let (joint_rand_seed, joint_rand_part, joint_rand) = if self.typ.joint_rand_len() > 0 {
+ let mut joint_rand_part_xof = P::init(
+ msg.joint_rand_blind.as_ref().unwrap().as_ref(),
+ &Self::domain_separation_tag(DST_JOINT_RAND_PART),
+ );
+ joint_rand_part_xof.update(&[agg_id]);
+ joint_rand_part_xof.update(nonce);
+ let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE);
+ for x in measurement_share {
+ x.encode(&mut encoding_buffer);
+ joint_rand_part_xof.update(&encoding_buffer);
+ encoding_buffer.clear();
+ }
+ let own_joint_rand_part = joint_rand_part_xof.into_seed();
+
+ // Make an iterator over the joint randomness parts, but use this aggregator's
+ // contribution, computed from the input share, in lieu of the the corresponding part
+ // from the public share.
+ //
+ // The locally computed part should match the part from the public share for honestly
+ // generated reports. If they do not match, the joint randomness seed check during the
+ // next round of preparation should fail.
+ let corrected_joint_rand_parts = public_share
+ .joint_rand_parts
+ .iter()
+ .flatten()
+ .take(agg_id as usize)
+ .chain(iter::once(&own_joint_rand_part))
+ .chain(
+ public_share
+ .joint_rand_parts
+ .iter()
+ .flatten()
+ .skip(agg_id as usize + 1),
+ );
+
+ let joint_rand_seed = Self::derive_joint_rand_seed(corrected_joint_rand_parts);
+
+ let joint_rand = P::seed_stream(
+ &joint_rand_seed,
+ &Self::domain_separation_tag(DST_JOINT_RANDOMNESS),
+ &[],
+ )
+ .into_field_vec(self.typ.joint_rand_len());
+ (Some(joint_rand_seed), Some(own_joint_rand_part), joint_rand)
+ } else {
+ (None, None, Vec::new())
+ };
+
+ // Run the query-generation algorithm.
+ let verifier_share = self.typ.query(
+ measurement_share,
+ proof_share,
+ &query_rand,
+ &joint_rand,
+ self.num_aggregators as usize,
+ )?;
+
+ Ok((
+ Prio3PrepareState {
+ measurement_share: msg.measurement_share.clone(),
+ joint_rand_seed,
+ agg_id,
+ verifier_len: verifier_share.len(),
+ },
+ Prio3PrepareShare {
+ verifier: verifier_share,
+ joint_rand_part,
+ },
+ ))
+ }
+
+ fn prepare_shares_to_prepare_message<
+ M: IntoIterator<Item = Prio3PrepareShare<T::Field, SEED_SIZE>>,
+ >(
+ &self,
+ _: &Self::AggregationParam,
+ inputs: M,
+ ) -> Result<Prio3PrepareMessage<SEED_SIZE>, VdafError> {
+ let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()];
+ let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators());
+ let mut count = 0;
+ for share in inputs.into_iter() {
+ count += 1;
+
+ if share.verifier.len() != verifier.len() {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected verifier share length: got {}; want {}",
+ share.verifier.len(),
+ verifier.len(),
+ )));
+ }
+
+ if self.typ.joint_rand_len() > 0 {
+ let joint_rand_seed_part = share.joint_rand_part.unwrap();
+ joint_rand_parts.push(joint_rand_seed_part);
+ }
+
+ for (x, y) in verifier.iter_mut().zip(share.verifier) {
+ *x += y;
+ }
+ }
+
+ if count != self.num_aggregators {
+ return Err(VdafError::Uncategorized(format!(
+ "unexpected message count: got {}; want {}",
+ count, self.num_aggregators,
+ )));
+ }
+
+ // Check the proof verifier.
+ match self.typ.decide(&verifier) {
+ Ok(true) => (),
+ Ok(false) => {
+ return Err(VdafError::Uncategorized(
+ "proof verifier check failed".into(),
+ ))
+ }
+ Err(err) => return Err(VdafError::from(err)),
+ };
+
+ let joint_rand_seed = if self.typ.joint_rand_len() > 0 {
+ Some(Self::derive_joint_rand_seed(joint_rand_parts.iter()))
+ } else {
+ None
+ };
+
+ Ok(Prio3PrepareMessage { joint_rand_seed })
+ }
+
+ fn prepare_next(
+ &self,
+ step: Prio3PrepareState<T::Field, SEED_SIZE>,
+ msg: Prio3PrepareMessage<SEED_SIZE>,
+ ) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> {
+ if self.typ.joint_rand_len() > 0 {
+ // Check that the joint randomness was correct.
+ if step
+ .joint_rand_seed
+ .as_ref()
+ .unwrap()
+ .ct_ne(msg.joint_rand_seed.as_ref().unwrap())
+ .into()
+ {
+ return Err(VdafError::Uncategorized(
+ "joint randomness mismatch".to_string(),
+ ));
+ }
+ }
+
+ // Compute the output share.
+ let measurement_share = match step.measurement_share {
+ Share::Leader(data) => data,
+ Share::Helper(seed) => {
+ let dst = Self::domain_separation_tag(DST_MEASUREMENT_SHARE);
+ P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len())
+ }
+ };
+
+ let output_share = match self.typ.truncate(measurement_share) {
+ Ok(data) => OutputShare(data),
+ Err(err) => {
+ return Err(VdafError::from(err));
+ }
+ };
+
+ Ok(PrepareTransition::Finish(output_share))
+ }
+
+ /// Aggregates a sequence of output shares into an aggregate share.
+ fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>(
+ &self,
+ _agg_param: &(),
+ output_shares: It,
+ ) -> Result<AggregateShare<T::Field>, VdafError> {
+ let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
+ for output_share in output_shares.into_iter() {
+ agg_share.accumulate(&output_share)?;
+ }
+
+ Ok(agg_share)
+ }
+}
+
+#[cfg(feature = "experimental")]
+impl<T, P, S, const SEED_SIZE: usize> AggregatorWithNoise<SEED_SIZE, 16, S>
+ for Prio3<T, P, SEED_SIZE>
+where
+ T: TypeWithNoise<S>,
+ P: Xof<SEED_SIZE>,
+ S: DifferentialPrivacyStrategy,
+{
+ fn add_noise_to_agg_share(
+ &self,
+ dp_strategy: &S,
+ _agg_param: &Self::AggregationParam,
+ agg_share: &mut Self::AggregateShare,
+ num_measurements: usize,
+ ) -> Result<(), VdafError> {
+ self.typ
+ .add_noise_to_result(dp_strategy, &mut agg_share.0, num_measurements)?;
+ Ok(())
+ }
+}
+
+impl<T, P, const SEED_SIZE: usize> Collector for Prio3<T, P, SEED_SIZE>
+where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ /// Combines aggregate shares into the aggregate result.
+ fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>(
+ &self,
+ _agg_param: &Self::AggregationParam,
+ agg_shares: It,
+ num_measurements: usize,
+ ) -> Result<T::AggregateResult, VdafError> {
+ let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]);
+ for agg_share in agg_shares.into_iter() {
+ agg.merge(&agg_share)?;
+ }
+
+ Ok(self.typ.decode_result(&agg.0, num_measurements)?)
+ }
+}
+
+#[derive(Clone)]
+struct HelperShare<const SEED_SIZE: usize> {
+ measurement_share: Seed<SEED_SIZE>,
+ proof_share: Seed<SEED_SIZE>,
+ joint_rand_blind: Option<Seed<SEED_SIZE>>,
+}
+
+impl<const SEED_SIZE: usize> HelperShare<SEED_SIZE> {
+ fn from_seeds(
+ measurement_share: [u8; SEED_SIZE],
+ proof_share: [u8; SEED_SIZE],
+ joint_rand_blind: Option<[u8; SEED_SIZE]>,
+ ) -> Self {
+ HelperShare {
+ measurement_share: Seed::from_bytes(measurement_share),
+ proof_share: Seed::from_bytes(proof_share),
+ joint_rand_blind: joint_rand_blind.map(Seed::from_bytes),
+ }
+ }
+}
+
+fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> {
+ if num_aggregators == 0 {
+ return Err(VdafError::Uncategorized(format!(
+ "at least one aggregator is required; got {num_aggregators}"
+ )));
+ } else if num_aggregators > 254 {
+ return Err(VdafError::Uncategorized(format!(
+ "number of aggregators must not exceed 254; got {num_aggregators}"
+ )));
+ }
+
+ Ok(())
+}
+
+impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())>
+ for OutputShare<F>
+where
+ F: FieldElement,
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ fn decode_with_param(
+ (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ decode_fieldvec(vdaf.output_len(), bytes).map(Self)
+ }
+}
+
+impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())>
+ for AggregateShare<F>
+where
+ F: FieldElement,
+ T: Type,
+ P: Xof<SEED_SIZE>,
+{
+ fn decode_with_param(
+ (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()),
+ bytes: &mut Cursor<&[u8]>,
+ ) -> Result<Self, CodecError> {
+ decode_fieldvec(vdaf.output_len(), bytes).map(Self)
+ }
+}
+
+// This function determines equality between two optional, constant-time comparable values. It
+// short-circuits on the existence (but not contents) of the values -- a timing side-channel may
+// reveal whether the values match on Some or None.
+#[inline]
+fn option_ct_eq<T>(left: Option<&T>, right: Option<&T>) -> Choice
+where
+ T: ConstantTimeEq + ?Sized,
+{
+ match (left, right) {
+ (Some(left), Some(right)) => left.ct_eq(right),
+ (None, None) => Choice::from(1),
+ _ => Choice::from(0),
+ }
+}
+
+/// This is a polyfill for `usize::ilog2()`, which is only available in Rust 1.67 and later. It is
+/// based on the implementation in the standard library. It can be removed when the MSRV has been
+/// advanced past 1.67.
+///
+/// # Panics
+///
+/// This function will panic if `input` is zero.
+fn ilog2(input: usize) -> u32 {
+ if input == 0 {
+ panic!("Tried to take the logarithm of zero");
+ }
+ (usize::BITS - 1) - input.leading_zeros()
+}
+
+/// Finds the optimal choice of chunk length for [`Prio3Histogram`] or [`Prio3SumVec`], given its
+/// encoded measurement length. For [`Prio3Histogram`], the measurement length is equal to the
+/// length parameter. For [`Prio3SumVec`], the measurement length is equal to the product of the
+/// length and bits parameters.
+pub fn optimal_chunk_length(measurement_length: usize) -> usize {
+ if measurement_length <= 1 {
+ return 1;
+ }
+
+ /// Candidate set of parameter choices for the parallel sum optimization.
+ struct Candidate {
+ gadget_calls: usize,
+ chunk_length: usize,
+ }
+
+ let max_log2 = ilog2(measurement_length + 1);
+ let best_opt = (1..=max_log2)
+ .rev()
+ .map(|log2| {
+ let gadget_calls = (1 << log2) - 1;
+ let chunk_length = (measurement_length + gadget_calls - 1) / gadget_calls;
+ Candidate {
+ gadget_calls,
+ chunk_length,
+ }
+ })
+ .min_by_key(|candidate| {
+ // Compute the proof length, in field elements, for either Prio3Histogram or Prio3SumVec
+ (candidate.chunk_length * 2)
+ + 2 * ((1 + candidate.gadget_calls).next_power_of_two() - 1)
+ });
+ // Unwrap safety: max_log2 must be at least 1, because smaller measurement_length inputs are
+ // dealt with separately. Thus, the range iterator that the search is over will be nonempty,
+ // and min_by_key() will always return Some.
+ best_opt.unwrap().chunk_length
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[cfg(feature = "experimental")]
+ use crate::flp::gadgets::ParallelSumGadget;
+ use crate::vdaf::{
+ equality_comparison_test, fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare,
+ };
+ use assert_matches::assert_matches;
+ #[cfg(feature = "experimental")]
+ use fixed::{
+ types::extra::{U15, U31, U63},
+ FixedI16, FixedI32, FixedI64,
+ };
+ #[cfg(feature = "experimental")]
+ use fixed_macro::fixed;
+ use rand::prelude::*;
+
+ #[test]
+ fn test_prio3_count() {
+ let prio3 = Prio3::new_count(2).unwrap();
+
+ assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3);
+
+ let mut nonce = [0; 16];
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ thread_rng().fill(&mut nonce[..]);
+
+ let (public_share, input_shares) = prio3.shard(&0, &nonce).unwrap();
+ run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap();
+
+ let (public_share, input_shares) = prio3.shard(&1, &nonce).unwrap();
+ run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap();
+
+ test_serialization(&prio3, &1, &nonce).unwrap();
+
+ let prio3_extra_helper = Prio3::new_count(3).unwrap();
+ assert_eq!(
+ run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(),
+ 3,
+ );
+ }
+
+ #[test]
+ fn test_prio3_sum() {
+ let prio3 = Prio3::new_sum(3, 16).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(),
+ (1 << 16) + 1
+ );
+
+ let mut verify_key = [0; 16];
+ thread_rng().fill(&mut verify_key[..]);
+ let nonce = [0; 16];
+
+ let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
+ input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255;
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
+ assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap();
+ assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ test_serialization(&prio3, &1, &nonce).unwrap();
+ }
+
+ #[test]
+ fn test_prio3_sum_vec() {
+ let prio3 = Prio3::new_sum_vec(2, 2, 20, 4).unwrap();
+ assert_eq!(
+ run_vdaf(
+ &prio3,
+ &(),
+ [
+ vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
+ vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
+ vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
+ ]
+ )
+ .unwrap(),
+ vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
+ );
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_prio3_sum_vec_multithreaded() {
+ let prio3 = Prio3::new_sum_vec_multithreaded(2, 2, 20, 4).unwrap();
+ assert_eq!(
+ run_vdaf(
+ &prio3,
+ &(),
+ [
+ vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1],
+ vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0],
+ vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
+ ]
+ )
+ .unwrap(),
+ vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2],
+ );
+ }
+
+ #[test]
+ #[cfg(feature = "experimental")]
+ fn test_prio3_bounded_fpvec_sum_unaligned() {
+ type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>;
+ #[cfg(feature = "multithreaded")]
+ type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>;
+ let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum;
+ #[cfg(feature = "multithreaded")]
+ let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
+
+ {
+ const SIZE: usize = 5;
+ let fp32_0 = fixed!(0: I1F31);
+
+ // 32 bit fixedpoint, non-power-of-2 vector, single-threaded
+ {
+ let prio3_32 = ctor_32(2, SIZE).unwrap();
+ test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_32);
+ }
+
+ // 32 bit fixedpoint, non-power-of-2 vector, multi-threaded
+ #[cfg(feature = "multithreaded")]
+ {
+ let prio3_mt_32 = ctor_mt_32(2, SIZE).unwrap();
+ test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_mt_32);
+ }
+ }
+
+ fn test_fixed_vec<Fx, PE, M, const SIZE: usize>(
+ fp_0: Fx,
+ prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofShake128, 16>,
+ ) where
+ Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>,
+ PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
+ M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static,
+ {
+ let fp_vec = vec![fp_0; SIZE];
+
+ let measurements = [fp_vec.clone(), fp_vec];
+ assert_eq!(
+ run_vdaf(&prio3, &(), measurements).unwrap(),
+ vec![0.0; SIZE]
+ );
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "experimental")]
+ fn test_prio3_bounded_fpvec_sum() {
+ type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>;
+ let ctor_16 = P::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum;
+ let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum;
+ let ctor_64 = P::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum;
+
+ #[cfg(feature = "multithreaded")]
+ type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>;
+ #[cfg(feature = "multithreaded")]
+ let ctor_mt_16 = PM::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
+ #[cfg(feature = "multithreaded")]
+ let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
+ #[cfg(feature = "multithreaded")]
+ let ctor_mt_64 = PM::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum_multithreaded;
+
+ {
+ // 16 bit fixedpoint
+ let fp16_4_inv = fixed!(0.25: I1F15);
+ let fp16_8_inv = fixed!(0.125: I1F15);
+ let fp16_16_inv = fixed!(0.0625: I1F15);
+
+ // two aggregators, three entries per vector.
+ {
+ let prio3_16 = ctor_16(2, 3).unwrap();
+ test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16);
+ }
+
+ #[cfg(feature = "multithreaded")]
+ {
+ let prio3_16_mt = ctor_mt_16(2, 3).unwrap();
+ test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt);
+ }
+ }
+
+ {
+ // 32 bit fixedpoint
+ let fp32_4_inv = fixed!(0.25: I1F31);
+ let fp32_8_inv = fixed!(0.125: I1F31);
+ let fp32_16_inv = fixed!(0.0625: I1F31);
+
+ {
+ let prio3_32 = ctor_32(2, 3).unwrap();
+ test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32);
+ }
+
+ #[cfg(feature = "multithreaded")]
+ {
+ let prio3_32_mt = ctor_mt_32(2, 3).unwrap();
+ test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt);
+ }
+ }
+
+ {
+ // 64 bit fixedpoint
+ let fp64_4_inv = fixed!(0.25: I1F63);
+ let fp64_8_inv = fixed!(0.125: I1F63);
+ let fp64_16_inv = fixed!(0.0625: I1F63);
+
+ {
+ let prio3_64 = ctor_64(2, 3).unwrap();
+ test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64);
+ }
+
+ #[cfg(feature = "multithreaded")]
+ {
+ let prio3_64_mt = ctor_mt_64(2, 3).unwrap();
+ test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt);
+ }
+ }
+
+ fn test_fixed<Fx, PE, M>(
+ fp_4_inv: Fx,
+ fp_8_inv: Fx,
+ fp_16_inv: Fx,
+ prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofShake128, 16>,
+ ) where
+ Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>,
+ PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static,
+ M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static,
+ {
+ let fp_vec1 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
+ let fp_vec2 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
+
+ let fp_vec3 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
+ let fp_vec4 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv];
+
+ let fp_vec5 = vec![fp_4_inv, -fp_8_inv, -fp_16_inv];
+ let fp_vec6 = vec![fp_4_inv, fp_8_inv, fp_16_inv];
+
+ // positive entries
+ let fp_list = [fp_vec1, fp_vec2];
+ assert_eq!(
+ run_vdaf(&prio3, &(), fp_list).unwrap(),
+ vec!(0.5, 0.25, 0.125),
+ );
+
+ // negative entries
+ let fp_list2 = [fp_vec3, fp_vec4];
+ assert_eq!(
+ run_vdaf(&prio3, &(), fp_list2).unwrap(),
+ vec!(-0.5, -0.25, -0.125),
+ );
+
+ // both
+ let fp_list3 = [fp_vec5, fp_vec6];
+ assert_eq!(
+ run_vdaf(&prio3, &(), fp_list3).unwrap(),
+ vec!(0.5, 0.0, 0.0),
+ );
+
+ let mut verify_key = [0; 16];
+ let mut nonce = [0; 16];
+ thread_rng().fill(&mut verify_key);
+ thread_rng().fill(&mut nonce);
+
+ let (public_share, mut input_shares) = prio3
+ .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
+ .unwrap();
+ input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255;
+ let result =
+ run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3
+ .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
+ .unwrap();
+ assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result =
+ run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ let (public_share, mut input_shares) = prio3
+ .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce)
+ .unwrap();
+ assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => {
+ data[0] += Field128::one();
+ });
+ let result =
+ run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares);
+ assert_matches!(result, Err(VdafError::Uncategorized(_)));
+
+ test_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce).unwrap();
+ }
+ }
+
+ #[test]
+ fn test_prio3_histogram() {
+ let prio3 = Prio3::new_histogram(2, 4, 2).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(),
+ vec![1, 1, 1, 1]
+ );
+ assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]);
+ test_serialization(&prio3, &3, &[0; 16]).unwrap();
+ }
+
+ #[test]
+ #[cfg(feature = "multithreaded")]
+ fn test_prio3_histogram_multithreaded() {
+ let prio3 = Prio3::new_histogram_multithreaded(2, 4, 2).unwrap();
+
+ assert_eq!(
+ run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(),
+ vec![1, 1, 1, 1]
+ );
+ assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]);
+ assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]);
+ test_serialization(&prio3, &3, &[0; 16]).unwrap();
+ }
+
+ #[test]
+ fn test_prio3_average() {
+ let prio3 = Prio3::new_average(2, 64).unwrap();
+
+ assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64);
+ assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64);
+ assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64);
+ assert_eq!(
+ run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(),
+ 207.5f64
+ );
+ }
+
+ #[test]
+ fn test_prio3_input_share() {
+ let prio3 = Prio3::new_sum(5, 16).unwrap();
+ let (_public_share, input_shares) = prio3.shard(&1, &[0; 16]).unwrap();
+
+ // Check that seed shares are distinct.
+ for (i, x) in input_shares.iter().enumerate() {
+ for (j, y) in input_shares.iter().enumerate() {
+ if i != j {
+ if let (Share::Helper(left), Share::Helper(right)) =
+ (&x.measurement_share, &y.measurement_share)
+ {
+ assert_ne!(left, right);
+ }
+
+ if let (Share::Helper(left), Share::Helper(right)) =
+ (&x.proof_share, &y.proof_share)
+ {
+ assert_ne!(left, right);
+ }
+
+ assert_ne!(x.joint_rand_blind, y.joint_rand_blind);
+ }
+ }
+ }
+ }
+
+ fn test_serialization<T, P, const SEED_SIZE: usize>(
+ prio3: &Prio3<T, P, SEED_SIZE>,
+ measurement: &T::Measurement,
+ nonce: &[u8; 16],
+ ) -> Result<(), VdafError>
+ where
+ T: Type,
+ P: Xof<SEED_SIZE>,
+ {
+ let mut verify_key = [0; SEED_SIZE];
+ thread_rng().fill(&mut verify_key[..]);
+ let (public_share, input_shares) = prio3.shard(measurement, nonce)?;
+
+ let encoded_public_share = public_share.get_encoded();
+ let decoded_public_share =
+ Prio3PublicShare::get_decoded_with_param(prio3, &encoded_public_share)
+ .expect("failed to decode public share");
+ assert_eq!(decoded_public_share, public_share);
+ assert_eq!(
+ public_share.encoded_len().unwrap(),
+ encoded_public_share.len()
+ );
+
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let encoded_input_share = input_share.get_encoded();
+ let decoded_input_share =
+ Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), &encoded_input_share)
+ .expect("failed to decode input share");
+ assert_eq!(&decoded_input_share, input_share);
+ assert_eq!(
+ input_share.encoded_len().unwrap(),
+ encoded_input_share.len()
+ );
+ }
+
+ let mut prepare_shares = Vec::new();
+ let mut last_prepare_state = None;
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (prepare_state, prepare_share) =
+ prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?;
+
+ let encoded_prepare_state = prepare_state.get_encoded();
+ let decoded_prepare_state =
+ Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &encoded_prepare_state)
+ .expect("failed to decode prepare state");
+ assert_eq!(decoded_prepare_state, prepare_state);
+ assert_eq!(
+ prepare_state.encoded_len().unwrap(),
+ encoded_prepare_state.len()
+ );
+
+ let encoded_prepare_share = prepare_share.get_encoded();
+ let decoded_prepare_share =
+ Prio3PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share)
+ .expect("failed to decode prepare share");
+ assert_eq!(decoded_prepare_share, prepare_share);
+ assert_eq!(
+ prepare_share.encoded_len().unwrap(),
+ encoded_prepare_share.len()
+ );
+
+ prepare_shares.push(prepare_share);
+ last_prepare_state = Some(prepare_state);
+ }
+
+ let prepare_message = prio3
+ .prepare_shares_to_prepare_message(&(), prepare_shares)
+ .unwrap();
+
+ let encoded_prepare_message = prepare_message.get_encoded();
+ let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param(
+ &last_prepare_state.unwrap(),
+ &encoded_prepare_message,
+ )
+ .expect("failed to decode prepare message");
+ assert_eq!(decoded_prepare_message, prepare_message);
+ assert_eq!(
+ prepare_message.encoded_len().unwrap(),
+ encoded_prepare_message.len()
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn roundtrip_output_share() {
+ let vdaf = Prio3::new_count(2).unwrap();
+ fieldvec_roundtrip_test::<Field64, Prio3Count, OutputShare<Field64>>(&vdaf, &(), 1);
+
+ let vdaf = Prio3::new_sum(2, 17).unwrap();
+ fieldvec_roundtrip_test::<Field128, Prio3Sum, OutputShare<Field128>>(&vdaf, &(), 1);
+
+ let vdaf = Prio3::new_histogram(2, 12, 3).unwrap();
+ fieldvec_roundtrip_test::<Field128, Prio3Histogram, OutputShare<Field128>>(&vdaf, &(), 12);
+ }
+
+ #[test]
+ fn roundtrip_aggregate_share() {
+ let vdaf = Prio3::new_count(2).unwrap();
+ fieldvec_roundtrip_test::<Field64, Prio3Count, AggregateShare<Field64>>(&vdaf, &(), 1);
+
+ let vdaf = Prio3::new_sum(2, 17).unwrap();
+ fieldvec_roundtrip_test::<Field128, Prio3Sum, AggregateShare<Field128>>(&vdaf, &(), 1);
+
+ let vdaf = Prio3::new_histogram(2, 12, 3).unwrap();
+ fieldvec_roundtrip_test::<Field128, Prio3Histogram, AggregateShare<Field128>>(
+ &vdaf,
+ &(),
+ 12,
+ );
+ }
+
+ #[test]
+ fn public_share_equality_test() {
+ equality_comparison_test(&[
+ Prio3PublicShare {
+ joint_rand_parts: Some(Vec::from([Seed([0])])),
+ },
+ Prio3PublicShare {
+ joint_rand_parts: Some(Vec::from([Seed([1])])),
+ },
+ Prio3PublicShare {
+ joint_rand_parts: None,
+ },
+ ])
+ }
+
+ #[test]
+ fn input_share_equality_test() {
+ equality_comparison_test(&[
+ // Default.
+ Prio3InputShare {
+ measurement_share: Share::Leader(Vec::from([0])),
+ proof_share: Share::Leader(Vec::from([1])),
+ joint_rand_blind: Some(Seed([2])),
+ },
+ // Modified measurement share.
+ Prio3InputShare {
+ measurement_share: Share::Leader(Vec::from([100])),
+ proof_share: Share::Leader(Vec::from([1])),
+ joint_rand_blind: Some(Seed([2])),
+ },
+ // Modified proof share.
+ Prio3InputShare {
+ measurement_share: Share::Leader(Vec::from([0])),
+ proof_share: Share::Leader(Vec::from([101])),
+ joint_rand_blind: Some(Seed([2])),
+ },
+ // Modified joint_rand_blind.
+ Prio3InputShare {
+ measurement_share: Share::Leader(Vec::from([0])),
+ proof_share: Share::Leader(Vec::from([1])),
+ joint_rand_blind: Some(Seed([102])),
+ },
+ // Missing joint_rand_blind.
+ Prio3InputShare {
+ measurement_share: Share::Leader(Vec::from([0])),
+ proof_share: Share::Leader(Vec::from([1])),
+ joint_rand_blind: None,
+ },
+ ])
+ }
+
+ #[test]
+ fn prepare_share_equality_test() {
+ equality_comparison_test(&[
+ // Default.
+ Prio3PrepareShare {
+ verifier: Vec::from([0]),
+ joint_rand_part: Some(Seed([1])),
+ },
+ // Modified verifier.
+ Prio3PrepareShare {
+ verifier: Vec::from([100]),
+ joint_rand_part: Some(Seed([1])),
+ },
+ // Modified joint_rand_part.
+ Prio3PrepareShare {
+ verifier: Vec::from([0]),
+ joint_rand_part: Some(Seed([101])),
+ },
+ // Missing joint_rand_part.
+ Prio3PrepareShare {
+ verifier: Vec::from([0]),
+ joint_rand_part: None,
+ },
+ ])
+ }
+
+ #[test]
+ fn prepare_message_equality_test() {
+ equality_comparison_test(&[
+ // Default.
+ Prio3PrepareMessage {
+ joint_rand_seed: Some(Seed([0])),
+ },
+ // Modified joint_rand_seed.
+ Prio3PrepareMessage {
+ joint_rand_seed: Some(Seed([100])),
+ },
+ // Missing joint_rand_seed.
+ Prio3PrepareMessage {
+ joint_rand_seed: None,
+ },
+ ])
+ }
+
+ #[test]
+ fn prepare_state_equality_test() {
+ equality_comparison_test(&[
+ // Default.
+ Prio3PrepareState {
+ measurement_share: Share::Leader(Vec::from([0])),
+ joint_rand_seed: Some(Seed([1])),
+ agg_id: 2,
+ verifier_len: 3,
+ },
+ // Modified measurement share.
+ Prio3PrepareState {
+ measurement_share: Share::Leader(Vec::from([100])),
+ joint_rand_seed: Some(Seed([1])),
+ agg_id: 2,
+ verifier_len: 3,
+ },
+ // Modified joint_rand_seed.
+ Prio3PrepareState {
+ measurement_share: Share::Leader(Vec::from([0])),
+ joint_rand_seed: Some(Seed([101])),
+ agg_id: 2,
+ verifier_len: 3,
+ },
+ // Missing joint_rand_seed.
+ Prio3PrepareState {
+ measurement_share: Share::Leader(Vec::from([0])),
+ joint_rand_seed: None,
+ agg_id: 2,
+ verifier_len: 3,
+ },
+ // Modified agg_id.
+ Prio3PrepareState {
+ measurement_share: Share::Leader(Vec::from([0])),
+ joint_rand_seed: Some(Seed([1])),
+ agg_id: 102,
+ verifier_len: 3,
+ },
+ // Modified verifier_len.
+ Prio3PrepareState {
+ measurement_share: Share::Leader(Vec::from([0])),
+ joint_rand_seed: Some(Seed([1])),
+ agg_id: 2,
+ verifier_len: 103,
+ },
+ ])
+ }
+
+ #[test]
+ fn test_optimal_chunk_length() {
+ // nonsense argument, but make sure it doesn't panic.
+ optimal_chunk_length(0);
+
+ // edge cases on either side of power-of-two jumps
+ assert_eq!(optimal_chunk_length(1), 1);
+ assert_eq!(optimal_chunk_length(2), 2);
+ assert_eq!(optimal_chunk_length(3), 1);
+ assert_eq!(optimal_chunk_length(18), 6);
+ assert_eq!(optimal_chunk_length(19), 3);
+
+ // additional arbitrary test cases
+ assert_eq!(optimal_chunk_length(40), 6);
+ assert_eq!(optimal_chunk_length(10_000), 79);
+ assert_eq!(optimal_chunk_length(100_000), 393);
+
+ // confirm that the chunk lengths are truly optimal
+ for measurement_length in [2, 3, 4, 5, 18, 19, 40] {
+ let optimal_chunk_length = optimal_chunk_length(measurement_length);
+ let optimal_proof_length = Histogram::<Field128, ParallelSum<_, _>>::new(
+ measurement_length,
+ optimal_chunk_length,
+ )
+ .unwrap()
+ .proof_len();
+ for chunk_length in 1..=measurement_length {
+ let proof_length =
+ Histogram::<Field128, ParallelSum<_, _>>::new(measurement_length, chunk_length)
+ .unwrap()
+ .proof_len();
+ assert!(proof_length >= optimal_proof_length);
+ }
+ }
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs
new file mode 100644
index 0000000000..372a2c8560
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prio3_test.rs
@@ -0,0 +1,251 @@
+// SPDX-License-Identifier: MPL-2.0
+
+use crate::{
+ codec::{Encode, ParameterizedDecode},
+ flp::Type,
+ vdaf::{
+ prio3::{Prio3, Prio3InputShare, Prio3PrepareShare, Prio3PublicShare},
+ xof::Xof,
+ Aggregator, Collector, OutputShare, PrepareTransition, Vdaf,
+ },
+};
+use serde::{Deserialize, Serialize};
+use std::{collections::HashMap, convert::TryInto, fmt::Debug};
+
+#[derive(Debug, Deserialize, Serialize)]
+struct TEncoded(#[serde(with = "hex")] Vec<u8>);
+
+impl AsRef<[u8]> for TEncoded {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[derive(Deserialize, Serialize)]
+struct TPrio3Prep<M> {
+ measurement: M,
+ #[serde(with = "hex")]
+ nonce: Vec<u8>,
+ #[serde(with = "hex")]
+ rand: Vec<u8>,
+ public_share: TEncoded,
+ input_shares: Vec<TEncoded>,
+ prep_shares: Vec<Vec<TEncoded>>,
+ prep_messages: Vec<TEncoded>,
+ out_shares: Vec<Vec<TEncoded>>,
+}
+
+#[derive(Deserialize, Serialize)]
+struct TPrio3<M> {
+ verify_key: TEncoded,
+ shares: u8,
+ prep: Vec<TPrio3Prep<M>>,
+ agg_shares: Vec<TEncoded>,
+ agg_result: serde_json::Value,
+ #[serde(flatten)]
+ other_params: HashMap<String, serde_json::Value>,
+}
+
+macro_rules! err {
+ (
+ $test_num:ident,
+ $error:expr,
+ $msg:expr
+ ) => {
+ panic!("test #{} failed: {} err: {}", $test_num, $msg, $error)
+ };
+}
+
+// TODO Generalize this method to work with any VDAF. To do so we would need to add
+// `shard_with_random()` to traits. (There may be a less invasive alternative.)
+fn check_prep_test_vec<M, T, P, const SEED_SIZE: usize>(
+ prio3: &Prio3<T, P, SEED_SIZE>,
+ verify_key: &[u8; SEED_SIZE],
+ test_num: usize,
+ t: &TPrio3Prep<M>,
+) -> Vec<OutputShare<T::Field>>
+where
+ T: Type<Measurement = M>,
+ P: Xof<SEED_SIZE>,
+{
+ let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap();
+ let (public_share, input_shares) = prio3
+ .shard_with_random(&t.measurement, &nonce, &t.rand)
+ .expect("failed to generate input shares");
+
+ assert_eq!(
+ public_share,
+ Prio3PublicShare::get_decoded_with_param(prio3, t.public_share.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (public share)")),
+ );
+ for (agg_id, want) in t.input_shares.iter().enumerate() {
+ assert_eq!(
+ input_shares[agg_id],
+ Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")),
+ "#{test_num}"
+ );
+ assert_eq!(
+ input_shares[agg_id].get_encoded(),
+ want.as_ref(),
+ "#{test_num}"
+ )
+ }
+
+ let mut states = Vec::new();
+ let mut prep_shares = Vec::new();
+ for (agg_id, input_share) in input_shares.iter().enumerate() {
+ let (state, prep_share) = prio3
+ .prepare_init(verify_key, agg_id, &(), &nonce, &public_share, input_share)
+ .unwrap_or_else(|e| err!(test_num, e, "prep state init"));
+ states.push(state);
+ prep_shares.push(prep_share);
+ }
+
+ assert_eq!(1, t.prep_shares.len(), "#{test_num}");
+ for (i, want) in t.prep_shares[0].iter().enumerate() {
+ assert_eq!(
+ prep_shares[i],
+ Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref())
+ .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
+ "#{test_num}"
+ );
+ assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{test_num}");
+ }
+
+ let inbound = prio3
+ .prepare_shares_to_prepare_message(&(), prep_shares)
+ .unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
+ assert_eq!(t.prep_messages.len(), 1);
+ assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());
+
+ let mut out_shares = Vec::new();
+ for state in states.iter_mut() {
+ match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() {
+ PrepareTransition::Finish(out_share) => {
+ out_shares.push(out_share);
+ }
+ _ => panic!("unexpected transition"),
+ }
+ }
+
+ for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
+ let got: Vec<Vec<u8>> = got.as_ref().iter().map(|x| x.get_encoded()).collect();
+ assert_eq!(got.len(), want.len());
+ for (got_elem, want_elem) in got.iter().zip(want.iter()) {
+ assert_eq!(got_elem.as_slice(), want_elem.as_ref());
+ }
+ }
+
+ out_shares
+}
+
+#[must_use]
+fn check_aggregate_test_vec<M, T, P, const SEED_SIZE: usize>(
+ prio3: &Prio3<T, P, SEED_SIZE>,
+ t: &TPrio3<M>,
+) -> T::AggregateResult
+where
+ T: Type<Measurement = M>,
+ P: Xof<SEED_SIZE>,
+{
+ let verify_key = t.verify_key.as_ref().try_into().unwrap();
+
+ let mut all_output_shares = vec![Vec::new(); prio3.num_aggregators()];
+ for (test_num, p) in t.prep.iter().enumerate() {
+ let output_shares = check_prep_test_vec(prio3, verify_key, test_num, p);
+ for (aggregator_output_shares, output_share) in
+ all_output_shares.iter_mut().zip(output_shares.into_iter())
+ {
+ aggregator_output_shares.push(output_share);
+ }
+ }
+
+ let aggregate_shares = all_output_shares
+ .into_iter()
+ .map(|aggregator_output_shares| prio3.aggregate(&(), aggregator_output_shares).unwrap())
+ .collect::<Vec<_>>();
+
+ for (got, want) in aggregate_shares.iter().zip(t.agg_shares.iter()) {
+ let got = got.get_encoded();
+ assert_eq!(got.as_slice(), want.as_ref());
+ }
+
+ prio3.unshard(&(), aggregate_shares, 1).unwrap()
+}
+
+#[test]
+fn test_vec_prio3_count() {
+ for test_vector_str in [
+ include_str!("test_vec/07/Prio3Count_0.json"),
+ include_str!("test_vec/07/Prio3Count_1.json"),
+ ] {
+ let t: TPrio3<u64> = serde_json::from_str(test_vector_str).unwrap();
+ let prio3 = Prio3::new_count(t.shares).unwrap();
+
+ let aggregate_result = check_aggregate_test_vec(&prio3, &t);
+ assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap());
+ }
+}
+
+#[test]
+fn test_vec_prio3_sum() {
+ for test_vector_str in [
+ include_str!("test_vec/07/Prio3Sum_0.json"),
+ include_str!("test_vec/07/Prio3Sum_1.json"),
+ ] {
+ let t: TPrio3<u128> = serde_json::from_str(test_vector_str).unwrap();
+ let bits = t.other_params["bits"].as_u64().unwrap() as usize;
+ let prio3 = Prio3::new_sum(t.shares, bits).unwrap();
+
+ let aggregate_result = check_aggregate_test_vec(&prio3, &t);
+ assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap() as u128);
+ }
+}
+
+#[test]
+fn test_vec_prio3_sum_vec() {
+ for test_vector_str in [
+ include_str!("test_vec/07/Prio3SumVec_0.json"),
+ include_str!("test_vec/07/Prio3SumVec_1.json"),
+ ] {
+ let t: TPrio3<Vec<u128>> = serde_json::from_str(test_vector_str).unwrap();
+ let bits = t.other_params["bits"].as_u64().unwrap() as usize;
+ let length = t.other_params["length"].as_u64().unwrap() as usize;
+ let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize;
+ let prio3 = Prio3::new_sum_vec(t.shares, bits, length, chunk_length).unwrap();
+
+ let aggregate_result = check_aggregate_test_vec(&prio3, &t);
+ let expected_aggregate_result = t
+ .agg_result
+ .as_array()
+ .unwrap()
+ .iter()
+ .map(|val| val.as_u64().unwrap() as u128)
+ .collect::<Vec<u128>>();
+ assert_eq!(aggregate_result, expected_aggregate_result);
+ }
+}
+
+#[test]
+fn test_vec_prio3_histogram() {
+ for test_vector_str in [
+ include_str!("test_vec/07/Prio3Histogram_0.json"),
+ include_str!("test_vec/07/Prio3Histogram_1.json"),
+ ] {
+ let t: TPrio3<usize> = serde_json::from_str(test_vector_str).unwrap();
+ let length = t.other_params["length"].as_u64().unwrap() as usize;
+ let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize;
+ let prio3 = Prio3::new_histogram(t.shares, length, chunk_length).unwrap();
+
+ let aggregate_result = check_aggregate_test_vec(&prio3, &t);
+ let expected_aggregate_result = t
+ .agg_result
+ .as_array()
+ .unwrap()
+ .iter()
+ .map(|val| val.as_u64().unwrap() as u128)
+ .collect::<Vec<u128>>();
+ assert_eq!(aggregate_result, expected_aggregate_result);
+ }
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json
new file mode 100644
index 0000000000..2ff7aa7ffd
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json
@@ -0,0 +1,52 @@
+{
+ "alpha": "0",
+ "beta_inner": [
+ [
+ "0",
+ "0"
+ ],
+ [
+ "1",
+ "1"
+ ],
+ [
+ "2",
+ "2"
+ ],
+ [
+ "3",
+ "3"
+ ],
+ [
+ "4",
+ "4"
+ ],
+ [
+ "5",
+ "5"
+ ],
+ [
+ "6",
+ "6"
+ ],
+ [
+ "7",
+ "7"
+ ],
+ [
+ "8",
+ "8"
+ ]
+ ],
+ "beta_leaf": [
+ "9",
+ "9"
+ ],
+ "binder": "736f6d65206e6f6e6365",
+ "bits": 10,
+ "keys": [
+ "000102030405060708090a0b0c0d0e0f",
+ "101112131415161718191a1b1c1d1e1f"
+ ],
+ "public_share": "921909356f44964d29c537aeeaeba92e573e4298c88dcc35bd3ae6acb4367236226b1af3151d5814f308f04e208fde2110c72523338563bc1c5fb47d22b5c34ae102e1e82fa250c7e23b95e985f91d7d91887fa7fb301ec20a06b1d4408d9a594754dcd86ec00c91f40f17c1ff52ed99fcd59965fe243a6cec7e672fefc5e3a29e653d5dcca8917e8af2c4f19d122c6dd30a3e2a80fb809383ced9d24fcd86516025174f5183fddfc6d74dde3b78834391c785defc8e4fbff92214df4c8322ee433a8eaeed7369419e0d6037a536e081df333aaab9e8e4d207d846961f015d96d57e3b59e24927773d6e0d66108955c1da134baab4eacd363c8e452b8c3845d5fb5c0ff6c27d7423a73d32742ccc3c750a17cd1f6026dd98a2cf6d2bff2dd339017b25af23d6db00ae8975e3f7e6aaef4af71f3e8cd14eb5c4373db9c3a76fc04659b761e650a97cb873df894064ecb2043a4317ef237ffe8f130eb5c2ca2a132c16f14943cd7e462568c8544b82e29329eb2a"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json
new file mode 100644
index 0000000000..79fadca3df
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json
@@ -0,0 +1,56 @@
+{
+ "agg_param": [
+ 0,
+ [
+ 0,
+ 1
+ ]
+ ],
+ "agg_result": [
+ 0,
+ 1
+ ],
+ "agg_shares": [
+ "70f1cb8dc03c9eea88d270d6211a8667",
+ "910e34723ec361157a2d8f29dde57998"
+ ],
+ "bits": 4,
+ "prep": [
+ {
+ "input_shares": [
+ "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930",
+ "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c"
+ ],
+ "measurement": 13,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "70f1cb8dc03c9eea",
+ "88d270d6211a8667"
+ ],
+ [
+ "910e34723ec36115",
+ "7a2d8f29dde57998"
+ ]
+ ],
+ "prep_messages": [
+ "d4cd54eb29f676c2d10fab848e6e85ebd51804e3562cf23b",
+ ""
+ ],
+ "prep_shares": [
+ [
+ "bd68d28c9fff9a30f84122278759025501b83270bf27b41d",
+ "1765825e8af6db91d9cd885d07158396d460d17297043e1e"
+ ],
+ [
+ "7c9659b7c681b4a4",
+ "8569a648387e4b5b"
+ ]
+ ],
+ "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json
new file mode 100644
index 0000000000..a566fe8b4d
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json
@@ -0,0 +1,64 @@
+{
+ "agg_param": [
+ 1,
+ [
+ 0,
+ 1,
+ 2,
+ 3
+ ]
+ ],
+ "agg_result": [
+ 0,
+ 0,
+ 0,
+ 1
+ ],
+ "agg_shares": [
+ "d83fbcbf13566502f5849058b8b089e568a4e8aab8565425f69a56f809fc4527",
+ "29c04340eba99afd0c7b6fa7464f761a995b175546a9abda0c65a907f503bad8"
+ ],
+ "bits": 4,
+ "prep": [
+ {
+ "input_shares": [
+ "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930",
+ "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c"
+ ],
+ "measurement": 13,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "d83fbcbf13566502",
+ "f5849058b8b089e5",
+ "68a4e8aab8565425",
+ "f69a56f809fc4527"
+ ],
+ [
+ "29c04340eba99afd",
+ "0c7b6fa7464f761a",
+ "995b175546a9abda",
+ "0c65a907f503bad8"
+ ]
+ ],
+ "prep_messages": [
+ "d45c0eabcc906acfb8239f3d0ef2b69a0f465979b04e355c",
+ ""
+ ],
+ "prep_shares": [
+ [
+ "5d1b91841835491251436306076eaaa674d4b95b84b2a084",
+ "77417d26b45b21bd68e03b3706840cf49c719f1d2b9c94d7"
+ ],
+ [
+ "6e703a28b5960604",
+ "938fc5d74969f9fb"
+ ]
+ ],
+ "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json
new file mode 100644
index 0000000000..8141bc942e
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json
@@ -0,0 +1,64 @@
+{
+ "agg_param": [
+ 2,
+ [
+ 0,
+ 2,
+ 4,
+ 6
+ ]
+ ],
+ "agg_result": [
+ 0,
+ 0,
+ 0,
+ 1
+ ],
+ "agg_shares": [
+ "7ea47022f22f6be9bce8e0ee2eb522bcbc2d246c17704beed7043426b646fe26",
+ "835b8fdd0cd0941645171f11d04add4345d2db93e78fb4112bfbcbd948b901d9"
+ ],
+ "bits": 4,
+ "prep": [
+ {
+ "input_shares": [
+ "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930",
+ "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c"
+ ],
+ "measurement": 13,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "7ea47022f22f6be9",
+ "bce8e0ee2eb522bc",
+ "bc2d246c17704bee",
+ "d7043426b646fe26"
+ ],
+ [
+ "835b8fdd0cd09416",
+ "45171f11d04add43",
+ "45d2db93e78fb411",
+ "2bfbcbd948b901d9"
+ ]
+ ],
+ "prep_messages": [
+ "6fb240ce8b8a2a8ce62112240f676105e0398515599f04b4",
+ ""
+ ],
+ "prep_shares": [
+ [
+ "ca0f02c7c61655263bf76d954b8abd16eb6e5ce2b26911b2",
+ "a5a23e07c573d565ac2aa48ec2dca3eef5ca2833a635f301"
+ ],
+ [
+ "f5171a3cc9d49422",
+ "0ce8e5c3352b6bdd"
+ ]
+ ],
+ "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json
new file mode 100644
index 0000000000..1741ec0ebc
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json
@@ -0,0 +1,76 @@
+{
+ "agg_param": [
+ 3,
+ [
+ 1,
+ 3,
+ 5,
+ 7,
+ 9,
+ 13,
+ 15
+ ]
+ ],
+ "agg_result": [
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0
+ ],
+ "agg_shares": [
+ "ec2be80f01fd1ded599b1a18d6ef112c400f421cca2c080d4ccc5cdd09562b3e556c1aaabe9dd47e8bc25979394c7bb5c61fd1db34b8dfdcc3eff4a5304fb7706b5462025bb400e644f2e0752f38098702491691494a2b498176ef41c4e6a962f716473c53087a3e80db0b9acb50cb15081b5ea4b50c48093f67a8c75875422dfd64ab2fa71fa3f3b55ec708ba4086672aff514d0cffe6f1c07f117c22af9b2c67b2a0c7ec1366ce474721174edb8b9eb33faef5f9c9d0c956e4407a86473120cfa46e8c634c1bc66c63a2009911f82c8426a45013e637aaba0e471b03f0a67a",
+ "01d417f0fe02e212a664e5e72910eed3bff0bde335d3f7f2b333a322f6a9d4419893e55541622b81743da686c6b3844a39e02e24cb4720233c100b5acfb0480f82ab9dfda44bff19bb0d1f8ad0c7f678fdb6e96eb6b5d4b67e8910be3b19561df6e8b8c3acf785c17f24f46534af34eaf7e4a15b4af3b7f6c0985738a78abd52f09a54d058e05c0c4aa138f745bf7998d500aeb2f300190e3f80ee83dd506453874d5f3813ec9931b8b8dee8b12474614cc0510a06362f36a91bbf8579b8ce5f1e5b91739cb3e439939c5dff66ee07d37bd95bafec19c85545f1b8e4fc0f5905"
+ ],
+ "bits": 4,
+ "prep": [
+ {
+ "input_shares": [
+ "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930",
+ "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c"
+ ],
+ "measurement": 13,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "ec2be80f01fd1ded599b1a18d6ef112c400f421cca2c080d4ccc5cdd09562b3e",
+ "556c1aaabe9dd47e8bc25979394c7bb5c61fd1db34b8dfdcc3eff4a5304fb770",
+ "6b5462025bb400e644f2e0752f38098702491691494a2b498176ef41c4e6a962",
+ "f716473c53087a3e80db0b9acb50cb15081b5ea4b50c48093f67a8c75875422d",
+ "fd64ab2fa71fa3f3b55ec708ba4086672aff514d0cffe6f1c07f117c22af9b2c",
+ "67b2a0c7ec1366ce474721174edb8b9eb33faef5f9c9d0c956e4407a86473120",
+ "cfa46e8c634c1bc66c63a2009911f82c8426a45013e637aaba0e471b03f0a67a"
+ ],
+ [
+ "01d417f0fe02e212a664e5e72910eed3bff0bde335d3f7f2b333a322f6a9d441",
+ "9893e55541622b81743da686c6b3844a39e02e24cb4720233c100b5acfb0480f",
+ "82ab9dfda44bff19bb0d1f8ad0c7f678fdb6e96eb6b5d4b67e8910be3b19561d",
+ "f6e8b8c3acf785c17f24f46534af34eaf7e4a15b4af3b7f6c0985738a78abd52",
+ "f09a54d058e05c0c4aa138f745bf7998d500aeb2f300190e3f80ee83dd506453",
+ "874d5f3813ec9931b8b8dee8b12474614cc0510a06362f36a91bbf8579b8ce5f",
+ "1e5b91739cb3e439939c5dff66ee07d37bd95bafec19c85545f1b8e4fc0f5905"
+ ]
+ ],
+ "prep_messages": [
+ "4a2b97cf17e54b126a86c6791c50d6507ee8b74b3d9903bcf3881121bc6e0975c4efb2d8b8a132b8a6caa4eb39ac2bbb5bdc351604fa9e78d1a6f5a5f615bb0c8819f485d8b24a4e48da47d3b7458a9cfde1e85c66453319a3f6d43dc40a0135",
+ ""
+ ],
+ "prep_shares": [
+ [
+ "4e64e5ed76c69ef68d3e144918a719986e40ab82f34bd30298b0085a3265d16988b8f646731ef47cb2fb1598e4cb817747623f1cc70ee7843ce1a9d6e3cf5c456801c9a3ae0c7c7663349a3daaf8fb51d165085c751e5bdd4e800df9e1e0193e",
+ "fcc6b1e1a01ead1bdc47b23004a9bcb80fa80cc9494d30b95bd808c78909380b2937bc9145833e3bf4ce8e5355e0a943147af6f93cebb7f394c54bcf12465e470d182be229a6ced7e4a5ad950d4d8e4a2c7ce000f126d83b5476c744e229e776"
+ ],
+ [
+ "003c39f76240f6f9bcc6065a247b4432a651d5d72a35aff45928eec28c8a9d07",
+ "edc3c6089dbf09064339f9a5db84bbcd59ae2a28d5ca500ba6d7113d73756278"
+ ]
+ ],
+ "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json
new file mode 100644
index 0000000000..c27ad93435
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json
@@ -0,0 +1,39 @@
+{
+ "agg_param": null,
+ "agg_result": 1,
+ "agg_shares": [
+ "afead111dacc0c7e",
+ "53152eee2433f381"
+ ],
+ "prep": [
+ {
+ "input_shares": [
+ "afead111dacc0c7ec08c411babd6e2404df512ddfa0a81736b7607f4ccb3f39e414fdb4bc89a63569702c92aed6a6a96",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"
+ ],
+ "measurement": 1,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "afead111dacc0c7e"
+ ],
+ [
+ "53152eee2433f381"
+ ]
+ ],
+ "prep_messages": [
+ ""
+ ],
+ "prep_shares": [
+ [
+ "123f23c117b7ed6099be9e6a31a42a9caa60882a3b4aa50303f8b588c9efe60b",
+ "efc0dc3ee748129f2da661f47a625a57d64a5b62ab38647c34bb161c7576d721"
+ ]
+ ],
+ "public_share": "",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json
new file mode 100644
index 0000000000..148fe6df58
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json
@@ -0,0 +1,45 @@
+{
+ "agg_param": null,
+ "agg_result": 1,
+ "agg_shares": [
+ "c5647e016eea69f6",
+ "53152eee2433f381",
+ "eb8553106be2a287"
+ ],
+ "prep": [
+ {
+ "input_shares": [
+ "c5647e016eea69f6d10e90d05e2ad8b402b8580f394a719b371ae8f1a364b280d08ca7177946a1a0b9643e2469b0a2e9",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f",
+ "202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
+ ],
+ "measurement": 1,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "c5647e016eea69f6"
+ ],
+ [
+ "53152eee2433f381"
+ ],
+ [
+ "eb8553106be2a287"
+ ]
+ ],
+ "prep_messages": [
+ ""
+ ],
+ "prep_shares": [
+ [
+ "5c8d00fd24e449d375581d6adbeaf9cf4bdface6d368fd7b1562e5bf47b9fa68",
+ "efc0dc3ee748129f2da661f47a625a57d64a5b62ab38647c34bb161c7576d721",
+ "b7b122c4f1d2a38df764c623c266f02f7b5178c3d64735ec06037585d643f528"
+ ]
+ ],
+ "public_share": "",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 3,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json
new file mode 100644
index 0000000000..099f786669
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json
@@ -0,0 +1,52 @@
+{
+ "agg_param": null,
+ "agg_result": [
+ 0,
+ 0,
+ 1,
+ 0
+ ],
+ "agg_shares": [
+ "14be9c4ef7a6e12e963fdeac21cebdd4d36e13f4bc25306322e56303c62c90afd73f6b4aa9fdf33cb0afb55426d645ff8cd7e78cebf9d4f1087f6d4a033c8eae",
+ "ed4163b108591ed14dc02153de31422b2e91ec0b43dacf9cc11a9cfc39d36f502bc094b556020cc333504aabd929ba007528187314062b0edb8092b5fcc37151"
+ ],
+ "chunk_length": 2,
+ "length": 4,
+ "prep": [
+ {
+ "input_shares": [
+ "14be9c4ef7a6e12e963fdeac21cebdd4d36e13f4bc25306322e56303c62c90afd73f6b4aa9fdf33cb0afb55426d645ff8cd7e78cebf9d4f1087f6d4a033c8eaeec786d3b212d968c939de66318dbacafe73c1f5aa3e9078ba2f63ec5179e6b4694612c36f5d4d539d46dab1ac20e43963978d9dd36f19f31c83e58c903c2cd94215c68b15f5d6071e9e19fa973829dc71b536351b0db1072e77b7570e3e06c65fac248d21dd970f29640050e901d06775f05a897850cab5707ac25543ed6ce7061b9cd70c783e0483727236d0cbb05dafefd78ec4e6419efe93d6f82cdadbfd4e860661238040229f60205bbba983790303132333435363738393a3b3c3d3e3f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f"
+ ],
+ "measurement": 2,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "14be9c4ef7a6e12e963fdeac21cebdd4",
+ "d36e13f4bc25306322e56303c62c90af",
+ "d73f6b4aa9fdf33cb0afb55426d645ff",
+ "8cd7e78cebf9d4f1087f6d4a033c8eae"
+ ],
+ [
+ "ed4163b108591ed14dc02153de31422b",
+ "2e91ec0b43dacf9cc11a9cfc39d36f50",
+ "2bc094b556020cc333504aabd929ba00",
+ "7528187314062b0edb8092b5fcc37151"
+ ]
+ ],
+ "prep_messages": [
+ "7556ccbddbd14d509ee89124d31d1feb"
+ ],
+ "prep_shares": [
+ [
+ "806b1f8537500ce0b4b501b0ae5ed8f82679ba11ad995d6f605b000e32c41afb6d0070287fe7b99b8304d264cba1e3c6f4456e1c06f3b9d3d4947b2041c86b020d26c74d7663817e6a91960489806931b304fcd3755b43b96c806d2bbeb0166bbec7c61c35f886f3f539890522388f43",
+ "8194e07ac8aff31f2f4afe4f51a12707b692a56a1745315a1022b4eb257b2a8725c610416af7b0d1a296f409cdb3fbf4f4c0d488206d794254e4755fd124cdc9a67364ddc7865afe3554de5f52f1ac910f3f8e110cfbad4113861316dc73ec60de4f6c512adaa41de631eda8d6d8c189"
+ ]
+ ],
+ "public_share": "bec7c61c35f886f3f539890522388f43de4f6c512adaa41de631eda8d6d8c189",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json
new file mode 100644
index 0000000000..0b9a9b4d5d
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json
@@ -0,0 +1,89 @@
+{
+ "agg_param": null,
+ "agg_result": [
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0
+ ],
+ "agg_shares": [
+ "a6d2b4756e63a10bf5f650e258c73b0ccb20bace98e225dea29d625d527fdd4ded86beb9a0a0ac5c4216f5add7a297cece34b479a568a327f1e259839f813df97b34de254be5b9b9c8d9e56dbff50b7a6bf1e5967686755a1dc42e0ab170add8c88f8ca68f945e768a5007c775fd27cfecb4495e257a2f2f94ca48830aa16ec0decaeee645e295c5dc2ebe491aae1a7f17b2807fcb33ee08127db466067bf84ec613dac9c93adbe73dd262c1859b2865",
+ "ed4163b108591ed14dc02153de31422b2e91ec0b43dacf9cc11a9cfc39d36f502bc094b556020cc333504aabd929ba007528187314062b0edb8092b5fcc3715126e16ce274ad58caaa14d22608269a4c41a256d3c9e847c0a6ac1a4fbaf6309e9ccbe74a9442ca956d843d6bd5adf9797a84557597d9cc81ddfa281ae5048d686bdb289ec2f3c96cdfa79b6974e6d15aec047748636d4358226283e11a78e045f59db2dda566162a56c85936ac0f4696",
+ "6eebe7d888434023a1488dcac80682c8084e592524430a857f4701a673adb261eab8ac90085d47e06d99c0a64e33ae30bfa23313469131cafb9b13c763ba50b560eab4f73f6ded7b7011486b38e45939566cc395bf9042e5038fb6a6949821899ea48b0edc28d7f3cf2abbcdb454deb69cc6602c43ac034f563a8e62105a04d7b859e87af729a0cd2729a64c716b1326fe480838d15ece9eaf20c8b7de0c276b464e7358905e0eee4f654308ce549104"
+ ],
+ "chunk_length": 3,
+ "length": 11,
+ "prep": [
+ {
+ "input_shares": [
+ "a6d2b4756e63a10bf5f650e258c73b0ccb20bace98e225dea29d625d527fdd4ded86beb9a0a0ac5c4216f5add7a297cece34b479a568a327f1e259839f813df97b34de254be5b9b9c8d9e56dbff50b7a6bf1e5967686755a1dc42e0ab170add8c88f8ca68f945e768a5007c775fd27cfecb4495e257a2f2f94ca48830aa16ec0decaeee645e295c5dc2ebe491aae1a7f17b2807fcb33ee08127db466067bf84ec613dac9c93adbe73dd262c1859b2865508d344dda6c4339e650c401324c31481780ef7e7dcc07120ac004c05ab75ee5d22e2d0eb229dcdd3755fab49a1c2916e17c8ed2d975cfe76d576569bf05233c07f94417fccaf73d1cc33e17dae74650badffdd639a9b9f9e89de4b9fd13e258b90fbb2b3817b607dc14e6e5327746ca20d1f1918bce9714b135ffe01eb4e6aefab92b0462f7e676e26007e8c2e5a66e16f32f7c8457a6dfba39d9082f640006d560b4d64e86e2e2358c84e03b857c980f51b1a78b53f7cb44343ed184d8dc87ebf8698609eeefae5d8882224ebd28b9531015badea8ae9fe01c7495cafecdc4f13389ea4eb0bbce0a5ab85aa6fc06aabd96d28c84ecf039bfeb4c350049485f8a4c706a109164ff4c640edaedd0ad50820b1d1ed7ab08fc69c48b39aff1eebc02ef1ea40bd70784bfa50511c3dd64b107f4297842280c3cff8d94be202a0e2cb0090f3adb2189f445fcf291f452f162606162636465666768696a6b6c6d6e6f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
+ "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f"
+ ],
+ "measurement": 2,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "a6d2b4756e63a10bf5f650e258c73b0c",
+ "cb20bace98e225dea29d625d527fdd4d",
+ "ed86beb9a0a0ac5c4216f5add7a297ce",
+ "ce34b479a568a327f1e259839f813df9",
+ "7b34de254be5b9b9c8d9e56dbff50b7a",
+ "6bf1e5967686755a1dc42e0ab170add8",
+ "c88f8ca68f945e768a5007c775fd27cf",
+ "ecb4495e257a2f2f94ca48830aa16ec0",
+ "decaeee645e295c5dc2ebe491aae1a7f",
+ "17b2807fcb33ee08127db466067bf84e",
+ "c613dac9c93adbe73dd262c1859b2865"
+ ],
+ [
+ "ed4163b108591ed14dc02153de31422b",
+ "2e91ec0b43dacf9cc11a9cfc39d36f50",
+ "2bc094b556020cc333504aabd929ba00",
+ "7528187314062b0edb8092b5fcc37151",
+ "26e16ce274ad58caaa14d22608269a4c",
+ "41a256d3c9e847c0a6ac1a4fbaf6309e",
+ "9ccbe74a9442ca956d843d6bd5adf979",
+ "7a84557597d9cc81ddfa281ae5048d68",
+ "6bdb289ec2f3c96cdfa79b6974e6d15a",
+ "ec047748636d4358226283e11a78e045",
+ "f59db2dda566162a56c85936ac0f4696"
+ ],
+ [
+ "6eebe7d888434023a1488dcac80682c8",
+ "084e592524430a857f4701a673adb261",
+ "eab8ac90085d47e06d99c0a64e33ae30",
+ "bfa23313469131cafb9b13c763ba50b5",
+ "60eab4f73f6ded7b7011486b38e45939",
+ "566cc395bf9042e5038fb6a694982189",
+ "9ea48b0edc28d7f3cf2abbcdb454deb6",
+ "9cc6602c43ac034f563a8e62105a04d7",
+ "b859e87af729a0cd2729a64c716b1326",
+ "fe480838d15ece9eaf20c8b7de0c276b",
+ "464e7358905e0eee4f654308ce549104"
+ ]
+ ],
+ "prep_messages": [
+ "4b7dc5c1b2a08aec5dcfc13de800559b"
+ ],
+ "prep_shares": [
+ [
+ "e80c098526d9321dd0801f97a648722016fa117f10cb2b062fc5fb1e55705894007f838333ef348c6306e141369bd88d123c66d2faeb132e330a73882c38765d425847bd86e5f784b3348ee4840c5df103b49f04c4dcca4667abb956187da58c91c946d9d5fdf496d95428f8a625dddfc8b7bb469397ebd4b177f902896febdaac39a8d9ec0aa1a24132036c2430929c",
+ "4f6d4137ddffd58b243b8f845a1684550b240ea3e91a68335f717b83056e9b45c5e62d7a24da54147fcb9260d023cb7f9c8d036f0100f5fea0ce22f49e3d7672bc83fd5c724f2684f3442e8c5291c41509151808d1da447cddc3fe11cf5cd8d7fe662cf035eff88b583f6b32499b332aa6dee37947ef482e15fcb3a7f04b20813162d162b9bf30eee4953b6fdabd10f1",
+ "ca85b543fc26f756ef4351e4fea0098a73787eeab8b613029af310882b7e87d28fc5502fd76bd704626d9f0f662e531feaf1fa912cc209de6541401d8508a4788d92e549f58241334cfa29abc1fb80a28ae61d7a4060d9582e3e20f182e4519ab1bb9547c545aafc21416e779856c80cf3690155119111aebf3800757989229e4966453f7aa269163b272848de80227f"
+ ]
+ ],
+ "public_share": "ac39a8d9ec0aa1a24132036c2430929c3162d162b9bf30eee4953b6fdabd10f14966453f7aa269163b272848de80227f",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f"
+ }
+ ],
+ "shares": 3,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json
new file mode 100644
index 0000000000..a7178fcfee
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json
@@ -0,0 +1,194 @@
+{
+ "agg_param": null,
+ "agg_result": [
+ 256,
+ 257,
+ 258,
+ 259,
+ 260,
+ 261,
+ 262,
+ 263,
+ 264,
+ 265
+ ],
+ "agg_shares": [
+ "cdeb52d4615d1c718ef21a6560939efcb5024c89ea9b0f0018302087c0e978b5b5c84c8c4f217b14584cc3939963f56a2718c17af81d1f85129347fc73b2548262e047a67c3e0583eeef7c4a98a5da4ccdb558491081a00a8de5a46b8c2f299bc69162a7411b3ecf2b47670db337083dc50b8fc6d5c5d21da52fc7f538166b0e4564edd8fbb75bc8c4fdd0c02be2b6d11bc4a159297b72e2c635a9250feb2445",
+ "3415ad2b9ea2e38e550de59a9f6c61034dfeb3761564f0ffcbcfdf783f16874a4e38b373b0de84eb8bb33c6c669c0a95dde83e8507e2e07ad16cb8038c4dab7da320b85983c1fa7cf50f83b5675a25b3394ba7b6ef7e5ff5561a5b9473d0d664416f9d58bee4c130b8b898f24cc8f7c243f570392a3a2de23ed0380ac7e994f1c49c12270448a4371f022f3fd41d492eef3c5ea6d6848d1d1dca56daf014dbba"
+ ],
+ "bits": 8,
+ "chunk_length": 9,
+ "length": 10,
+ "prep": [
+ {
+ "input_shares": [
+ "2451f59efba2edc493c9246ff1f0e0a7f8f6f22ee46e662c899e485d7ce288d6becdfee804a39618972fbaa595eeec25423e412cbe51d44a62747de627630302368ec3535a2545a2799e8a0b9a144c811158dda278865d834b34fbe77ad11dbb9fdcf0637c24e10d5ab36d03cdc5f6b95e400a0a81608d96c25733c376376de3927c3570e8ab671358a1686d0c44ac938368d5621cff66585454ef124daa5f18efd7e791a4bcb11caf74b378e2c4feff3e5bad16e7c3fab987eb4d4a0c675bb4f4e70e1373fb00a5dd30a1118355c20e2e4c3700be3d3c1cf25d3e4a729836ba564aa074f99be0d23d4cc0dc9f263c986988e0d16a3d28c262d34f220b1ed127cddea3e2a1bd075c653d4b6f1c3d35e25d2804e7960250dea42dc4a52c9545bedc182ee8391b4c6849366af8e15f30bd06872e5ed651ef7db0b0c442886de32eeeeacc5f2dfe87f9375b4774153fc9e442105b5f8e452e80874c84131400d4d588a1a5d94bac9e68dbf917ef6405b0bc13fa89daf46f84405aedb166ac93f6545256b1da6ac65e01d580bb26eef82c34b9d728fc0c96ff898ed46bd289abbec9917397552ebf6d1eb3f916f69ee9f80e9466512bff70af2d8f3a9ed599f24e33550a09304e1b4f51948e2d8cbf5a1bb14455b1786ae3af4670111bc3983293ad9ae029128efd86d0a05cb3f442b43f466cec5cc9c4989bf5a29eb5c2401bc8bba0d5b7487bc0bf010c968fe76e3a9924459dce6704528d56540081240ed0d2f301a8c9baca5c183b1b5c3a9c03dce5036926d06e1470c2e63d15fdc3a61056154fca9439c595098ff3794c7d7e62af5e3139b43e22a0f8864c254a069a083604762d77ea000177a7b908efe27f6e00db7ea25f573734af803044fb2dc333ed9ad589d7677e23614143fa6836d68ed311dcedf0ea031688b2cf8c5248f21be444f1c61b050a0ab7dca04992673afb737bd27526a72dda3b03dad4d3b0bb81f0887ee6f25ec4cf35d58ea5f085e97609cfb6a8e97d84fdf8755b8e81ff29614bf1b03bcd2b8d9ab06dc4d60785f83eb6ee4573859223214ecdad734d114e15e1971a8b82222910fd041a1123a4e792a9239f99252de3e3e8d5bb209e2c9bda506a79853c482546940364a8246392fb5e18e85847458445fe3a970b29db6d3d0e4a806cfec7c8538f24896d2d10669113f2b724161d2007ee75c0b651f4934046142b04b2015212997c609625bbeb81b9fa0249c167196557c08ae9ea2defcf7859eed4d35b2183c628cb82fe01255e558e7c8b13bded6ba43ed8b92f6ba7879b39932468260c5768ca0909aae899ad1252c5dbcd741d971f179bc36e88a0a10981f73202cb25db324da405fdd5ca5331431afe362c5f933b3c1216c3e19140cd27f7c2ef67898887856a46a518a3afec78ee0d9dce778289a38d2df906932c40019afadab12fe7d0695316e5a3c1e38aa630a44bc8cc01a5a8cae060b7de435e54963b9354182d64e340ec9dc3e37f8b2bbaaab23608b86827991df4367839f443c160c1eb77f41159f69592c3eb37c21a521afcd34036a13a145e9cb1039704b8e523359ea5c3a50f705118ea7d8b1063eb85bfcdc941e0235579e97856ee6f6bfe9c4d0d161b5662de26a2fbddc530ce918a98514903c63a3476d6ebe68e2503e6bf255691fbd8a006e9c77f5a4ad9e3e8d21a56bc4f7bc90d61ebb31eaa4dce48eb9a8069a584ae35266a4bc4af970860d2e9a0df7b87e8fc8b597e73a85d8eeb91def6057d7a77e8f859ee9ee07ef2fb2260660e59ee16458eafbd7bab979ed9bf72c1c27cadc9011f563aeed9a4b2f09ce5455857bcf3acbe0e1cf15537594469c777a885ad24ac5a8c894a8257c5212fb46184ad7f280bc25600129b25bf941460fcdd2e45a0216f1f2fec84d4792a15f877d15c649991ef998621a50a04251257ed6ccd803fdbc83ce4b5c4d7e8ee4487848768384e0b3970ec899dfb423560e755c4716deb3e188ca780cc7a97e8fb80076e9c44c6b7575725253ae4605844c3748bf90c14f17ee80a42fea450c05eb1f251eb09855f2ab368047d68cd7f4b8898d0113d52117375eed77707e7abff8554dc7da30cca674ac587198c0f165a47a5db79e9cc1c7bda9cdc36b94f14141a307850062f4d7208b8c612691699e3e3c16c3fc2dc6604e0fcb27acd6b5d326d2663a3f819589176c70423bc4d37d8c2a17e97468cf923a5388eb0d1de61358e9651a7e76b033d32d6c84e7ed5831a990b46e8228b6ef120643049645b82e100a7ed6ddd2ebfe2dcbd8b0e7ac1e5ee021d4279f164acc47875ade2c0acff5dbbf3a6eb0e8601632c926780be1660270420aa02c99fb39af1852b09904791e90cfa1f02aec0ab2de111524394819527819e52d495196ab3aff1e323dfec07af91e18b9a04e37552a23b13177bdcfc64ec7108e5e9b3679ccdf6b1e998e2bbcd5fbbbebf5ad8008e727cae6499cc06aa03809947e298683a4340f51d6eecad38d0a7a5437dd6e72bce6543b81fd3a438d71e232845cdb403f1011295f9aee5e33352b86e92343985884284c9646da13545f37b9d6da7d0cf902c19a5ca1f4f1818a2c2644807fcc54be35c29f96fb4fea5efdc88b270f1c5504bd8ba558834786020cc2f03ab5c56eaea38532b9faf6208f57d970b2e5ff92872713c9e0ad07b26e72dca6f9a9c02bad6c9db4d1d738f306292f14415d2856c2b073c5d8faf89e9713ceb375b6eefabc240bf6c6bf39cafb99993767dbaf5ee5f4b3f93e638e904fb55f443312c145b809fd203b5b3a16bd229b952e100bbfc0e49bbd05d54c3e5fa1a44fe55de16cfa52f3b169e0bfe95b1b8b6367f9309adfe3df079104fd720d46d772def3c0534d73615071fa22a79af875f796478d2f599dbb4c1ed303132333435363738393a3b3c3d3e3f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f"
+ ],
+ "measurement": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9
+ ],
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "9aa31b9c201fb42526a6b32120318aa9",
+ "e8551983a3deafaafe0f608295f8d291",
+ "3eed6ed96f607eb1be6e96868876fc78",
+ "65b295d3525f0ad74886c2fe7b3b1cd6",
+ "794a6d37d41457d6f04fd418888cf36e",
+ "483cc86d052be058d0a1e12384ba0d89",
+ "9c85cb376b5ebfeffb6c22af3bbd02bf",
+ "9c0385979cecf00983ba97fc12b2235a",
+ "c7cbf9f2533dc942eca94540b9a0e745",
+ "0f418bc80d2926f6ec11e3615a4e0c17"
+ ],
+ [
+ "675ce463dfe04bdabd594cdedfce7556",
+ "1aaae67c5c215055e5ef9f7d6a072d6e",
+ "c5129126909f814e2591697977890387",
+ "9f4d6a2cada0f5289b793d0184c4e329",
+ "8cb592c82beba829f3af2be777730c91",
+ "bec33792fad41fa7135e1edc7b45f276",
+ "6b7a34c894a14010e892dd50c442fd40",
+ "6cfc7a6863130ff660456803ed4ddca5",
+ "4234060dacc236bdf755babf465f18ba",
+ "fbbe7437f2d6d909f7ed1c9ea5b1f3e8"
+ ]
+ ],
+ "prep_messages": [
+ "db085315822777376b4d0f962d8f06d9"
+ ],
+ "prep_shares": [
+ [
+ "e6a4fe7264f95c384446bd51db14f78e7f4133afed0604aeb3b87125fc076c7447d795723adfe93d85f9fe2993c52420e45694fd2ec164a54a7267ee5efc8cb40b6659ac81f2e850786218bcec469ec4f7bb28e875d75ee98d54d566186c61c35448a50cb11e195d886622861a78bbb74325b7972e7b4c47f0e2e10a15d7a33c3daecb2dfc507b1b6676c1e9bfc52a4873f408a5788e7b77ce6943e67f3f457280544d93b81b08e427f699ba54adcbb0ffab83366d9b336846c0c989f0bc25bdd14683f1a85e844b9dbac26daae84cc8d57ef6b0c340798ac5ade63150e8d7a9673b64d798a97cf2715f399fd371e342c1ad50e28431f54180ef63ad7dd21f3e5d8d67159cacfd56f5d99c39d53047c8d7bf11ad83a2e3e569e1393b12d87d01701fa71b50b51e092ca6b797bb97890efb6327f1c4e488663dca5f00675c2af7368a9ab95b3c4e9e1a8dd5430d336833",
+ "1b5b018d9b06a3c79fb942ae24eb087131693625114418115cb3e54a45056eabd0f6e371501957e00796db78ea8f3388eb8345e938f5fbdd8b24c3a968276d7c457ff43ce93631942f823f5bb9c6d1335b8022e804072711cb8fa5fb3afb209e696c9cac47da44cdc3eb3874eb0d8c89692408b463df12bb2d6e8193d5829cce221e486a579b91cd10a0fb38fec7214a9008d574ba32615f6215aef827a2962a31df892814bd8b8f828d029f07f6acf490e7ecd3377f10b86ec2d5741ba37a3a9522b897e840e315a614a89d0bcf8296bdd45e330eaf3f34b3ce4e1dd41306eae92147fc6676eedff2cb239581f46750df341390e066dabb01ef6362694f923d5a65dbc5252a8da8702a979aca3e211af7124485c1c7dc68f6fb1bdfb9a4d0d993cece17c818d8fb1151f1ef1b32745f09c155f42259af588b424c1dafe8d0e90a3dfc6f55bb428773ce15071cb720fb"
+ ]
+ ],
+ "public_share": "368a9ab95b3c4e9e1a8dd5430d3368330a3dfc6f55bb428773ce15071cb720fb",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ },
+ {
+ "input_shares": [
+ "2551f59efba2edc493c9246ff1f0e0a7f8f6f22ee46e662c899e485d7ce288d6becdfee804a39618972fbaa595eeec25423e412cbe51d44a62747de627630302368ec3535a2545a2799e8a0b9a144c811158dda278865d834b34fbe77ad11dbb9fdcf0637c24e10d5ab36d03cdc5f6b95e400a0a81608d96c25733c376376de3927c3570e8ab671358a1686d0c44ac938368d5621cff66585454ef124daa5f18efd7e791a4bcb11caf74b378e2c4feff3e5bad16e7c3fab987eb4d4a0c675bb4f4e70e1373fb00a5dd30a1118355c20e2e4c3700be3d3c1cf25d3e4a729836ba564aa074f99be0d23d4cc0dc9f263c986988e0d16a3d28c262d34f220b1ed127cedea3e2a1bd075c653d4b6f1c3d35e25c2804e7960250dea42dc4a52c9545bedc182ee8391b4c6849366af8e15f30bd06872e5ed651ef7db0b0c442886de32eeeeacc5f2dfe87f9375b4774153fc9e442105b5f8e452e80874c84131400d4d588a1a5d94bac9e68dbf917ef6405b0bc13fa89daf46f84405aedb166ac93f6545256b1da6ac65e01d580bb26eef82c34b8d728fc0c96ff898ed46bd289abbec9917397552ebf6d1eb3f916f69ee9f80e9466512bff70af2d8f3a9ed599f24e33550a09304e1b4f51948e2d8cbf5a1bb14455b1786ae3af4670111bc3983293ad9ae029128efd86d0a05cb3f442b43f466cec5cc9c4989bf5a29eb5c2401bc8bba1d5b7487bc0bf010c968fe76e3a9924459dce6704528d56540081240ed0d2f300a8c9baca5c183b1b5c3a9c03dce5036926d06e1470c2e63d15fdc3a61056154fca9439c595098ff3794c7d7e62af5e3139b43e22a0f8864c254a069a083604762d77ea000177a7b908efe27f6e00db7ea25f573734af803044fb2dc333ed9ad589d7677e23614143fa6836d68ed311dcedf0ea031688b2cf8c5248f21be444f0c61b050a0ab7dca04992673afb737bd27526a72dda3b03dad4d3b0bb81f0887ee6f25ec4cf35d58ea5f085e97609cfb6a8e97d84fdf8755b8e81ff29614bf1b03bcd2b8d9ab06dc4d60785f83eb6ee4573859223214ecdad734d114e15e1971b8b82222910fd041a1123a4e792a9239e99252de3e3e8d5bb209e2c9bda506a78853c482546940364a8246392fb5e18e85847458445fe3a970b29db6d3d0e4a806cfec7c8538f24896d2d10669113f2b724161d2007ee75c0b651f4934046142b04b2015212997c609625bbeb81b9fa0249c167196557c08ae9ea2defcf7859eed4d35b2183c628cb82fe01255e558e7b8b13bded6ba43ed8b92f6ba7879b39922468260c5768ca0909aae899ad1252c5dbcd741d971f179bc36e88a0a10981f73202cb25db324da405fdd5ca5331431afe362c5f933b3c1216c3e19140cd27f7c2ef67898887856a46a518a3afec78ee0d9dce778289a38d2df906932c40019bfadab12fe7d0695316e5a3c1e38aa630a44bc8cc01a5a8cae060b7de435e54963b9354182d64e340ec9dc3e37f8b2bb9aab23608b86827991df4367839f443c160c1eb77f41159f69592c3eb37c21a521afcd34036a13a145e9cb1039704b8e523359ea5c3a50f705118ea7d8b1063eb85bfcdc941e0235579e97856ee6f6bfe9c4d0d161b5662de26a2fbddc530ce918a98514903c63a3476d6ebe68e2503e6bf255691fbd8a006e9c77f5a4ad9e3e7d21a56bc4f7bc90d61ebb31eaa4dce48eb9a8069a584ae35266a4bc4af970860d2e9a0df7b87e8fc8b597e73a85d8eeb91def6057d7a77e8f859ee9ee07ef2fb2260660e59ee16458eafbd7bab979ed9bf72c1c27cadc9011f563aeed9a4b2f09ce5455857bcf3acbe0e1cf15537594469c777a885ad24ac5a8c894a8257c5212fb46184ad7f280bc25600129b25bf941460fcdd2e45a0216f1f2fec84d4792a15f877d15c649991ef998621a50a04251257ed6ccd803fdbc83ce4b5c4d7e8ee4487848768384e0b3970ec899dfb423560e755c4716deb3e188ca780cc7a97e8fb80076e9c44c6b7575725253ae4605844c3748bf90c14f17ee80a42fea450c05eb1f251eb09855f2ab368047d68cd7f4b8898d0113d52117375eed77707e7abff8554dc7da30cca674ac587198c0f165a47a5db79e9cc1c7bda9cdc36b94f14141a307850062f4d7208b8c612691699e3e3c16c3fc2dc6604e0fcb27acd6b5d326d2663a3f819589176c70423bc4d42a21c0a30addfe4b4176740f9a418eca631cda9a8b94d20c7f9b834ed87751464dc5e5b446d920312003e4673b48c6c12b407af1ed90002507883f78d166f0b90bc14ed77d4aec6220cdd51948cdad29ab70513aadccd0e3c8c8108d3b0722602d9612aa6feb323a4ff3e8fe0e3d5701467491acdd3c71c34bc019047647779922216ccd61c47958461e3017adf446c4bd2ab7fbf70e41419679f6a9b3fa4c9aa5e9ef8469ace0d88bc35a3374f462573d2ba24b712359ef36e413006a9883bfa4fad43d89c7f1732725e3cad482d17a9499e1fb0f57d1ca93cafa7fd6d654a70cd7318bd7ace30e981217317105bcfe5e33352b86e92343985884284c9646d966beb8aca87d44e15dcce24aeb08312091cd98e6b52e2525409c58438f00131d33ce09fde0343f84db73369954f2d77a3a559189bc4dfd7e7c043b1364b36550595f624483c4eccccb1c4958a9284e43522dcc72ad9b01162d964605eab990dd1ddd25796f55991e1201f22526117662c0cad518f191effe5608b444b9e8973f5a11a8c154bc501bc47bb4fb5832d67f4c5ea5c221e64ff88ff5d5117aadac704a8b94beb036e87fc9a9462c355231b9bbe8a9122f12390073600f2d7f6262f1758eced79619900cad1910286c5bae553c6525c63c52ca97bf452957c5fc1a7d695dcb5ed0cf0004454c528d63960cd303132333435363738393a3b3c3d3e3f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f"
+ ],
+ "measurement": [
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1
+ ],
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "9ba31b9c201fb42526a6b32120318aa9",
+ "e8551983a3deafaafe0f608295f8d291",
+ "3ded6ed96f607eb1be6e96868876fc78",
+ "63b295d3525f0ad74886c2fe7b3b1cd6",
+ "764a6d37d41457d6f04fd418888cf36e",
+ "443cc86d052be058d0a1e12384ba0d89",
+ "9785cb376b5ebfeffb6c22af3bbd02bf",
+ "960385979cecf00983ba97fc12b2235a",
+ "c0cbf9f2533dc942eca94540b9a0e745",
+ "07418bc80d2926f6ec11e3615a4e0c17"
+ ],
+ [
+ "675ce463dfe04bdabd594cdedfce7556",
+ "1aaae67c5c215055e5ef9f7d6a072d6e",
+ "c5129126909f814e2591697977890387",
+ "9f4d6a2cada0f5289b793d0184c4e329",
+ "8cb592c82beba829f3af2be777730c91",
+ "bec33792fad41fa7135e1edc7b45f276",
+ "6b7a34c894a14010e892dd50c442fd40",
+ "6cfc7a6863130ff660456803ed4ddca5",
+ "4234060dacc236bdf755babf465f18ba",
+ "fbbe7437f2d6d909f7ed1c9ea5b1f3e8"
+ ]
+ ],
+ "prep_messages": [
+ "fdecd6d197e824492dd550fcdd0aa3c7"
+ ],
+ "prep_shares": [
+ [
+ "e6a4fe7264f95c384446bd51db14f78eec654b5beb7d60677c0b8385ca51a5bf896c79a069b04f339fdcb675885a26f533ec0d6e805beef6d7683a8ecdfa46cb87fdbc532efa5b1041347ed94f54fbca15041d013729d5eb78dcb1a1c293e0035448a50cb11e195d886622861a78bbb7bccc68d521fac002ece757cc5d82afb3b0f9f8766a1714047b50417a8e7f63eacf222c675bdf1d3e8362806ef5f9c16c446a1e5e06ce539aa300bc68c837c9207c5dbc7d85f896fb1be725e461ceacd303716a2cd8005e370a8ac1062d966b5c1499813d0dc63d697c4fd44dcfe81b3cfc37952de43649650c52ca9f2044fd3dfc42cd08bf0659e6a7facee2468ad1b07903a933ec3cbd06a5f81461e434449c32ef6678f3c8fd250c0b3318e80235144a96f07af2169dfddb503b38a2039f80f6bdfbbec6b3ad1025a43a90249221d20e627a73e2ed92bf1920574f42775675",
+ "1b5b018d9b06a3c79fb942ae24eb0871aba638b677efec2418d7921020c44ff8d0f6e371501957e00796db78ea8f338813d41fd058418e6056a93bb2785cec1d457ff43ce93631942f823f5bb9c6d133fbfbf9d10fa7922dccd1e570b0246775696c9cac47da44cdc3eb3874eb0d8c89fcdfd722a57116825749c8972129304c221e486a579b91cd10a0fb38fec7214af0fe63f4a24e9e189e180a8aaed3b22931df892814bd8b8f828d029f07f6acf402eb7c22aea07b55e871ebdf6475ae509522b897e840e315a614a89d0bcf8296c0b0d5b745f277689643e3da46c09a1ce92147fc6676eedff2cb239581f46750fa1947870c81f9d565f51ed7e09d86dc5a65dbc5252a8da8702a979aca3e211ae9def20e5f2fa85d16e3ccb0a917510793cece17c818d8fb1151f1ef1b32745f09c155f42259af588b424c1dafe8d0e90a3dfc6f55bb428773ce15071cb720fb"
+ ]
+ ],
+ "public_share": "0e627a73e2ed92bf1920574f427756750a3dfc6f55bb428773ce15071cb720fb",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ },
+ {
+ "input_shares": [
+ "2551f59efba2edc493c9246ff1f0e0a7f9f6f22ee46e662c899e485d7ce288d6bfcdfee804a39618972fbaa595eeec25433e412cbe51d44a62747de627630302378ec3535a2545a2799e8a0b9a144c811258dda278865d834b34fbe77ad11dbba0dcf0637c24e10d5ab36d03cdc5f6b95f400a0a81608d96c25733c376376de3927c3570e8ab671358a1686d0c44ac938468d5621cff66585454ef124daa5f18f0d7e791a4bcb11caf74b378e2c4feff3f5bad16e7c3fab987eb4d4a0c675bb4f5e70e1373fb00a5dd30a1118355c20e2f4c3700be3d3c1cf25d3e4a729836ba574aa074f99be0d23d4cc0dc9f263c986a88e0d16a3d28c262d34f220b1ed127cedea3e2a1bd075c653d4b6f1c3d35e25d2804e7960250dea42dc4a52c9545bedd182ee8391b4c6849366af8e15f30bd07872e5ed651ef7db0b0c442886de32eefeacc5f2dfe87f9375b4774153fc9e443105b5f8e452e80874c84131400d4d589a1a5d94bac9e68dbf917ef6405b0bc14fa89daf46f84405aedb166ac93f6545256b1da6ac65e01d580bb26eef82c34b9d728fc0c96ff898ed46bd289abbec9927397552ebf6d1eb3f916f69ee9f80e9566512bff70af2d8f3a9ed599f24e33560a09304e1b4f51948e2d8cbf5a1bb14555b1786ae3af4670111bc3983293ad9be029128efd86d0a05cb3f442b43f466dec5cc9c4989bf5a29eb5c2401bc8bba1d5b7487bc0bf010c968fe76e3a9924469dce6704528d56540081240ed0d2f301a8c9baca5c183b1b5c3a9c03dce5036a26d06e1470c2e63d15fdc3a610561550ca9439c595098ff3794c7d7e62af5e3239b43e22a0f8864c254a069a083604772d77ea000177a7b908efe27f6e00db7fa25f573734af803044fb2dc333ed9ad589d7677e23614143fa6836d68ed311ddedf0ea031688b2cf8c5248f21be444f1c61b050a0ab7dca04992673afb737bd37526a72dda3b03dad4d3b0bb81f0887fe6f25ec4cf35d58ea5f085e97609cfb7a8e97d84fdf8755b8e81ff29614bf1b13bcd2b8d9ab06dc4d60785f83eb6ee4673859223214ecdad734d114e15e1971b8b82222910fd041a1123a4e792a9239f99252de3e3e8d5bb209e2c9bda506a79853c482546940364a8246392fb5e18e95847458445fe3a970b29db6d3d0e4a816cfec7c8538f24896d2d10669113f2b824161d2007ee75c0b651f4934046142c04b2015212997c609625bbeb81b9fa0349c167196557c08ae9ea2defcf7859eed4d35b2183c628cb82fe01255e558e7c8b13bded6ba43ed8b92f6ba7879b39932468260c5768ca0909aae899ad1252c6dbcd741d971f179bc36e88a0a10981f83202cb25db324da405fdd5ca5331431bfe362c5f933b3c1216c3e19140cd27f8c2ef67898887856a46a518a3afec78ef0d9dce778289a38d2df906932c40019bfadab12fe7d0695316e5a3c1e38aa631a44bc8cc01a5a8cae060b7de435e54973b9354182d64e340ec9dc3e37f8b2bbaaab23608b86827991df4367839f443c260c1eb77f41159f69592c3eb37c21a531afcd34036a13a145e9cb1039704b8e623359ea5c3a50f705118ea7d8b1063ec85bfcdc941e0235579e97856ee6f6bfe9c4d0d161b5662de26a2fbddc530ce928a98514903c63a3476d6ebe68e2503e7bf255691fbd8a006e9c77f5a4ad9e3e8d21a56bc4f7bc90d61ebb31eaa4dce49eb9a8069a584ae35266a4bc4af970861d2e9a0df7b87e8fc8b597e73a85d8eec91def6057d7a77e8f859ee9ee07ef2fc2260660e59ee16458eafbd7bab979ed9bf72c1c27cadc9011f563aeed9a4b2f09ce5455857bcf3acbe0e1cf15537594469c777a885ad24ac5a8c894a8257c5212fb46184ad7f280bc25600129b25bf941460fcdd2e45a0216f1f2fec84d4792a15f877d15c649991ef998621a50a04251257ed6ccd803fdbc83ce4b5c4d7e8ee4487848768384e0b3970ec899dfb423560e755c4716deb3e188ca780cc7a97e8fb80076e9c44c6b7575725253ae4605844c3748bf90c14f17ee80a42fea450c05eb1f251eb09855f2ab368047d68cd7f4b8898d0113d52117375eed77707e7abff8554dc7da30cca674ac587198c0f165a47a5db79e9cc1c7bda9cdc36b94f14141a307850062f4d7208b8c612691699e3e3c16c3fc2dc6604e0fcb27acd6b5d326d2663a3f819589176c70423bc4dc97f402ea5053f2aa068d42c371d327339bd8637472eb9be88e409688a53bff392a8891c96c7804998d761fcf34dbf8da8cb17567f75ab4be29ecb6c33c85bf572b645bb7b226e4b99ccc0d959e6d809ec45b70cefe5ada611ea14962c9788d3e15100872669baf3c07a424cc205db97c0e2f72880a40a48d3820f89a16c7ee9dbc44bea18d787e0b730093a7b1fe4f9fd3e50128c61f6d4179fde88b02629629bf9f56f15abe3860ad2c99eaee9eba2310ca898fee5103f7bcb11ec9f4154007cf7d2b51a173331576526665d9f90879a2122b2bdef3a2cff68e57b146ae6d99a4e247d0daa693ef0a07aadc1a4b25ce5e33352b86e92343985884284c9646d0f8ec766552f75092a8b613870386a8b77901f01cddd76b4761e74519b24b851a570b5de8ca954b2c7df0fb314b6fa550e8e49713a28358e399afb3b9199496b229bc55644ee8e4772f1e00dc53886ade4932acee5cfd079707bd1d204c58360f26434fb158b53c1c4a51b65703f123f8090fe42dc48dbd3469a7d4bf1958203adffe46dd39084b66c789517b4438ed9415946ca552d523fa6c71e3302c3552f140d62d41cf3580e5e8500674cbb7d9ddd849d1ddb1d48ef7fd92f363e5e5b6a95b0c67b37e7e5e6a4dec9d8d56e577562eecec955cb6f9925c81cc165634018ab142c519ddd54f358356cee2ba50840303132333435363738393a3b3c3d3e3f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f"
+ ],
+ "measurement": [
+ 255,
+ 255,
+ 255,
+ 255,
+ 255,
+ 255,
+ 255,
+ 255,
+ 255,
+ 255
+ ],
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "99a41b9c201fb42526a6b32120318aa9",
+ "e6561983a3deafaafe0f608295f8d291",
+ "3bee6ed96f607eb1be6e96868876fc78",
+ "61b395d3525f0ad74886c2fe7b3b1cd6",
+ "744b6d37d41457d6f04fd418888cf36e",
+ "423dc86d052be058d0a1e12384ba0d89",
+ "9586cb376b5ebfeffb6c22af3bbd02bf",
+ "940485979cecf00983ba97fc12b2235a",
+ "beccf9f2533dc942eca94540b9a0e745",
+ "05428bc80d2926f6ec11e3615a4e0c17"
+ ],
+ [
+ "675ce463dfe04bdabd594cdedfce7556",
+ "1aaae67c5c215055e5ef9f7d6a072d6e",
+ "c5129126909f814e2591697977890387",
+ "9f4d6a2cada0f5289b793d0184c4e329",
+ "8cb592c82beba829f3af2be777730c91",
+ "bec33792fad41fa7135e1edc7b45f276",
+ "6b7a34c894a14010e892dd50c442fd40",
+ "6cfc7a6863130ff660456803ed4ddca5",
+ "4234060dacc236bdf755babf465f18ba",
+ "fbbe7437f2d6d909f7ed1c9ea5b1f3e8"
+ ]
+ ],
+ "prep_messages": [
+ "190c0ef07d6f2cbd1bed12d71b5f118d"
+ ],
+ "prep_shares": [
+ [
+ "e6a4fe7264f95c384446bd51db14f78ef1ffdfc96013014baff7bd0684b8ff3360f6a2a23b1d7b71796b845eeb21e7d1682128dc8b87837803fec0e3bd6c548d5cc9c041a482cbfc38120743a2f1a0985053eaeccc339c56c2edf76451e22c9a1bca6ab2181850f44d904964d1f227a970e70b59328028ddafdb1649a8c4f1f7cbeb366e42f8603a7b627f7519f25617a33726ede06ab438714b4bd3cda025dc2bcab64974eb02b9e2a23bf0cab4e5ef1e87bb1a72767098c768a20e1090a712ed38c1e8803fd18181cc0069355b40f5f98ff3cbd2b31022df719d9c660e7d5bab239819730ad9165a38641379444093ba148996166ccf5e2826bddb7a7dc8eda4dd5928b12e1ea715538788804b5ce231dfc98ff45027ff8cb1f92007f2339621201a7dc483c83bb6df082105cb5f5d9eda1f847b3d43b4e16502062d5a4a158f651439d7666c683a6283d308c91b25",
+ "1b5b018d9b06a3c79fb942ae24eb08714d70b996c916bc83063fda0f9ec82780d0f6e371501957e00796db78ea8f338874640cdcd2ca1496ea55d62aa3709498457ff43ce93631942f823f5bb9c6d133b1f04db02cc745245af6a31e4dfe912a696c9cac47da44cdc3eb3874eb0d8c89cc603d2ceec1678996fe311424b56e65221e486a579b91cd10a0fb38fec7214a46c9ce7f890afed4efa31fb5ca8766ae31df892814bd8b8f828d029f07f6acf47620063e9fa98ab947b376e068b9be689522b897e840e315a614a89d0bcf8296332affc2e6b52204d1d7994260e415c6e92147fc6676eedff2cb239581f4675066213f3b302497abbf10c1729131a0125a65dbc5252a8da8702a979aca3e211afb99a0edb76c92f2bd5941c6071db19d93cece17c818d8fb1151f1ef1b32745f09c155f42259af588b424c1dafe8d0e90a3dfc6f55bb428773ce15071cb720fb"
+ ]
+ ],
+ "public_share": "8f651439d7666c683a6283d308c91b250a3dfc6f55bb428773ce15071cb720fb",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json
new file mode 100644
index 0000000000..af95aac5a0
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json
@@ -0,0 +1,146 @@
+{
+ "agg_param": null,
+ "agg_result": [
+ 45328,
+ 76286,
+ 26980
+ ],
+ "agg_shares": [
+ "598a4207618eab7bd108c110106e84cd9498b18b79cace4b383274cef0ab280eae74b07f9a7f087a89a4f51f3adb97ed",
+ "ea61abdf14b8477f6de1b47a18ac778ad0149cb235e666ccce92a9246a2858403e5903013ab179dcf6719d10fccdf589",
+ "cfc412198ab90c0589158a74d7e503a89b7cb3c1504fcae7dc3ae20ca52b7fb17a9b4c7f2bcf7da947e96ccfc9567288"
+ ],
+ "bits": 16,
+ "chunk_length": 7,
+ "length": 3,
+ "prep": [
+ {
+ "input_shares": [
+ "db3fe5357f56f6cfe9a0a34ed08e231765b423d5670741c151f7bc28cb15c4c8678df838b6ef52d74c5bb5e8c8c607b2674c024da705b558163e3127ae09f4a5c061dc6129d71755ec77b5d5a2fc088541fa612bf1273d94718e0d28b654ea9524148bf910a5c6d55b596e323d4f55ab62b0851e0985987ea39cf500bfd4d5faed37f519120b574aead76d706e149c54b6e790f7c35553f830f67553ff3f08db0f5d8a2e157f9c1ef6eedcb845a19fffd3f9c6e705b035c95f3ebe51219a91c6ffff33c52b8fd7e90417636bb443c98c3e6933b1599c09151095fff99d7af43d327fcf2afdc50144e5b8f5b1ff1f8cc8d442613dc76dd48100dd4637b38fc5fa239333c6d1e2bec2132a3e3cf5fd11f7e823fbf275071a71535740ac64de88b83ba002bd4490de5c7540e454331d46ebb825725b964ecfcc5c3bb076bf57fd819c0bd788b4495da456594503c8f5b618c4713e957a7589b151ead3a27ae80153b0fbc219f5524eec22dc91aa56ed87a71d2b6e6b857ee03a700fb5af5b23513fb73023c2764f74b40f00e44e2e7acd463e7f38c783a898015b51863659502645e7d7fdf14abf76c605d42e523a540f40e816db420d599cea502baa131d6b61b8aed86a09b9f7587e107c6fc9043d72644153b54234146535e9ec25223d293f4a2a4481b5649a7625116aef735ecbbdd9e99076985e1c424e91c29ae231c01f2b12007c81a9b3c72a84522014f4cdcb949bb9850b62b1c9c5d35f0fc3dbf4f637ae85d3082ffd15890041a455a5dd38efca02eb415afcd2c93a2b6ab46c88847f380affbc3a9af4745718044c9f23978026af66871e610e1beee21360fc5ae4f05dd32689328216db3c512fddacefb2713187f4f5069170db246ae0cb13b3b88f8c832d11ae6b26fe6f0f89b9a9056c28ac6160134ab1af4be606a2ca7d1256cfc223c823ae4921ba11aab9e4f8104885b2d17ac7826ad06b0fd0e54397e67179755b29c6d724f3e9d9a6934619ef0bfee72cdcf0031672bb6e55be53d3ef0fccf7beba7dea636eff7fdecee99f60a0bc5d5cb11e24683ca8085cf218a3607c3f5a051c33717f39834435c5c542f42bd9c9d9103173cb11436e2a626e228415924037ee4a4078f0d9a5bd0fef1dc9415740c4b32f51d199d912907a103c7e25928acadfb4dc1e8d32c34b4f0100bfe95feb850e0d7b359aae9a332913d1f8eca0d1cd93dee2f4b75076536037ca2b6bec5003ae89e03db9813d0dd2165bf50a836ca492f05750a3b5fff4d089fc54fdf225aa112d72ab021b0cef1b8d31c852133fae501406668db0cb4c3589934bde5ec218876b0dc4087b61d706234635d5d9065787a21ed895fb7c05902a554f879dae44d0905292815c5ae2338be1faa9249c7a0530623084f3a7c8bdddf0cf0123597127f9416dcad20f71bb98bf0078d9d20d4ae09c46fc1882cae650f48bebfcc443dfd7963c9944f0ce0a9bda91fe8b8f307fb3da48e573224324c35e867599262fa281ee6bc537928747dd4a1370440bda92e28eb5ba088481e476f19e2a03a7c7619eb79e2184afff7bb585319fa6a58dbfb8862f5d193ccfe0e4aeb58c1633d9e983861d4976615b11514160e5d77ed9a3e2179893c65d9de03813d27aec3485d96098764ee1e7779d47850e6b96ba064b4b913da8390416afb16719a38b725d5b27db351049bdb322cfc93a905a108d07e49764eb5f3f66fe8f3aec7580606162636465666768696a6b6c6d6e6f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
+ "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f"
+ ],
+ "measurement": [
+ 10000,
+ 32000,
+ 9
+ ],
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "7e6f6b02cb848e7e3c5840b05acfd699",
+ "dda1902ed398efc35ebb269aa58e0d5a",
+ "cd03902ade7fad281b8cfc5f1349ddf9"
+ ],
+ [
+ "f9758e4a5c3d6d2a1b4b3c7e5d397d83",
+ "9bb1de90bc4c2244e630e3b6780dc86a",
+ "15735600bee57d499ed0890554ef5183"
+ ],
+ [
+ "9b4106b3d83d0457705c83d147f7abe2",
+ "89299140701aeef79e13f6aee1632a3b",
+ "298919d5639ad48d0ea3799a98c7d082"
+ ]
+ ],
+ "prep_messages": [
+ "429a80b8b2f73ce00d066d0c0e540cfb"
+ ],
+ "prep_shares": [
+ [
+ "574ce7d79da173f0652f7d1e699d5fe756c90210a093e0c893b2b7510374465c35014cdcef24d6044e63ece20db245a982c5a00f16227c6530ade1e5fd4e925f8505da8add8b97fa3216790b11b46ad76bec9b6b3f0c1e9b8284f81f0c83bb690033a6c4a8e962d34b51e27aeaa33a0d41e852a78af96c6dfba1d36364996df4eaaacc993e5633528fb0d106753112ecbfe7fec74ea75d2d1b16267ee1b3d6145e746f19d0840be93d38baa3a5c7d42c6a05311092b4fe05e38aa9e24fb9c99a5aa4bc485752d049703409ccb6656150eeb34048030f19b5353df1e23b376b1c4be2925034d8b7a847690ca7f0c4f338c30b461d9e8c1a1de094fea67efdc4d3da6e96a43cd51241f828f01028e5b0e9",
+ "21c1bcbdd38fa9e8d86272f8c66c2b16768d00a79945c6014bb55374174299e4e7967f80f8df630af6f5aac52647e36437d08d0aabfd406b0984557a1a69b28f3706033ed34cde0d6bfd39a81a0fca04117cd7ac37f06815b6dc1a356454d669b0e7871238066b280f7c09e0365214235dcd39dfdf6e46ad632d50101029ff490910f79b754d99088059cf1d5da8a791b1a51d8d96fe428bac8f547dc2e9632c4e00b3a6e95d36d03d83969de2989c120aa3c54040bc7f49f910df6c5e4003721a03a3f1d743c457d0cbadbc0bd25ca7f464f404a1f39f852872314ca44ea35a32c150fc25094cd2147739d1a1e03e8f6e636170867687f1c511dbb936fb521b6639ebc145f03b6d1751fcdf14c5108e",
+ "89f25b6a8ecee226a56d10e9cff57402c3198e7629d8c4152d9319172f0df07727e57ad2e170a0a32cf748f9dc0fe45597d26743a7254aebceb106f0f172a5f2aecf563787471334096bb202a07696b7e0cf22d7dda51e1cf0fcc3d0d5d0a9f4c7ea8999a3f4edb4bd7d3ecb26f2137fe9d443f34248265c25a12290866130c9ccd2dd263bce8b78c14ddbc0c94ed85da0c401ac614c789711c0e2db27ad6ef59aabfdbb6bf174fff2358c4a8d6b7e6cd2a0d41883f3563506755a920aaa7695cf35ffecb0b195d01d1d86ee1a015df5de2f145b61ef3bc5fff1634d911a69540c03444dcb256fa982b24cf064dd6e4dffe71f8d5648cf336acd9485f9d70e24f2ad6b4aa14511e7b660c0866f73672a"
+ ]
+ ],
+ "public_share": "da6e96a43cd51241f828f01028e5b0e96639ebc145f03b6d1751fcdf14c5108ef2ad6b4aa14511e7b660c0866f73672a",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f"
+ },
+ {
+ "input_shares": [
+ "db3fe5357f56f6cfe9a0a34ed08e231766b423d5670741c151f7bc28cb15c4c8688df838b6ef52d74c5bb5e8c8c607b2684c024da705b558163e3127ae09f4a5bf61dc6129d71755ec77b5d5a2fc088541fa612bf1273d94718e0d28b654ea9524148bf910a5c6d55b596e323d4f55ab63b0851e0985987ea39cf500bfd4d5faed37f519120b574aead76d706e149c54b6e790f7c35553f830f67553ff3f08db0e5d8a2e157f9c1ef6eedcb845a19fffd4f9c6e705b035c95f3ebe51219a91c6ffff33c52b8fd7e90417636bb443c98c3d6933b1599c09151095fff99d7af43d337fcf2afdc50144e5b8f5b1ff1f8cc8d442613dc76dd48100dd4637b38fc5fa249333c6d1e2bec2132a3e3cf5fd11f7e923fbf275071a71535740ac64de88b83ca002bd4490de5c7540e454331d46ebb925725b964ecfcc5c3bb076bf57fd819d0bd788b4495da456594503c8f5b618c4713e957a7589b151ead3a27ae80153b0fbc219f5524eec22dc91aa56ed87a71e2b6e6b857ee03a700fb5af5b23513fb63023c2764f74b40f00e44e2e7acd463e7f38c783a898015b51863659502645e7d7fdf14abf76c605d42e523a540f40e816db420d599cea502baa131d6b61b8add86a09b9f7587e107c6fc9043d72644053b54234146535e9ec25223d293f4a2a4481b5649a7625116aef735ecbbdd9e99076985e1c424e91c29ae231c01f2b12007c81a9b3c72a84522014f4cdcb949bb9850b62b1c9c5d35f0fc3dbf4f637af85d3082ffd15890041a455a5dd38efc902eb415afcd2c93a2b6ab46c88847f390affbc3a9af4745718044c9f23978027af66871e610e1beee21360fc5ae4f05ed32689328216db3c512fddacefb2713287f4f5069170db246ae0cb13b3b88f8d832d11ae6b26fe6f0f89b9a9056c28ad6160134ab1af4be606a2ca7d1256cfc223c823ae4921ba11aab9e4f8104885b3d17ac7826ad06b0fd0e54397e67179755b29c6d724f3e9d9a6934619ef0bfee72cdcf0031672bb6e55be53d3ef0fccf7beba7dea636eff7fdecee99f60a0bc5d5cb11e24683ca8085cf218a3607c3f5a051c33717f39834435c5c542f42bd9c9d9103173cb11436e2a626e228415924037ee4a4078f0d9a5bd0fef1dc9415740c4b32f51d199d912907a103c7e25928acadfb4dc1e8d32c34b4f0100bfe95feb850e0d7b359aae9a332913d1f8eca0d1cd93dee2f4b75076536037ca2b6bec5003ae89e03db9813d0dd2165bf50a836ca492f05750a3b5fff4d089fc54fdf225aa112d72ab021b0cef1b8d31c852133fae501406668db0cb4c3589934bde5ec218876b0dc4087b61d706234635d5d9065787a21ed895fb7c05902a554f879dae44d0905292815c5ae2338be1faa924d95673e4d904cc4215bfda1c9c0ca55fa38af6b696ef4f1a452d6870c1d82fa311cee716b00a11313a0ba7c4038e5e195b5e1d3502a44d9a9ab1636394821eaf157e6c77316f6e26b98ca81220752396890cce205581aead60bfcc3916a6c2d6d360c7e845e70c8c68e4acb5eb2877a9a7c7619eb79e2184afff7bb585319fa669b151040f5b15cab2d8c3a50379e9d9e8bf1ac6319bc32e489f64793f882d0e3e1906ac04d47eaec15c20c503d007d09d6a9b032d0f9a8b3d95447fcb1d4b7334b95d873a171f876dcc2a62a62af58e10802f88742027d3d27b9d72fea73dc84906d3dde03299dc3e033651406229da606162636465666768696a6b6c6d6e6f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
+ "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f"
+ ],
+ "measurement": [
+ 19342,
+ 19615,
+ 3061
+ ],
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "fc936b02cb848e7e3c5840b05acfd699",
+ "7c71902ed398efc35ebb269aa58e0d5a",
+ "b90f902ade7fad281b8cfc5f1349ddf9"
+ ],
+ [
+ "f9758e4a5c3d6d2a1b4b3c7e5d397d83",
+ "9bb1de90bc4c2244e630e3b6780dc86a",
+ "15735600bee57d499ed0890554ef5183"
+ ],
+ [
+ "9b4106b3d83d0457705c83d147f7abe2",
+ "89299140701aeef79e13f6aee1632a3b",
+ "298919d5639ad48d0ea3799a98c7d082"
+ ]
+ ],
+ "prep_messages": [
+ "e129a25d4bb45ae8f5ff9d97d72f6d86"
+ ],
+ "prep_shares": [
+ [
+ "574ce7d79da173f0652f7d1e699d5fe7865e23ddbe20c3646ca1d41f8d0d123a831f3b90d5b917779d25ca7aafd5c903b4b996399d29d3af9fcb1c6549c36a639e45759c53b3bf6a1014fbdf7ea4d5d3726af036968d80aed52a7e48dc0a06ee66a6e34bf38ce1083cb56057cfbe98ae9138b635f2b64b60daf151ce214d33b99062b201463033142fc734e4c163fd8a5d2ea4503567c0a9739027883ab0b37ece24259b5ddcc88b3e272bbad93857a52a48abf5eb9fb135ccaf27fc0379f49f203cc0f5a4f14183ef3b935076893f0bc56dcd9b0ffbf8f711039bfd3fe894c0c007d8702db35ded99e2605fa3d9d4b70f6dfcdea9883c839e2a91aa358a38359e2d3cb20272a3ecb52e1edc17a48940",
+ "21c1bcbdd38fa9e8d86272f8c66c2b167ab67945b007b46536ece1073c4ac681e7967f80f8df630af6f5aac52647e36454b87d9c2d6aa2d4c2186b47a6bef8153706033ed34cde0d6bfd39a81a0fca04618cc495486eccfbcba29099fb101871b0e7871238066b280f7c09e0365214230a741f21625c730fae5e37a7a3cbe0c00910f79b754d99088059cf1d5da8a7911221c38f8a6f79877dd34cb4fd2990744e00b3a6e95d36d03d83969de2989c12e44eb18d466e8954631aa1dbe84e78111a03a3f1d743c457d0cbadbc0bd25ca7de998de032e28d1a8ca88d584ad7a64032c150fc25094cd2147739d1a1e03e8f6e636170867687f1c511dbb936fb521b6639ebc145f03b6d1751fcdf14c5108e",
+ "89f25b6a8ecee226a56d10e9cff57402a204a4e83b6b78a5aa32317f68443eac27e57ad2e170a0a32cf748f9dc0fe45524ab48c23361b4b327d498fd702f23b6aecf563787471334096bb202a07696b7eb8d79193777673f93547e14b15adb2ac7ea8999a3f4edb4bd7d3ecb26f2137f5e2881b52fd4d3ec538b978ebba54312ccd2dd263bce8b78c14ddbc0c94ed85d38fdf5fb7ee2eeb21ef066fdde8a94f89aabfdbb6bf174fff2358c4a8d6b7e6cdb080672d15b2ae06d6ca64387b56c00cf35ffecb0b195d01d1d86ee1a015df53e771738ec1faca3fcdfe4f033977dd50c03444dcb256fa982b24cf064dd6e4dffe71f8d5648cf336acd9485f9d70e24f2ad6b4aa14511e7b660c0866f73672a"
+ ]
+ ],
+ "public_share": "9e2d3cb20272a3ecb52e1edc17a489406639ebc145f03b6d1751fcdf14c5108ef2ad6b4aa14511e7b660c0866f73672a",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f"
+ },
+ {
+ "input_shares": [
+ "db3fe5357f56f6cfe9a0a34ed08e231766b423d5670741c151f7bc28cb15c4c8678df838b6ef52d74c5bb5e8c8c607b2674c024da705b558163e3127ae09f4a5c061dc6129d71755ec77b5d5a2fc088542fa612bf1273d94718e0d28b654ea9525148bf910a5c6d55b596e323d4f55ab62b0851e0985987ea39cf500bfd4d5faec37f519120b574aead76d706e149c54b6e790f7c35553f830f67553ff3f08db0f5d8a2e157f9c1ef6eedcb845a19fffd4f9c6e705b035c95f3ebe51219a91c6000034c52b8fd7e90417636bb443c98c3e6933b1599c09151095fff99d7af43d327fcf2afdc50144e5b8f5b1ff1f8cc8d442613dc76dd48100dd4637b38fc5fa249333c6d1e2bec2132a3e3cf5fd11f7e923fbf275071a71535740ac64de88b83ca002bd4490de5c7540e454331d46ebb925725b964ecfcc5c3bb076bf57fd819d0bd788b4495da456594503c8f5b618c4713e957a7589b151ead3a27ae80153b1fbc219f5524eec22dc91aa56ed87a71d2b6e6b857ee03a700fb5af5b23513fb63023c2764f74b40f00e44e2e7acd463e7f38c783a898015b51863659502645e6d7fdf14abf76c605d42e523a540f40e716db420d599cea502baa131d6b61b8add86a09b9f7587e107c6fc9043d72644153b54234146535e9ec25223d293f4a2a4481b5649a7625116aef735ecbbdd9e99076985e1c424e91c29ae231c01f2b11007c81a9b3c72a84522014f4cdcb949cb9850b62b1c9c5d35f0fc3dbf4f637af85d3082ffd15890041a455a5dd38efc902eb415afcd2c93a2b6ab46c88847f380affbc3a9af4745718044c9f23978027af66871e610e1beee21360fc5ae4f05ed32689328216db3c512fddacefb2713187f4f5069170db246ae0cb13b3b88f8d832d11ae6b26fe6f0f89b9a9056c28ac6160134ab1af4be606a2ca7d1256cfc323c823ae4921ba11aab9e4f8104885b3d17ac7826ad06b0fd0e54397e67179765b29c6d724f3e9d9a6934619ef0bfee72cdcf0031672bb6e55be53d3ef0fccf8beba7dea636eff7fdecee99f60a0bc5d5cb11e24683ca8085cf218a3607c3f5a051c33717f39834435c5c542f42bd9c9d9103173cb11436e2a626e228415924037ee4a4078f0d9a5bd0fef1dc9415740c4b32f51d199d912907a103c7e25928acadfb4dc1e8d32c34b4f0100bfe95feb850e0d7b359aae9a332913d1f8eca0d1cd93dee2f4b75076536037ca2b6bec5003ae89e03db9813d0dd2165bf50a836ca492f05750a3b5fff4d089fc54fdf225aa112d72ab021b0cef1b8d31c852133fae501406668db0cb4c3589934bde5ec218876b0dc4087b61d706234635d5d9065787a21ed895fb7c05902a554f879dae44d0905292815c5ae2338be1faa924c67468654e6880c69828616e9d011132b27aa457c6dcbd554ec4374ec882357626d6f323bfadc3b23ac6d390e40f8f2008e462458194a1f9d63e0b977f4be907f34aecb0a12ad13b95bfc858c412abab064106faf149e9b25cb557c12a54e6ae957097b8b11f174094d134cc213bf98da7c7619eb79e2184afff7bb585319fa67b935c839af760464b6f3d5402847d07d9cf6c2502ae55f33e08959b38de273b2911fa9ef530cc2cc1a1f3f8224ed7c8efe455f3ad1e462c1d089d4be054801a56ecdd4dca5bbc7191990a1c028d6d79934bf7aed757eccdd68512ebe9f919f087f6020e75fa8e281316ae3a0a50a7f5606162636465666768696a6b6c6d6e6f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
+ "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f"
+ ],
+ "measurement": [
+ 15986,
+ 24671,
+ 23910
+ ],
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "e0866b02cb848e7e3c5840b05acfd699",
+ "3c85902ed398efc35ebb269aa58e0d5a",
+ "2a61902ade7fad281b8cfc5f1349ddf9"
+ ],
+ [
+ "f9758e4a5c3d6d2a1b4b3c7e5d397d83",
+ "9bb1de90bc4c2244e630e3b6780dc86a",
+ "15735600bee57d499ed0890554ef5183"
+ ],
+ [
+ "9b4106b3d83d0457705c83d147f7abe2",
+ "89299140701aeef79e13f6aee1632a3b",
+ "298919d5639ad48d0ea3799a98c7d082"
+ ]
+ ],
+ "prep_messages": [
+ "62b225fa36c2cbc5896a40b1360f0ce0"
+ ],
+ "prep_shares": [
+ [
+ "574ce7d79da173f0652f7d1e699d5fe74186663960507c6ac01fafa7046bc88a5c4146feb2861bcb9f3ad8bb294028812c631ce293cde53b6fc09c6bdbcb839ede24393da7141c26a1a594ccdab1fba7ed2b2e38faf952343a2871b854699ff9533b4f88f9bc7e935cc62e3dea6efff7e2d396d8276281db4e789dc30d3fe01c397b64d685e87e7dccae1b7636b1a1486293f29e5c3433ee211e1f66ec5a7a0f1d79ed261735f53fdb51cbd51c1a6b057bf2845e34bfa091f0f8fd7565242c8b247064dfac759c1c646d085bca6ad0845b9cc2093fd49d493ebf2d48a90b99815f0bd165f56453c6b27dff04d3a161d9fc0bc4177e423791d4064faccfbbd3218c6815a24e38fb752efc839e59043cbb",
+ "21c1bcbdd38fa9e8d86272f8c66c2b16ce028e4f460022aed5c3e47011081a9ce7967f80f8df630af6f5aac52647e36481e9f91e23071ea7d98af1e76a9dd5823706033ed34cde0d6bfd39a81a0fca048271f4696baf9fd2e8a37e7b5e1be75eb0e7871238066b280f7c09e036521423abafe75c716165e86a573bab2e46efda0910f79b754d99088059cf1d5da8a79146cedb83d4b00d96b9871a66d73147984e00b3a6e95d36d03d83969de2989c12e222e1015cad12a219308643f8e1c1511a03a3f1d743c457d0cbadbc0bd25ca7cb9908430d0604df4468ab54b36f81ef32c150fc25094cd2147739d1a1e03e8f6e636170867687f1c511dbb936fb521b6639ebc145f03b6d1751fcdf14c5108e",
+ "89f25b6a8ecee226a56d10e9cff5740294125c8939f32731654a11d38906f94a27e57ad2e170a0a32cf748f9dc0fe45503b71c971d945163a09158bc06295e77aecf563787471334096bb202a07696b7f7c98da9fca5e00ff5ce6eb4b9d52274c7ea8999a3f4edb4bd7d3ecb26f2137fbd6dc71ed450947219eef408d88af7dbccd2dd263bce8b78c14ddbc0c94ed85db3255ece317e82a0f3b4a873e9b72b949aabfdbb6bf174fff2358c4a8d6b7e6c545473085c7c9b5c2aa9d290b01670ffcf35ffecb0b195d01d1d86ee1a015df56465c0991dd00dfda6948958c3ff38510c03444dcb256fa982b24cf064dd6e4dffe71f8d5648cf336acd9485f9d70e24f2ad6b4aa14511e7b660c0866f73672a"
+ ]
+ ],
+ "public_share": "8c6815a24e38fb752efc839e59043cbb6639ebc145f03b6d1751fcdf14c5108ef2ad6b4aa14511e7b660c0866f73672a",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f"
+ }
+ ],
+ "shares": 3,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json
new file mode 100644
index 0000000000..4dd3798668
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json
@@ -0,0 +1,40 @@
+{
+ "agg_param": null,
+ "agg_result": 100,
+ "agg_shares": [
+ "0467d4fd6a7ec85386c9f3ef0790dc10",
+ "61992b02958137ac5d360c10f86f23ef"
+ ],
+ "bits": 8,
+ "prep": [
+ {
+ "input_shares": [
+ "1ac8a54e6804d575f85957d45c728487722ad069fc1ed413da970ea56ae482d81057f898a367319a14f402d072c24bb71aa13cf4f9cdcd731e779aaa4a5d561ff40e41797cf2d43308849ff9080f846c2e78e2736a1d3cdaa68f3ec890d633cc13f5bf96466d3f02f93612bc827ff53357d52ae00dd234b9758f2cbb7aa9682d06041da1507aa12446de554a945924f445d3715279ef00f4c55ae987cec4bb9f1316efdc8737b7f924d046a6b8ef222b0dc1205ce9b464076fa80d2dfe37af4d836d597ade7d51e18e9c95d13158942d249efd0a1a226759e4bc1d46d3a41bdb227703fe0a7554cf4769935bc99cd1f35b274ecec240816af4915c7abe3e16b7be5ab5e105f9ae7b2e683191c9400cf99ab0c687e4929f87e6e64f712ca02f07a1b29fcebdbfde7655797f9c1b6b3114420d8a19736ae614116782278b7a71f9ef6928ad44ce588644886523d6fbe0b7bbb47248edbaa0b5ce33f74a07005e2a6842eb2c05778e170112f6e6a5f206d7830aa122e29069dcb4a4c064e63c29b3c6e2b22dfb5ab344ca0f1be8e8ce36d26435413de2dc4f53e158ebb8478b4a98de014a688db9470106fd7e73a65c2e656b5a627b5584ca0594ba10cc39c5612bcef576625c37c5249ad5c04e42c66d6a9653c4ec47e2bcd860870bef64f812974654f17f77c08eaa395803d33bdf31db17d76dbb9d2407d7c4f9efbce274542ff6aa0dcf188803eb586108317db430ad517ce7cb0f56d225c835161eb348949ebe253bedc338c6b939ce837561f01d7f0304963eab2a28b38c36bb169a4ee0637635818bd5e4798a8319152a2678b0aa7b837cb0f24df6148ae2c84b78db8892f4415f90f3804e7a29cdcd32a0a8625fd20aca47ee0ef12ebd6138b3534a1b42303132333435363738393a3b3c3d3e3f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f"
+ ],
+ "measurement": 100,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "0467d4fd6a7ec85386c9f3ef0790dc10"
+ ],
+ [
+ "61992b02958137ac5d360c10f86f23ef"
+ ]
+ ],
+ "prep_messages": [
+ "0fd2bb14ac123437f6520fdc4a817934"
+ ],
+ "prep_shares": [
+ [
+ "61428b8d7e326827ac832bc4074ad61652efcfdb8d95b6f06b83dd9f5d55ce9f142d1a1fd437eb8c84581ad15dcd9a57417942e63a1a46e6b0ffc8b6d6300f7d",
+ "a0bd747281cd97d8377cd43bf8b529e9eb5e4b1153111bd6cd06aa3a5493a6da4470f696b9afff52ec10fc00040e4538470fdb8d3e05e188aba2b16e24c71b69"
+ ]
+ ],
+ "public_share": "417942e63a1a46e6b0ffc8b6d6300f7d470fdb8d3e05e188aba2b16e24c71b69",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"
+ }
+ ],
+ "shares": 2,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json
new file mode 100644
index 0000000000..8e7163ae2a
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json
@@ -0,0 +1,46 @@
+{
+ "agg_param": null,
+ "agg_result": 100,
+ "agg_shares": [
+ "b3916e4086c52aa0439356b05082885a",
+ "61992b02958137ac5d360c10f86f23ef",
+ "52d565bde4b89db326369d3fb70d54b6"
+ ],
+ "bits": 8,
+ "prep": [
+ {
+ "input_shares": [
+ "ec863a16981eee70b71dc7d396a85ba412dcb4e86769d6db0c60f668456f5d6fbb844d9503580fb7b662bbb2ed7d92002e6ab4a2d31a6f611cbb5c48ca6df69811d536f74a3ff61eb29bd9b9f1b64c35eddd5c4ac97376057883a317f2989b545a682775f948f28f80f366f36b4eb90f931bca79e229eae377102295d9c46da2e239f74f045084747039c0a955726b4258bc0d14da7474bea6cd136eb5e55e9531e6a68703003a64943a5650b16674c82d9c4b526a7ed3d69f8f13ae83609cf056f3fed8d6593fdad7b367d2d248413072651073ea91b8162d42af168698f0f0928c8238b2df218e26d004d2bdb5f9f20d0a43c0286d08cfc26971f282992f82ff14d51cee3e0f3fc7411869c2176cabc6b1a68e33ff5eb217490de9f0d85cb84e9115bb7e208a190d25bf9cc138485892802a50b790ba6f45804de487a3353e54b5471adb5ab612d9ee6416649e136456215503637e0daab367149bc5cdf02a2dabc2790f84cadec1510263fe6aa27df5df395b7a241777a8ed28da27276b48f599dd895a005746cfd1f3c874e6f52407f4c417934d7091685c0b38b1d76b398ad263ec73f4f811aed38febf67a19a001a2c7ab8071f986939713cccd146c7a049c5129783359fcf86410765028fbfbbe62c2474a6b75de0ba49c037e07946deae971207f4f74b8b1d6a7b225eb0b66ed1f3878bc14d9d7a38b2162247b7ed9ac3df6fd2a98a3e4bf2855c8fb13f39487481fe03f5b5cb5123d11aaef180ff8ae69709322459a01a72e9304295ae5721d6eac6dae140677d0dd60f192f0475bacfd131d4ff3393238caa00fe0847c3a43c97a31f84f58b3c7487c5c0a09e85b39ed4b69fcdfa071da15216fd5f1fad125328e40689acce1a6cb113c2a16f599606162636465666768696a6b6c6d6e6f",
+ "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
+ "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f"
+ ],
+ "measurement": 100,
+ "nonce": "000102030405060708090a0b0c0d0e0f",
+ "out_shares": [
+ [
+ "b3916e4086c52aa0439356b05082885a"
+ ],
+ [
+ "61992b02958137ac5d360c10f86f23ef"
+ ],
+ [
+ "52d565bde4b89db326369d3fb70d54b6"
+ ]
+ ],
+ "prep_messages": [
+ "e385da3bc2246be76ff12a7093ecb45e"
+ ],
+ "prep_shares": [
+ [
+ "7b6a9ad01449ec86dc6736dced3ecd24b47ab2a3768908b10696d537f2b02c98cf3314686f94ac37c7d81b14fea51f784e037bbdd56b2ee8486757acad61db1e",
+ "40a478cd7376c1e9ea339ddcf96ab1a7eb5e4b1153111bd6cd06aa3a5493a6da4470f696b9afff52ec10fc00040e4538470fdb8d3e05e188aba2b16e24c71b69",
+ "46f1ec617740528f1c642c47185681331adc83aace0cc5ddd256cf295c93c64d207d424f5056a8d59748ba9e423c4cf5f560fb4c6505c9a773629e12f21ee230"
+ ]
+ ],
+ "public_share": "4e037bbdd56b2ee8486757acad61db1e470fdb8d3e05e188aba2b16e24c71b69f560fb4c6505c9a773629e12f21ee230",
+ "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f"
+ }
+ ],
+ "shares": 3,
+ "verify_key": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json b/third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json
new file mode 100644
index 0000000000..ea76c50ff8
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json
@@ -0,0 +1,8 @@
+{
+ "binder": "62696e64657220737472696e67",
+ "derived_seed": "9cb53deb2feda2f9a7f34fde29a833f4",
+ "dst": "646f6d61696e2073657061726174696f6e20746167",
+ "expanded_vec_field128": "9cb53deb2feda2f9a7f34fde29a833f44ade288b2f55f2cd257e5f40595b5069543b40b740dfcf8ab5c863924f4510716b625f633a2f7e55a50b24a5fec9155dec199170f9ebe46768e9d120f7e8f62840441ef53dd5d2ba2d3fd39032e2da99498f4abf815b09c667cef08f0882fa945ba3335d2c7407de1b1650a5f4fe52884caf3ef1f5802eabb8f238c4d9d419bed904dcf79b22da49f68fb547c287a9cd4a38d58017eb2031a6bf1b4defb8905c3b777e9287f62a7fb0f97e4d8a26c4e5b909958bc73a6f7512b5b845488d98a7fcedf711ada6972f4c06818d3c3e7a070a88af60dc0323b59f304935fbbbd3792e590c9b6bce7459deba3599c7f30fe64a638219dde4bde4b1a51df8d85c2f36604c44f5f188148e3ba1dca3fd8073240ee577ef322df19a13d9ffa486a6833f4eb2838a58746707b8bf531cc86098f43809276b5f02914b26cb75938ca16eafa73397920a2f5e607af30e62ff60b83e15699d4d0265affe185b307ed330941a41b2b628e44d9a19412f7d9513cacd7b1fd740b7708e3bc764a0cf2146bca7c94d1901c43f509d7dcc9dfec54476789284e53f3760610a0ac5fce205e9b9aa0355c29702a5c9395bf1de8c974c800e1037a6bf5e0bd2af7d96b7f000ff6ab93299966b6832c493b600f2595a3db99353d2f8889019cd3ec5a73fa457f5442ed5edf349e78c9cf0cbf4f65aea03754c381c3efc206b7f17447cc51ac68eceacab9d92b13b0bc700c99a26ce2b6c3271f7639aa72dc27bbd54984907abb10ef1047ef352d378ddae48bf381804c89aa1847f5027537cf6af1b30aa44cead6495e98ca5b3205d39beb49d2db6752a4e57158e8c83464002b0b4a9838bc381c1dbdc3e9a584554fb76671a15f907c0b395a5",
+ "length": 40,
+ "seed": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json b/third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json
new file mode 100644
index 0000000000..edafb1bd4d
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json
@@ -0,0 +1,8 @@
+{
+ "binder": "62696e64657220737472696e67",
+ "derived_seed": "87c4d0dd654bf8eec8805c68b5eb0182",
+ "dst": "646f6d61696e2073657061726174696f6e20746167",
+ "expanded_vec_field128": "87c4d0dd654bf8eec8805c68b5eb0182b1b8ede598cfb8d8b234038fd0492eb14268bbb2ac15a55d463c8227c2d4fae8607631d13157c4935c5d2d56e4b1e2bdfe0f80b286d82e631704acee29ab6f7acaa316d3623cc3371297604caf57bc2eafe72056143971345901b9fb9f95b6a7384c6a88143124ff693ce9e453675f87a6b6338a1e1c9f72d19e32b51f60a1d7469de1fbe25407cc338d896b34c5fc437d2551297027eeefca9aaccdb78d655a6c220cbc2d76cc4a64b04806ae893c952996abb91f6ec32b6de27fe51be59514352d31af4967c0a85c5823ff73be7f15b9c0769321e4b69cb931a4e88f9da1fde1c5df9d84a7eadb41cf25681fc64a84a1c4accded794c1e6fec1fb26a286712425bfc29521273dcfc76cbab9b3c3c2b840ab6a4f9fd73ea434fc1c22a91943ed38fef0136f0f18f680c191978ab77c750d577c3526a327564da05cfc7bb9ef52c140d9e63b1f39761648772eaa61e2efb15890aed8340e6854b428f16dff5654c8a0852d46e817b49bbe91db3c46620adbd009a0d7d40843c1b6b7786833d3c1ae097b4fa35815dbcfca78e00a34f15936ed6d0f5bf50fc25adbecd3adfa55ba6bc7052f0662595cf7a933dfcc3d0ad5d825ec3bc191586a1c36a037d1c9e73c24777825d6afe59774abdb2918c2147a0436b17bafd967e07c46c3d6240c771f4fd4f9b3fff38b294508b8af5a1b71385f90f407620b7aa636fd2b55435b3688fc26ad3c23b2ad48158c4c475c07eb58569a8d1a906452b82d582397c4c69f5e79d3082d03b4dd85b5277a8b44c933d52d168caae8c602376f5487670a172d138364cb975c569c9c2d79506746090ea8102907c91b66764fd8740ca7bd3acb59173df29b5fa7542e51bce67b97c9ee2",
+ "length": 40,
+ "seed": "000102030405060708090a0b0c0d0e0f"
+}
diff --git a/third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json b/third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json
new file mode 100644
index 0000000000..674e682ac1
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json
@@ -0,0 +1,28 @@
+{
+ "dimension": 10,
+ "server_1_decrypted_shares": [
+ "x+y6F2RY3Y+toaLjU4a0WDqmNPYgHW1w9Z3svKirr+qcM9eDxhWSUPY4/N3A3PYGVFKa+i867MSiouE7Fq3iykBPKPMuNS4T8e1FA2uJ5PJHzPEobZNQWKG6ax5WYEbpmongi1kgK656OkATqMcXHnmkBC8=",
+ "u7Z0TCmpNSSrhCf27lo8qfmg7bUAAvcAbnntWm2lcoIa7yE8h1Mi4H7YDLv+t/9pzCHhyXC248VnyLWasFEfZ7wwzI2D+3U0XBjWZIXXufRvsfo8ZlUfbRtwYXYYwjxB5FV0jksJvZxaOFuNGP0IxIR1J1M=",
+ "iUazYtUedl8JNc8oSU+cLagEfJIBvQC7yUUjF5BhKyvDk54SWsCNbCTM8GigTWj5fobgebFqDlMbWaPceeO1F5S+8AwSpH8zSo20sckmYpeia8daPRDQu27e+ijLtxGkig82No/SniO+PPG96/xJ3e7wO74=",
+ "WUkmvNOpVZVb08GsZZuN7iqg6eUusF3wZY6U0NA8cAvzDiysSByNROP+cKCuy5l4r4JpkvP2n03WouKfyW4CWqQKC9vB+hCxVyWmNilyOTbACM0mxYU9lqvi43C3XEj4gflaOc7eUi8NTdAzvOThI/pmn3w=",
+ "Tp5mTet94KPUfzgs/XC8eal0iayw513glL8qUBAjAubTY97A/oO3Mevjm2Gn0uLFq4Add0A+O0Vo3l/Ar+etRUV9h4cEkKk5W7AEpY7JND224plOZ+3gsNiSUFm+hm6uOWXkanNNbPaSC3nyA81PF9fffAY=",
+ "2awOe4Bc1b7xsPvQxA9oTfadU3ueOw1EVbQ4poETuxmjPsvHbZhJXnYURfAvqfKCfz6eOUPhml7qcwWpht913JN9IQ5kpCM0So31xB99FkAuglwzZY7s4SLmWSrxjnvp9sjg8IDgiJAyCNQlIZtgjqq1TIE=",
+ "hLF/NXH1yZu4sf6m0jNfsZZS/cxVnoS5c1BtRSxkpeBuiXv+qnO3kko7A43CZ0b8C4kp1/Rfc5bnxn4+liYtsmu0ZGoW2AJdIfQTmkFfQlV/bi6Lcto121CDFY89/jrernWu3urzN1jeeNc1RmkpbmdtOrw=",
+ "xiufG50lmIHTVsvr6zpMXqTxhfVUlfe15eXfnwMkOEFfjMz/njV0HmqkvKtw9P/Z6gbN1Z0B0XnARknYz5OVNQRIaiu/AHTKWd19pnYu25VtFlHibbKhz32zHbfoHAoZeUXYr7x3j1vbJVNnkbwnz5WkNnc=",
+ "oEv/wE1+O6dVD/LKV0coeGq46zb8oDMFmL6GEEuZnC6REQdiJhHBa9fyWm1O+NtjUGu08r20R4r7f20dOqmYCOIMaNvejMTcXui9WXjrl8YQzQ86hkHSQyaXR0nkWOlCZsEF1kDX35TIGOWlCpjCjCwifPI=",
+ "OecAaoAKkmBgkmcVoCiq7NQkfXY/FS9M3zKZtE1pfbIwavWcqr9ucgFSs6zsas5aKVF/yJrNQwlbVE41YlClDuND6jD4NGuYKubzfT6I+saUAacdiDGp1fRA4RWcJFSxULG23XBP2b7D3l5wzeauGWDylkw="
+ ],
+ "server_2_decrypted_shares": [
+ "Kge/qYK/UnjrS8Q55v2K5Neg6iHORxEONcsaoLtX7wU=",
+ "rMt+kStRyxF2V7AWLLnWofVQPf6n7wrkhsTaMcqOZdo=",
+ "iMFcL/WOs264bO1dDOlyGjicF8/u6M5FdrPOFkkHuAE=",
+ "B1Dp6k0WFB252R4cAl2vtG7XGmvXmbDYM+iBn+U2QFU=",
+ "MO0aL0ZRKCfKkgaTmlsfn8Y/K6gRZoM9lOmOwBDxgPA=",
+ "Flz9wOjYAGLPjo+c89ve8fjaAhm+w1LqngA33ro+Q/M=",
+ "dfD3DWQfzkjubvNucxMRSbaaAz1O+4VLg2najHfPaxA=",
+ "ZgUbOVCMvhejv4LzOu3oSi+PgQBHUghfwxp4vJ5Ig9A=",
+ "Cqig9M1LE6iEl1PLB9Qc8cRfiF81TM4EqMCXVPT0SwE=",
+ "ZUpeNfk32vpkPoSh5ZNauB1pw5QeDJ5CMmvsDV0F6eA="
+ ],
+ "reference_sum": "BgAAAAcAAAAEAAAABQAAAAkAAAAGAAAAAgAAAAYAAAAEAAAABQAAAA=="
+} \ No newline at end of file
diff --git a/third_party/rust/prio/src/vdaf/xof.rs b/third_party/rust/prio/src/vdaf/xof.rs
new file mode 100644
index 0000000000..b38d176467
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/xof.rs
@@ -0,0 +1,574 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementations of XOFs specified in [[draft-irtf-cfrg-vdaf-07]].
+//!
+//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+
+use crate::{
+ field::FieldElement,
+ prng::Prng,
+ vdaf::{CodecError, Decode, Encode},
+};
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+use aes::{
+ cipher::{generic_array::GenericArray, BlockEncrypt, KeyInit},
+ Block,
+};
+#[cfg(feature = "crypto-dependencies")]
+use aes::{
+ cipher::{KeyIvInit, StreamCipher},
+ Aes128,
+};
+#[cfg(feature = "crypto-dependencies")]
+use ctr::Ctr64BE;
+use rand_core::{
+ impls::{next_u32_via_fill, next_u64_via_fill},
+ RngCore, SeedableRng,
+};
+use sha3::{
+ digest::{ExtendableOutput, Update, XofReader},
+ Shake128, Shake128Core, Shake128Reader,
+};
+#[cfg(feature = "crypto-dependencies")]
+use std::fmt::Formatter;
+use std::{
+ fmt::Debug,
+ io::{Cursor, Read},
+};
+use subtle::{Choice, ConstantTimeEq};
+
+/// Input of [`Xof`].
+#[derive(Clone, Debug)]
+pub struct Seed<const SEED_SIZE: usize>(pub(crate) [u8; SEED_SIZE]);
+
+impl<const SEED_SIZE: usize> Seed<SEED_SIZE> {
+ /// Generate a uniform random seed.
+ pub fn generate() -> Result<Self, getrandom::Error> {
+ let mut seed = [0; SEED_SIZE];
+ getrandom::getrandom(&mut seed)?;
+ Ok(Self::from_bytes(seed))
+ }
+
+ /// Construct seed from a byte slice.
+ pub(crate) fn from_bytes(seed: [u8; SEED_SIZE]) -> Self {
+ Self(seed)
+ }
+}
+
+impl<const SEED_SIZE: usize> AsRef<[u8; SEED_SIZE]> for Seed<SEED_SIZE> {
+ fn as_ref(&self) -> &[u8; SEED_SIZE] {
+ &self.0
+ }
+}
+
+impl<const SEED_SIZE: usize> PartialEq for Seed<SEED_SIZE> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ct_eq(other).into()
+ }
+}
+
+impl<const SEED_SIZE: usize> Eq for Seed<SEED_SIZE> {}
+
+impl<const SEED_SIZE: usize> ConstantTimeEq for Seed<SEED_SIZE> {
+ fn ct_eq(&self, other: &Self) -> Choice {
+ self.0.ct_eq(&other.0)
+ }
+}
+
+impl<const SEED_SIZE: usize> Encode for Seed<SEED_SIZE> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&self.0[..]);
+ }
+
+ fn encoded_len(&self) -> Option<usize> {
+ Some(SEED_SIZE)
+ }
+}
+
+impl<const SEED_SIZE: usize> Decode for Seed<SEED_SIZE> {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut seed = [0; SEED_SIZE];
+ bytes.read_exact(&mut seed)?;
+ Ok(Seed(seed))
+ }
+}
+
+/// Trait for deriving a vector of field elements.
+pub trait IntoFieldVec: RngCore + Sized {
+ /// Generate a finite field vector from the seed stream.
+ fn into_field_vec<F: FieldElement>(self, length: usize) -> Vec<F>;
+}
+
+impl<S: RngCore> IntoFieldVec for S {
+ fn into_field_vec<F: FieldElement>(self, length: usize) -> Vec<F> {
+ Prng::from_seed_stream(self).take(length).collect()
+ }
+}
+
+/// An extendable output function (XOF) with the interface specified in [[draft-irtf-cfrg-vdaf-07]].
+///
+/// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+pub trait Xof<const SEED_SIZE: usize>: Clone + Debug {
+ /// The type of stream produced by this XOF.
+ type SeedStream: RngCore + Sized;
+
+ /// Construct an instance of [`Xof`] with the given seed.
+ fn init(seed_bytes: &[u8; SEED_SIZE], dst: &[u8]) -> Self;
+
+ /// Update the XOF state by passing in the next fragment of the info string. The final info
+ /// string is assembled from the concatenation of sequence of fragments passed to this method.
+ fn update(&mut self, data: &[u8]);
+
+ /// Finalize the XOF state, producing a seed stream.
+ fn into_seed_stream(self) -> Self::SeedStream;
+
+ /// Finalize the XOF state, producing a seed.
+ fn into_seed(self) -> Seed<SEED_SIZE> {
+ let mut new_seed = [0; SEED_SIZE];
+ let mut seed_stream = self.into_seed_stream();
+ seed_stream.fill_bytes(&mut new_seed);
+ Seed(new_seed)
+ }
+
+ /// Construct a seed stream from the given seed and info string.
+ fn seed_stream(seed: &Seed<SEED_SIZE>, dst: &[u8], binder: &[u8]) -> Self::SeedStream {
+ let mut xof = Self::init(seed.as_ref(), dst);
+ xof.update(binder);
+ xof.into_seed_stream()
+ }
+}
+
+/// The key stream produced by AES128 in CTR-mode.
+#[cfg(feature = "crypto-dependencies")]
+#[cfg_attr(docsrs, doc(cfg(feature = "crypto-dependencies")))]
+pub struct SeedStreamAes128(Ctr64BE<Aes128>);
+
+#[cfg(feature = "crypto-dependencies")]
+impl SeedStreamAes128 {
+ pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self {
+ SeedStreamAes128(<Ctr64BE<Aes128> as KeyIvInit>::new(key.into(), iv.into()))
+ }
+
+ fn fill(&mut self, buf: &mut [u8]) {
+ buf.fill(0);
+ self.0.apply_keystream(buf);
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl RngCore for SeedStreamAes128 {
+ fn fill_bytes(&mut self, dest: &mut [u8]) {
+ self.fill(dest);
+ }
+
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
+ self.fill(dest);
+ Ok(())
+ }
+
+ fn next_u32(&mut self) -> u32 {
+ next_u32_via_fill(self)
+ }
+
+ fn next_u64(&mut self) -> u64 {
+ next_u64_via_fill(self)
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl Debug for SeedStreamAes128 {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ // Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that
+ // with [`cipher::StreamCipherCoreWrapper::get_core`][2].
+ //
+ // [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html
+ // [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html
+ self.0.get_core().fmt(f)
+ }
+}
+
+/// The XOF based on SHA-3 as specified in [[draft-irtf-cfrg-vdaf-07]].
+///
+/// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+#[derive(Clone, Debug)]
+pub struct XofShake128(Shake128);
+
+impl Xof<16> for XofShake128 {
+ type SeedStream = SeedStreamSha3;
+
+ fn init(seed_bytes: &[u8; 16], dst: &[u8]) -> Self {
+ let mut xof = Self(Shake128::from_core(Shake128Core::default()));
+ Update::update(
+ &mut xof.0,
+ &[dst.len().try_into().expect("dst must be at most 255 bytes")],
+ );
+ Update::update(&mut xof.0, dst);
+ Update::update(&mut xof.0, seed_bytes);
+ xof
+ }
+
+ fn update(&mut self, data: &[u8]) {
+ Update::update(&mut self.0, data);
+ }
+
+ fn into_seed_stream(self) -> SeedStreamSha3 {
+ SeedStreamSha3::new(self.0.finalize_xof())
+ }
+}
+
+/// The seed stream produced by SHAKE128.
+pub struct SeedStreamSha3(Shake128Reader);
+
+impl SeedStreamSha3 {
+ pub(crate) fn new(reader: Shake128Reader) -> Self {
+ Self(reader)
+ }
+}
+
+impl RngCore for SeedStreamSha3 {
+ fn fill_bytes(&mut self, dest: &mut [u8]) {
+ XofReader::read(&mut self.0, dest);
+ }
+
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
+ XofReader::read(&mut self.0, dest);
+ Ok(())
+ }
+
+ fn next_u32(&mut self) -> u32 {
+ next_u32_via_fill(self)
+ }
+
+ fn next_u64(&mut self) -> u64 {
+ next_u64_via_fill(self)
+ }
+}
+
+/// A `rand`-compatible interface to construct XofShake128 seed streams, with the domain separation tag
+/// and binder string both fixed as the empty string.
+impl SeedableRng for SeedStreamSha3 {
+ type Seed = [u8; 16];
+
+ fn from_seed(seed: Self::Seed) -> Self {
+ XofShake128::init(&seed, b"").into_seed_stream()
+ }
+}
+
+/// Factory to produce multiple [`XofFixedKeyAes128`] instances with the same fixed key and
+/// different seeds.
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+#[cfg_attr(
+ docsrs,
+ doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
+)]
+pub struct XofFixedKeyAes128Key {
+ cipher: Aes128,
+}
+
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+impl XofFixedKeyAes128Key {
+ /// Derive the fixed key from the domain separation tag and binder string.
+ pub fn new(dst: &[u8], binder: &[u8]) -> Self {
+ let mut fixed_key_deriver = Shake128::from_core(Shake128Core::default());
+ Update::update(
+ &mut fixed_key_deriver,
+ &[dst.len().try_into().expect("dst must be at most 255 bytes")],
+ );
+ Update::update(&mut fixed_key_deriver, dst);
+ Update::update(&mut fixed_key_deriver, binder);
+ let mut key = GenericArray::from([0; 16]);
+ XofReader::read(&mut fixed_key_deriver.finalize_xof(), key.as_mut());
+ Self {
+ cipher: Aes128::new(&key),
+ }
+ }
+
+ /// Combine a fixed key with a seed to produce a new stream of bytes.
+ pub fn with_seed(&self, seed: &[u8; 16]) -> SeedStreamFixedKeyAes128 {
+ SeedStreamFixedKeyAes128 {
+ cipher: self.cipher.clone(),
+ base_block: (*seed).into(),
+ length_consumed: 0,
+ }
+ }
+}
+
+/// XofFixedKeyAes128 as specified in [[draft-irtf-cfrg-vdaf-07]]. This XOF is NOT RECOMMENDED for
+/// general use; see Section 9 ("Security Considerations") for details.
+///
+/// This XOF combines SHA-3 and a fixed-key mode of operation for AES-128. The key is "fixed" in
+/// the sense that it is derived (using SHAKE128) from the domain separation tag and binder
+/// strings, and depending on the application, these strings can be hard-coded. The seed is used to
+/// construct each block of input passed to a hash function built from AES-128.
+///
+/// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+#[derive(Clone, Debug)]
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+#[cfg_attr(
+ docsrs,
+ doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
+)]
+pub struct XofFixedKeyAes128 {
+ fixed_key_deriver: Shake128,
+ base_block: Block,
+}
+
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+impl Xof<16> for XofFixedKeyAes128 {
+ type SeedStream = SeedStreamFixedKeyAes128;
+
+ fn init(seed_bytes: &[u8; 16], dst: &[u8]) -> Self {
+ let mut fixed_key_deriver = Shake128::from_core(Shake128Core::default());
+ Update::update(
+ &mut fixed_key_deriver,
+ &[dst.len().try_into().expect("dst must be at most 255 bytes")],
+ );
+ Update::update(&mut fixed_key_deriver, dst);
+ Self {
+ fixed_key_deriver,
+ base_block: (*seed_bytes).into(),
+ }
+ }
+
+ fn update(&mut self, data: &[u8]) {
+ Update::update(&mut self.fixed_key_deriver, data);
+ }
+
+ fn into_seed_stream(self) -> SeedStreamFixedKeyAes128 {
+ let mut fixed_key = GenericArray::from([0; 16]);
+ XofReader::read(
+ &mut self.fixed_key_deriver.finalize_xof(),
+ fixed_key.as_mut(),
+ );
+ SeedStreamFixedKeyAes128 {
+ base_block: self.base_block,
+ cipher: Aes128::new(&fixed_key),
+ length_consumed: 0,
+ }
+ }
+}
+
+/// Seed stream for [`XofFixedKeyAes128`].
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+#[cfg_attr(
+ docsrs,
+ doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
+)]
+pub struct SeedStreamFixedKeyAes128 {
+ cipher: Aes128,
+ base_block: Block,
+ length_consumed: u64,
+}
+
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+impl SeedStreamFixedKeyAes128 {
+ fn hash_block(&self, block: &mut Block) {
+ let sigma = Block::from([
+ // hi
+ block[8],
+ block[9],
+ block[10],
+ block[11],
+ block[12],
+ block[13],
+ block[14],
+ block[15],
+ // xor(hi, lo)
+ block[8] ^ block[0],
+ block[9] ^ block[1],
+ block[10] ^ block[2],
+ block[11] ^ block[3],
+ block[12] ^ block[4],
+ block[13] ^ block[5],
+ block[14] ^ block[6],
+ block[15] ^ block[7],
+ ]);
+ self.cipher.encrypt_block_b2b(&sigma, block);
+ for (b, s) in block.iter_mut().zip(sigma.iter()) {
+ *b ^= s;
+ }
+ }
+
+ fn fill(&mut self, buf: &mut [u8]) {
+ let next_length_consumed = self.length_consumed + u64::try_from(buf.len()).unwrap();
+ let mut offset = usize::try_from(self.length_consumed % 16).unwrap();
+ let mut index = 0;
+ let mut block = Block::from([0; 16]);
+
+ // NOTE(cjpatton) We might be able to speed this up by unrolling this loop and encrypting
+ // multiple blocks at the same time via `self.cipher.encrypt_blocks()`.
+ for block_counter in self.length_consumed / 16..(next_length_consumed + 15) / 16 {
+ block.clone_from(&self.base_block);
+ for (b, i) in block.iter_mut().zip(block_counter.to_le_bytes().iter()) {
+ *b ^= i;
+ }
+ self.hash_block(&mut block);
+ let read = std::cmp::min(16 - offset, buf.len() - index);
+ buf[index..index + read].copy_from_slice(&block[offset..offset + read]);
+ offset = 0;
+ index += read;
+ }
+
+ self.length_consumed = next_length_consumed;
+ }
+}
+
+#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
+impl RngCore for SeedStreamFixedKeyAes128 {
+ fn fill_bytes(&mut self, dest: &mut [u8]) {
+ self.fill(dest);
+ }
+
+ fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
+ self.fill(dest);
+ Ok(())
+ }
+
+ fn next_u32(&mut self) -> u32 {
+ next_u32_via_fill(self)
+ }
+
+ fn next_u64(&mut self) -> u64 {
+ next_u64_via_fill(self)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{field::Field128, vdaf::equality_comparison_test};
+ use serde::{Deserialize, Serialize};
+ use std::{convert::TryInto, io::Cursor};
+
+ #[derive(Deserialize, Serialize)]
+ struct XofTestVector {
+ #[serde(with = "hex")]
+ seed: Vec<u8>,
+ #[serde(with = "hex")]
+ dst: Vec<u8>,
+ #[serde(with = "hex")]
+ binder: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex")]
+ derived_seed: Vec<u8>,
+ #[serde(with = "hex")]
+ expanded_vec_field128: Vec<u8>,
+ }
+
+ // Test correctness of dervied methods.
+ fn test_xof<P, const SEED_SIZE: usize>()
+ where
+ P: Xof<SEED_SIZE>,
+ {
+ let seed = Seed::generate().unwrap();
+ let dst = b"algorithm and usage";
+ let binder = b"bind to artifact";
+
+ let mut xof = P::init(seed.as_ref(), dst);
+ xof.update(binder);
+
+ let mut want = Seed([0; SEED_SIZE]);
+ xof.clone().into_seed_stream().fill_bytes(&mut want.0[..]);
+ let got = xof.clone().into_seed();
+ assert_eq!(got, want);
+
+ let mut want = [0; 45];
+ xof.clone().into_seed_stream().fill_bytes(&mut want);
+ let mut got = [0; 45];
+ P::seed_stream(&seed, dst, binder).fill_bytes(&mut got);
+ assert_eq!(got, want);
+ }
+
+ #[test]
+ fn xof_shake128() {
+ let t: XofTestVector =
+ serde_json::from_str(include_str!("test_vec/07/XofShake128.json")).unwrap();
+ let mut xof = XofShake128::init(&t.seed.try_into().unwrap(), &t.dst);
+ xof.update(&t.binder);
+
+ assert_eq!(
+ xof.clone().into_seed(),
+ Seed(t.derived_seed.try_into().unwrap())
+ );
+
+ let mut bytes = Cursor::new(t.expanded_vec_field128.as_slice());
+ let mut want = Vec::with_capacity(t.length);
+ while (bytes.position() as usize) < t.expanded_vec_field128.len() {
+ want.push(Field128::decode(&mut bytes).unwrap())
+ }
+ let got: Vec<Field128> = xof.clone().into_seed_stream().into_field_vec(t.length);
+ assert_eq!(got, want);
+
+ test_xof::<XofShake128, 16>();
+ }
+
+ #[cfg(feature = "experimental")]
+ #[test]
+ fn xof_fixed_key_aes128() {
+ let t: XofTestVector =
+ serde_json::from_str(include_str!("test_vec/07/XofFixedKeyAes128.json")).unwrap();
+ let mut xof = XofFixedKeyAes128::init(&t.seed.try_into().unwrap(), &t.dst);
+ xof.update(&t.binder);
+
+ assert_eq!(
+ xof.clone().into_seed(),
+ Seed(t.derived_seed.try_into().unwrap())
+ );
+
+ let mut bytes = Cursor::new(t.expanded_vec_field128.as_slice());
+ let mut want = Vec::with_capacity(t.length);
+ while (bytes.position() as usize) < t.expanded_vec_field128.len() {
+ want.push(Field128::decode(&mut bytes).unwrap())
+ }
+ let got: Vec<Field128> = xof.clone().into_seed_stream().into_field_vec(t.length);
+ assert_eq!(got, want);
+
+ test_xof::<XofFixedKeyAes128, 16>();
+ }
+
+ #[cfg(feature = "experimental")]
+ #[test]
+ fn xof_fixed_key_aes128_incomplete_block() {
+ let seed = Seed::generate().unwrap();
+ let mut expected = [0; 32];
+ XofFixedKeyAes128::seed_stream(&seed, b"dst", b"binder").fill(&mut expected);
+
+ for len in 0..=32 {
+ let mut buf = vec![0; len];
+ XofFixedKeyAes128::seed_stream(&seed, b"dst", b"binder").fill(&mut buf);
+ assert_eq!(buf, &expected[..len]);
+ }
+ }
+
+ #[cfg(feature = "experimental")]
+ #[test]
+ fn xof_fixed_key_aes128_alternate_apis() {
+ let dst = b"domain separation tag";
+ let binder = b"AAAAAAAAAAAAAAAAAAAAAAAA";
+ let seed_1 = Seed::generate().unwrap();
+ let seed_2 = Seed::generate().unwrap();
+
+ let mut stream_1_trait_api = XofFixedKeyAes128::seed_stream(&seed_1, dst, binder);
+ let mut output_1_trait_api = [0u8; 32];
+ stream_1_trait_api.fill(&mut output_1_trait_api);
+ let mut stream_2_trait_api = XofFixedKeyAes128::seed_stream(&seed_2, dst, binder);
+ let mut output_2_trait_api = [0u8; 32];
+ stream_2_trait_api.fill(&mut output_2_trait_api);
+
+ let fixed_key = XofFixedKeyAes128Key::new(dst, binder);
+ let mut stream_1_alternate_api = fixed_key.with_seed(seed_1.as_ref());
+ let mut output_1_alternate_api = [0u8; 32];
+ stream_1_alternate_api.fill(&mut output_1_alternate_api);
+ let mut stream_2_alternate_api = fixed_key.with_seed(seed_2.as_ref());
+ let mut output_2_alternate_api = [0u8; 32];
+ stream_2_alternate_api.fill(&mut output_2_alternate_api);
+
+ assert_eq!(output_1_trait_api, output_1_alternate_api);
+ assert_eq!(output_2_trait_api, output_2_alternate_api);
+ }
+
+ #[test]
+ fn seed_equality_test() {
+ equality_comparison_test(&[Seed([1, 2, 3]), Seed([3, 2, 1])])
+ }
+}