// SPDX-License-Identifier: MPL-2.0 //! Module `test_vector` generates test vectors of serialized Prio inputs and //! support for working with test vectors, enabling backward compatibility //! testing. use crate::{ client::{Client, ClientError}, encrypt::{PrivateKey, PublicKey}, field::{FieldElement, FieldPrio2}, server::{Server, ServerError}, }; use rand::Rng; use serde::{Deserialize, Serialize}; use std::fmt::Debug; /// Errors propagated by functions in this module. #[derive(Debug, thiserror::Error)] pub enum TestVectorError { /// Error from Prio client #[error("Prio client error {0}")] Client(#[from] ClientError), /// Error from Prio server #[error("Prio server error {0}")] Server(#[from] ServerError), /// Error while converting primitive to FieldElement associated integer type #[error("Integer conversion error {0}")] IntegerConversion(String), } const SERVER_1_PRIVATE_KEY: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBH\ fNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw=="; const SERVER_2_PRIVATE_KEY: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rD\ ULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w=="; /// An ECDSA P-256 private key suitable for decrypting inputs, used to generate /// test vectors and later to decrypt them. fn server_1_private_key() -> PrivateKey { PrivateKey::from_base64(SERVER_1_PRIVATE_KEY).unwrap() } /// The public portion of [`server_1_private_key`]. fn server_1_public_key() -> PublicKey { PublicKey::from(&server_1_private_key()) } /// An ECDSA P-256 private key suitable for decrypting inputs, used to generate /// test vectors and later to decrypt them. fn server_2_private_key() -> PrivateKey { PrivateKey::from_base64(SERVER_2_PRIVATE_KEY).unwrap() } /// The public portion of [`server_2_private_key`]. fn server_2_public_key() -> PublicKey { PublicKey::from(&server_2_private_key()) } /// A test vector of Prio inputs, serialized and encrypted in the Priov2 format, /// along with a reference sum. The field is always [`FieldPrio2`]. #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct Priov2TestVector { /// Base64 encoded private key for the "first" a.k.a. "PHA" server, which /// may be used to decrypt `server_1_shares`. pub server_1_private_key: String, /// Base64 encoded private key for the non-"first" a.k.a. "facilitator" /// server, which may be used to decrypt `server_2_shares`. pub server_2_private_key: String, /// Dimension (number of buckets) of the inputs pub dimension: usize, /// Encrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA" /// server. The inner `Vec`s are encrypted bytes. #[serde( serialize_with = "base64::serialize_bytes", deserialize_with = "base64::deserialize_bytes" )] pub server_1_shares: Vec>, /// Encrypted share of Priov2 format inputs for the non-"first" a.k.a. /// "facilitator" server. #[serde( serialize_with = "base64::serialize_bytes", deserialize_with = "base64::deserialize_bytes" )] pub server_2_shares: Vec>, /// The sum over the inputs. #[serde( serialize_with = "base64::serialize_field", deserialize_with = "base64::deserialize_field" )] pub reference_sum: Vec, /// The version of the crate that generated this test vector pub prio_crate_version: String, } impl Priov2TestVector { /// Construct a test vector of `number_of_clients` inputs, each of which is a /// `dimension`-dimension vector of random Boolean values encoded as /// [`FieldPrio2`]. pub fn new(dimension: usize, number_of_clients: usize) -> Result { let mut client: Client = Client::new(dimension, server_1_public_key(), server_2_public_key())?; let mut reference_sum = vec![FieldPrio2::zero(); dimension]; let mut server_1_shares = Vec::with_capacity(number_of_clients); let mut server_2_shares = Vec::with_capacity(number_of_clients); let mut rng = rand::thread_rng(); for _ in 0..number_of_clients { // Generate a random vector of booleans let data: Vec = (0..dimension) .map(|_| FieldPrio2::from(rng.gen_range(0..2))) .collect(); // Update reference sum for (r, d) in reference_sum.iter_mut().zip(&data) { *r += *d; } let (server_1_share, server_2_share) = client.encode_simple(&data)?; server_1_shares.push(server_1_share); server_2_shares.push(server_2_share); } Ok(Self { server_1_private_key: SERVER_1_PRIVATE_KEY.to_owned(), server_2_private_key: SERVER_2_PRIVATE_KEY.to_owned(), dimension, server_1_shares, server_2_shares, reference_sum, prio_crate_version: env!("CARGO_PKG_VERSION").to_owned(), }) } /// Construct a [`Client`] that can encrypt input shares to this test /// vector's servers. pub fn client(&self) -> Result, TestVectorError> { Ok(Client::new( self.dimension, PublicKey::from(&PrivateKey::from_base64(&self.server_1_private_key).unwrap()), PublicKey::from(&PrivateKey::from_base64(&self.server_2_private_key).unwrap()), )?) } /// Construct a [`Server`] that can decrypt `server_1_shares`. pub fn server_1(&self) -> Result, TestVectorError> { Ok(Server::new( self.dimension, true, PrivateKey::from_base64(&self.server_1_private_key).unwrap(), )?) } /// Construct a [`Server`] that can decrypt `server_2_shares`. pub fn server_2(&self) -> Result, TestVectorError> { Ok(Server::new( self.dimension, false, PrivateKey::from_base64(&self.server_2_private_key).unwrap(), )?) } } mod base64 { //! Custom serialization module used for some members of struct //! `Priov2TestVector` so that byte slices are serialized as base64 strings //! instead of an array of an array of integers when serializing to JSON. // // Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2 use crate::field::{FieldElement, FieldPrio2}; use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; pub fn serialize_bytes(v: &[Vec], s: S) -> Result { let base64_vec = v.iter().map(base64::encode).collect(); >::serialize(&base64_vec, s) } pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result>, D::Error> { >::deserialize(d)? .iter() .map(|s| base64::decode(s.as_bytes()).map_err(Error::custom)) .collect() } pub fn serialize_field(v: &[FieldPrio2], s: S) -> Result { String::serialize(&base64::encode(FieldPrio2::slice_into_byte_vec(v)), s) } pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { let bytes = base64::decode(String::deserialize(d)?.as_bytes()).map_err(Error::custom)?; FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom) } } #[cfg(test)] mod tests { use super::*; use crate::util::reconstruct_shares; #[test] fn roundtrip_test_vector_serialization() { let test_vector = Priov2TestVector::new(123, 100).unwrap(); let serialized = serde_json::to_vec(&test_vector).unwrap(); let test_vector_again: Priov2TestVector = serde_json::from_slice(&serialized).unwrap(); assert_eq!(test_vector, test_vector_again); } #[test] fn accumulation_field_priov2() { let dimension = 123; let test_vector = Priov2TestVector::new(dimension, 100).unwrap(); let mut server1 = test_vector.server_1().unwrap(); let mut server2 = test_vector.server_2().unwrap(); for (server_1_share, server_2_share) in test_vector .server_1_shares .iter() .zip(&test_vector.server_2_shares) { let eval_at = server1.choose_eval_at(); let v1 = server1 .generate_verification_message(eval_at, server_1_share) .unwrap(); let v2 = server2 .generate_verification_message(eval_at, server_2_share) .unwrap(); assert!(server1.aggregate(server_1_share, &v1, &v2).unwrap()); assert!(server2.aggregate(server_2_share, &v1, &v2).unwrap()); } let total1 = server1.total_shares(); let total2 = server2.total_shares(); let reconstructed = reconstruct_shares(total1, total2).unwrap(); assert_eq!(reconstructed, test_vector.reference_sum); } }