summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/test_vector.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/prio/src/test_vector.rs
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/prio/src/test_vector.rs')
-rw-r--r--third_party/rust/prio/src/test_vector.rs244
1 files changed, 244 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/test_vector.rs b/third_party/rust/prio/src/test_vector.rs
new file mode 100644
index 0000000000..1306bc26f9
--- /dev/null
+++ b/third_party/rust/prio/src/test_vector.rs
@@ -0,0 +1,244 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Module `test_vector` generates test vectors of serialized Prio inputs and
+//! support for working with test vectors, enabling backward compatibility
+//! testing.
+
+use crate::{
+ client::{Client, ClientError},
+ encrypt::{PrivateKey, PublicKey},
+ field::{FieldElement, FieldPrio2},
+ server::{Server, ServerError},
+};
+use rand::Rng;
+use serde::{Deserialize, Serialize};
+use std::fmt::Debug;
+
+/// Errors propagated by functions in this module.
+#[derive(Debug, thiserror::Error)]
+pub enum TestVectorError {
+ /// Error from Prio client
+ #[error("Prio client error {0}")]
+ Client(#[from] ClientError),
+ /// Error from Prio server
+ #[error("Prio server error {0}")]
+ Server(#[from] ServerError),
+ /// Error while converting primitive to FieldElement associated integer type
+ #[error("Integer conversion error {0}")]
+ IntegerConversion(String),
+}
+
+const SERVER_1_PRIVATE_KEY: &str =
+ "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBH\
+ fNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
+const SERVER_2_PRIVATE_KEY: &str =
+ "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rD\
+ ULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";
+
+/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate
+/// test vectors and later to decrypt them.
+fn server_1_private_key() -> PrivateKey {
+ PrivateKey::from_base64(SERVER_1_PRIVATE_KEY).unwrap()
+}
+
+/// The public portion of [`server_1_private_key`].
+fn server_1_public_key() -> PublicKey {
+ PublicKey::from(&server_1_private_key())
+}
+
+/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate
+/// test vectors and later to decrypt them.
+fn server_2_private_key() -> PrivateKey {
+ PrivateKey::from_base64(SERVER_2_PRIVATE_KEY).unwrap()
+}
+
+/// The public portion of [`server_2_private_key`].
+fn server_2_public_key() -> PublicKey {
+ PublicKey::from(&server_2_private_key())
+}
+
+/// A test vector of Prio inputs, serialized and encrypted in the Priov2 format,
+/// along with a reference sum. The field is always [`FieldPrio2`].
+#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub struct Priov2TestVector {
+ /// Base64 encoded private key for the "first" a.k.a. "PHA" server, which
+ /// may be used to decrypt `server_1_shares`.
+ pub server_1_private_key: String,
+ /// Base64 encoded private key for the non-"first" a.k.a. "facilitator"
+ /// server, which may be used to decrypt `server_2_shares`.
+ pub server_2_private_key: String,
+ /// Dimension (number of buckets) of the inputs
+ pub dimension: usize,
+ /// Encrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA"
+ /// server. The inner `Vec`s are encrypted bytes.
+ #[serde(
+ serialize_with = "base64::serialize_bytes",
+ deserialize_with = "base64::deserialize_bytes"
+ )]
+ pub server_1_shares: Vec<Vec<u8>>,
+ /// Encrypted share of Priov2 format inputs for the non-"first" a.k.a.
+ /// "facilitator" server.
+ #[serde(
+ serialize_with = "base64::serialize_bytes",
+ deserialize_with = "base64::deserialize_bytes"
+ )]
+ pub server_2_shares: Vec<Vec<u8>>,
+ /// The sum over the inputs.
+ #[serde(
+ serialize_with = "base64::serialize_field",
+ deserialize_with = "base64::deserialize_field"
+ )]
+ pub reference_sum: Vec<FieldPrio2>,
+ /// The version of the crate that generated this test vector
+ pub prio_crate_version: String,
+}
+
+impl Priov2TestVector {
+ /// Construct a test vector of `number_of_clients` inputs, each of which is a
+ /// `dimension`-dimension vector of random Boolean values encoded as
+ /// [`FieldPrio2`].
+ pub fn new(dimension: usize, number_of_clients: usize) -> Result<Self, TestVectorError> {
+ let mut client: Client<FieldPrio2> =
+ Client::new(dimension, server_1_public_key(), server_2_public_key())?;
+
+ let mut reference_sum = vec![FieldPrio2::zero(); dimension];
+ let mut server_1_shares = Vec::with_capacity(number_of_clients);
+ let mut server_2_shares = Vec::with_capacity(number_of_clients);
+
+ let mut rng = rand::thread_rng();
+
+ for _ in 0..number_of_clients {
+ // Generate a random vector of booleans
+ let data: Vec<FieldPrio2> = (0..dimension)
+ .map(|_| FieldPrio2::from(rng.gen_range(0..2)))
+ .collect();
+
+ // Update reference sum
+ for (r, d) in reference_sum.iter_mut().zip(&data) {
+ *r += *d;
+ }
+
+ let (server_1_share, server_2_share) = client.encode_simple(&data)?;
+
+ server_1_shares.push(server_1_share);
+ server_2_shares.push(server_2_share);
+ }
+
+ Ok(Self {
+ server_1_private_key: SERVER_1_PRIVATE_KEY.to_owned(),
+ server_2_private_key: SERVER_2_PRIVATE_KEY.to_owned(),
+ dimension,
+ server_1_shares,
+ server_2_shares,
+ reference_sum,
+ prio_crate_version: env!("CARGO_PKG_VERSION").to_owned(),
+ })
+ }
+
+ /// Construct a [`Client`] that can encrypt input shares to this test
+ /// vector's servers.
+ pub fn client(&self) -> Result<Client<FieldPrio2>, TestVectorError> {
+ Ok(Client::new(
+ self.dimension,
+ PublicKey::from(&PrivateKey::from_base64(&self.server_1_private_key).unwrap()),
+ PublicKey::from(&PrivateKey::from_base64(&self.server_2_private_key).unwrap()),
+ )?)
+ }
+
+ /// Construct a [`Server`] that can decrypt `server_1_shares`.
+ pub fn server_1(&self) -> Result<Server<FieldPrio2>, TestVectorError> {
+ Ok(Server::new(
+ self.dimension,
+ true,
+ PrivateKey::from_base64(&self.server_1_private_key).unwrap(),
+ )?)
+ }
+
+ /// Construct a [`Server`] that can decrypt `server_2_shares`.
+ pub fn server_2(&self) -> Result<Server<FieldPrio2>, TestVectorError> {
+ Ok(Server::new(
+ self.dimension,
+ false,
+ PrivateKey::from_base64(&self.server_2_private_key).unwrap(),
+ )?)
+ }
+}
+
+mod base64 {
+ //! Custom serialization module used for some members of struct
+ //! `Priov2TestVector` so that byte slices are serialized as base64 strings
+ //! instead of an array of an array of integers when serializing to JSON.
+ //
+ // Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2
+ use crate::field::{FieldElement, FieldPrio2};
+ use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
+
+ pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> {
+ let base64_vec = v.iter().map(base64::encode).collect();
+ <Vec<String>>::serialize(&base64_vec, s)
+ }
+
+ pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> {
+ <Vec<String>>::deserialize(d)?
+ .iter()
+ .map(|s| base64::decode(s.as_bytes()).map_err(Error::custom))
+ .collect()
+ }
+
+ pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> {
+ String::serialize(&base64::encode(FieldPrio2::slice_into_byte_vec(v)), s)
+ }
+
+ pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> {
+ let bytes = base64::decode(String::deserialize(d)?.as_bytes()).map_err(Error::custom)?;
+ FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::util::reconstruct_shares;
+
+ #[test]
+ fn roundtrip_test_vector_serialization() {
+ let test_vector = Priov2TestVector::new(123, 100).unwrap();
+ let serialized = serde_json::to_vec(&test_vector).unwrap();
+ let test_vector_again: Priov2TestVector = serde_json::from_slice(&serialized).unwrap();
+
+ assert_eq!(test_vector, test_vector_again);
+ }
+
+ #[test]
+ fn accumulation_field_priov2() {
+ let dimension = 123;
+ let test_vector = Priov2TestVector::new(dimension, 100).unwrap();
+
+ let mut server1 = test_vector.server_1().unwrap();
+ let mut server2 = test_vector.server_2().unwrap();
+
+ for (server_1_share, server_2_share) in test_vector
+ .server_1_shares
+ .iter()
+ .zip(&test_vector.server_2_shares)
+ {
+ let eval_at = server1.choose_eval_at();
+
+ let v1 = server1
+ .generate_verification_message(eval_at, server_1_share)
+ .unwrap();
+ let v2 = server2
+ .generate_verification_message(eval_at, server_2_share)
+ .unwrap();
+
+ assert!(server1.aggregate(server_1_share, &v1, &v2).unwrap());
+ assert!(server2.aggregate(server_2_share, &v1, &v2).unwrap());
+ }
+
+ let total1 = server1.total_shares();
+ let total2 = server2.total_shares();
+
+ let reconstructed = reconstruct_shares(total1, total2).unwrap();
+ assert_eq!(reconstructed, test_vector.reference_sum);
+ }
+}