summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/server.rs')
-rw-r--r--third_party/rust/prio/src/server.rs469
1 files changed, 469 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/server.rs b/third_party/rust/prio/src/server.rs
new file mode 100644
index 0000000000..01f309e797
--- /dev/null
+++ b/third_party/rust/prio/src/server.rs
@@ -0,0 +1,469 @@
+// Copyright (c) 2020 Apple Inc.
+// SPDX-License-Identifier: MPL-2.0
+
+//! The Prio v2 server. Only 0 / 1 vectors are supported for now.
+use crate::{
+ encrypt::{decrypt_share, EncryptError, PrivateKey},
+ field::{merge_vector, FieldElement, FieldError},
+ polynomial::{poly_interpret_eval, PolyAuxMemory},
+ prng::{Prng, PrngError},
+ util::{proof_length, unpack_proof, SerializeError},
+ vdaf::prg::SeedStreamAes128,
+};
+use serde::{Deserialize, Serialize};
+use std::convert::TryInto;
+
+/// Possible errors from server operations
+#[derive(Debug, thiserror::Error)]
+pub enum ServerError {
+ /// Unexpected Share Length
+ #[error("unexpected share length")]
+ ShareLength,
+ /// Encryption/decryption error
+ #[error("encryption/decryption error")]
+ Encrypt(#[from] EncryptError),
+ /// Finite field operation error
+ #[error("finite field operation error")]
+ Field(#[from] FieldError),
+ /// Serialization/deserialization error
+ #[error("serialization/deserialization error")]
+ Serialize(#[from] SerializeError),
+ /// Failure when calling getrandom().
+ #[error("getrandom: {0}")]
+ GetRandom(#[from] getrandom::Error),
+ /// PRNG error.
+ #[error("prng error: {0}")]
+ Prng(#[from] PrngError),
+}
+
+/// Auxiliary memory for constructing a
+/// [`VerificationMessage`](struct.VerificationMessage.html)
+#[derive(Debug)]
+pub struct ValidationMemory<F> {
+ points_f: Vec<F>,
+ points_g: Vec<F>,
+ points_h: Vec<F>,
+ poly_mem: PolyAuxMemory<F>,
+}
+
+impl<F: FieldElement> ValidationMemory<F> {
+ /// Construct a new ValidationMemory object for validating proof shares of
+ /// length `dimension`.
+ pub fn new(dimension: usize) -> Self {
+ let n: usize = (dimension + 1).next_power_of_two();
+ ValidationMemory {
+ points_f: vec![F::zero(); n],
+ points_g: vec![F::zero(); n],
+ points_h: vec![F::zero(); 2 * n],
+ poly_mem: PolyAuxMemory::new(n),
+ }
+ }
+}
+
+/// Main workhorse of the server.
+#[derive(Debug)]
+pub struct Server<F> {
+ prng: Prng<F, SeedStreamAes128>,
+ dimension: usize,
+ is_first_server: bool,
+ accumulator: Vec<F>,
+ validation_mem: ValidationMemory<F>,
+ private_key: PrivateKey,
+}
+
+impl<F: FieldElement> Server<F> {
+ /// Construct a new server instance
+ ///
+ /// Params:
+ /// * `dimension`: the number of elements in the aggregation vector.
+ /// * `is_first_server`: only one of the servers should have this true.
+ /// * `private_key`: the private key for decrypting the share of the proof.
+ pub fn new(
+ dimension: usize,
+ is_first_server: bool,
+ private_key: PrivateKey,
+ ) -> Result<Server<F>, ServerError> {
+ Ok(Server {
+ prng: Prng::new()?,
+ dimension,
+ is_first_server,
+ accumulator: vec![F::zero(); dimension],
+ validation_mem: ValidationMemory::new(dimension),
+ private_key,
+ })
+ }
+
+ /// Decrypt and deserialize
+ fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<F>, ServerError> {
+ let len = proof_length(self.dimension);
+ let share = decrypt_share(encrypted_share, &self.private_key)?;
+ Ok(if self.is_first_server {
+ F::byte_slice_into_vec(&share)?
+ } else {
+ if share.len() != 32 {
+ return Err(ServerError::ShareLength);
+ }
+
+ Prng::from_prio2_seed(&share.try_into().unwrap())
+ .take(len)
+ .collect()
+ })
+ }
+
+ /// Generate verification message from an encrypted share
+ ///
+ /// This decrypts the share of the proof and constructs the
+ /// [`VerificationMessage`](struct.VerificationMessage.html).
+ /// The `eval_at` field should be generate by
+ /// [choose_eval_at](#method.choose_eval_at).
+ pub fn generate_verification_message(
+ &mut self,
+ eval_at: F,
+ share: &[u8],
+ ) -> Result<VerificationMessage<F>, ServerError> {
+ let share_field = self.deserialize_share(share)?;
+ generate_verification_message(
+ self.dimension,
+ eval_at,
+ &share_field,
+ self.is_first_server,
+ &mut self.validation_mem,
+ )
+ }
+
+ /// Add the content of the encrypted share into the accumulator
+ ///
+ /// This only changes the accumulator if the verification messages `v1` and
+ /// `v2` indicate that the share passed validation.
+ pub fn aggregate(
+ &mut self,
+ share: &[u8],
+ v1: &VerificationMessage<F>,
+ v2: &VerificationMessage<F>,
+ ) -> Result<bool, ServerError> {
+ let share_field = self.deserialize_share(share)?;
+ let is_valid = is_valid_share(v1, v2);
+ if is_valid {
+ // Add to the accumulator. share_field also includes the proof
+ // encoding, so we slice off the first dimension fields, which are
+ // the actual data share.
+ merge_vector(&mut self.accumulator, &share_field[..self.dimension])?;
+ }
+
+ Ok(is_valid)
+ }
+
+ /// Return the current accumulated shares.
+ ///
+ /// These can be merged together using
+ /// [`reconstruct_shares`](../util/fn.reconstruct_shares.html).
+ pub fn total_shares(&self) -> &[F] {
+ &self.accumulator
+ }
+
+ /// Merge shares from another server.
+ ///
+ /// This modifies the current accumulator.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if `other_total_shares.len()` is not equal to this
+ //// server's `dimension`.
+ pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> {
+ Ok(merge_vector(&mut self.accumulator, other_total_shares)?)
+ }
+
+ /// Choose a random point for polynomial evaluation
+ ///
+ /// The point returned is not one of the roots used for polynomial
+ /// evaluation.
+ pub fn choose_eval_at(&mut self) -> F {
+ loop {
+ let eval_at = self.prng.get();
+ if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) {
+ break eval_at;
+ }
+ }
+ }
+}
+
+/// Verification message for proof validation
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct VerificationMessage<F> {
+ /// f evaluated at random point
+ pub f_r: F,
+ /// g evaluated at random point
+ pub g_r: F,
+ /// h evaluated at random point
+ pub h_r: F,
+}
+
+/// Given a proof and evaluation point, this constructs the verification
+/// message.
+pub fn generate_verification_message<F: FieldElement>(
+ dimension: usize,
+ eval_at: F,
+ proof: &[F],
+ is_first_server: bool,
+ mem: &mut ValidationMemory<F>,
+) -> Result<VerificationMessage<F>, ServerError> {
+ let unpacked = unpack_proof(proof, dimension)?;
+ let proof_length = 2 * (dimension + 1).next_power_of_two();
+
+ // set zero terms
+ mem.points_f[0] = *unpacked.f0;
+ mem.points_g[0] = *unpacked.g0;
+ mem.points_h[0] = *unpacked.h0;
+
+ // set points_f and points_g
+ for (i, x) in unpacked.data.iter().enumerate() {
+ mem.points_f[i + 1] = *x;
+
+ if is_first_server {
+ // only one server needs to subtract one for point_g
+ mem.points_g[i + 1] = *x - F::one();
+ } else {
+ mem.points_g[i + 1] = *x;
+ }
+ }
+
+ // set points_h, skipping over elements that should be zero
+ let mut i = 1;
+ let mut j = 0;
+ while i < proof_length {
+ mem.points_h[i] = unpacked.points_h_packed[j];
+ j += 1;
+ i += 2;
+ }
+
+ // evaluate polynomials at random point
+ let f_r = poly_interpret_eval(
+ &mem.points_f,
+ &mem.poly_mem.roots_n_inverted,
+ eval_at,
+ &mut mem.poly_mem.coeffs,
+ &mut mem.poly_mem.fft_memory,
+ );
+ let g_r = poly_interpret_eval(
+ &mem.points_g,
+ &mem.poly_mem.roots_n_inverted,
+ eval_at,
+ &mut mem.poly_mem.coeffs,
+ &mut mem.poly_mem.fft_memory,
+ );
+ let h_r = poly_interpret_eval(
+ &mem.points_h,
+ &mem.poly_mem.roots_2n_inverted,
+ eval_at,
+ &mut mem.poly_mem.coeffs,
+ &mut mem.poly_mem.fft_memory,
+ );
+
+ Ok(VerificationMessage { f_r, g_r, h_r })
+}
+
+/// Decides if the distributed proof is valid
+pub fn is_valid_share<F: FieldElement>(
+ v1: &VerificationMessage<F>,
+ v2: &VerificationMessage<F>,
+) -> bool {
+ // reconstruct f_r, g_r, h_r
+ let f_r = v1.f_r + v2.f_r;
+ let g_r = v1.g_r + v2.g_r;
+ let h_r = v1.h_r + v2.h_r;
+ // validity check
+ f_r * g_r == h_r
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ encrypt::{encrypt_share, PublicKey},
+ field::{Field32, FieldPrio2},
+ test_vector::Priov2TestVector,
+ util::{self, unpack_proof_mut},
+ };
+ use serde_json;
+
+ #[test]
+ fn test_validation() {
+ let dim = 8;
+ let proof_u32: Vec<u32> = vec![
+ 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
+ 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
+ 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
+ ];
+
+ let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
+ let share2 = util::tests::secret_share(&mut proof);
+ let eval_at = Field32::from(12313);
+
+ let mut validation_mem = ValidationMemory::new(dim);
+
+ let v1 =
+ generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
+ let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
+ .unwrap();
+ assert!(is_valid_share(&v1, &v2));
+ }
+
+ #[test]
+ fn test_verification_message_serde() {
+ let dim = 8;
+ let proof_u32: Vec<u32> = vec![
+ 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
+ 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
+ 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
+ ];
+
+ let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
+ let share2 = util::tests::secret_share(&mut proof);
+ let eval_at = Field32::from(12313);
+
+ let mut validation_mem = ValidationMemory::new(dim);
+
+ let v1 =
+ generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
+ let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
+ .unwrap();
+
+ // serialize and deserialize the first verification message
+ let serialized = serde_json::to_string(&v1).unwrap();
+ let deserialized: VerificationMessage<Field32> = serde_json::from_str(&serialized).unwrap();
+
+ assert!(is_valid_share(&deserialized, &v2));
+ }
+
+ #[derive(Debug, Clone, Copy, PartialEq)]
+ enum Tweak {
+ None,
+ WrongInput,
+ DataPartOfShare,
+ ZeroTermF,
+ ZeroTermG,
+ ZeroTermH,
+ PointsH,
+ VerificationF,
+ VerificationG,
+ VerificationH,
+ }
+
+ fn tweaks(tweak: Tweak) {
+ let dim = 123;
+
+ // We generate a test vector just to get a `Client` and `Server`s with
+ // encryption keys but construct and tweak inputs below.
+ let test_vector = Priov2TestVector::new(dim, 0).unwrap();
+ let mut server1 = test_vector.server_1().unwrap();
+ let mut server2 = test_vector.server_2().unwrap();
+ let mut client = test_vector.client().unwrap();
+
+ // all zero data
+ let mut data = vec![FieldPrio2::zero(); dim];
+
+ if let Tweak::WrongInput = tweak {
+ data[0] = FieldPrio2::from(2);
+ }
+
+ let (share1_original, share2) = client.encode_simple(&data).unwrap();
+
+ let decrypted_share1 = decrypt_share(&share1_original, &server1.private_key).unwrap();
+ let mut share1_field = FieldPrio2::byte_slice_into_vec(&decrypted_share1).unwrap();
+ let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap();
+
+ let one = FieldPrio2::from(1);
+
+ match tweak {
+ Tweak::DataPartOfShare => unpacked_share1.data[0] += one,
+ Tweak::ZeroTermF => *unpacked_share1.f0 += one,
+ Tweak::ZeroTermG => *unpacked_share1.g0 += one,
+ Tweak::ZeroTermH => *unpacked_share1.h0 += one,
+ Tweak::PointsH => unpacked_share1.points_h_packed[0] += one,
+ _ => (),
+ };
+
+ // reserialize altered share1
+ let share1_modified = encrypt_share(
+ &FieldPrio2::slice_into_byte_vec(&share1_field),
+ &PublicKey::from(&server1.private_key),
+ )
+ .unwrap();
+
+ let eval_at = server1.choose_eval_at();
+
+ let mut v1 = server1
+ .generate_verification_message(eval_at, &share1_modified)
+ .unwrap();
+ let v2 = server2
+ .generate_verification_message(eval_at, &share2)
+ .unwrap();
+
+ match tweak {
+ Tweak::VerificationF => v1.f_r += one,
+ Tweak::VerificationG => v1.g_r += one,
+ Tweak::VerificationH => v1.h_r += one,
+ _ => (),
+ }
+
+ let should_be_valid = matches!(tweak, Tweak::None);
+ assert_eq!(
+ server1.aggregate(&share1_modified, &v1, &v2).unwrap(),
+ should_be_valid
+ );
+ assert_eq!(
+ server2.aggregate(&share2, &v1, &v2).unwrap(),
+ should_be_valid
+ );
+ }
+
+ #[test]
+ fn tweak_none() {
+ tweaks(Tweak::None);
+ }
+
+ #[test]
+ fn tweak_input() {
+ tweaks(Tweak::WrongInput);
+ }
+
+ #[test]
+ fn tweak_data() {
+ tweaks(Tweak::DataPartOfShare);
+ }
+
+ #[test]
+ fn tweak_f_zero() {
+ tweaks(Tweak::ZeroTermF);
+ }
+
+ #[test]
+ fn tweak_g_zero() {
+ tweaks(Tweak::ZeroTermG);
+ }
+
+ #[test]
+ fn tweak_h_zero() {
+ tweaks(Tweak::ZeroTermH);
+ }
+
+ #[test]
+ fn tweak_h_points() {
+ tweaks(Tweak::PointsH);
+ }
+
+ #[test]
+ fn tweak_f_verif() {
+ tweaks(Tweak::VerificationF);
+ }
+
+ #[test]
+ fn tweak_g_verif() {
+ tweaks(Tweak::VerificationG);
+ }
+
+ #[test]
+ fn tweak_h_verif() {
+ tweaks(Tweak::VerificationH);
+ }
+}