diff options
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio2.rs')
-rw-r--r-- | third_party/rust/prio/src/vdaf/prio2.rs | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs index 4669c47d00..ba725d90d7 100644 --- a/third_party/rust/prio/src/vdaf/prio2.rs +++ b/third_party/rust/prio/src/vdaf/prio2.rs @@ -88,8 +88,13 @@ impl Prio2 { ) .map_err(|e| VdafError::Uncategorized(e.to_string()))?; + let truncated_share = match input_share { + Share::Leader(data) => Share::Leader(data[..self.input_len].to_vec()), + Share::Helper(seed) => Share::Helper(seed.clone()), + }; + Ok(( - Prio2PrepareState(input_share.truncated(self.input_len)), + Prio2PrepareState(truncated_share), Prio2PrepareShare(verifier_share), )) } @@ -117,7 +122,6 @@ impl Prio2 { } impl Vdaf for Prio2 { - const ID: u32 = 0xFFFF0000; type Measurement = Vec<u32>; type AggregateResult = Vec<u32>; type AggregationParam = (); @@ -126,6 +130,10 @@ impl Vdaf for Prio2 { type OutputShare = OutputShare<FieldPrio2>; type AggregateShare = AggregateShare<FieldPrio2>; + fn algorithm_id(&self) -> u32 { + 0xFFFF0000 + } + fn num_aggregators(&self) -> usize { // Prio2 can easily be extended to support more than two Aggregators. 2 @@ -184,8 +192,8 @@ impl ConstantTimeEq for Prio2PrepareState { } impl Encode for Prio2PrepareState { - fn encode(&self, bytes: &mut Vec<u8>) { - self.0.encode(bytes); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + self.0.encode(bytes) } fn encoded_len(&self) -> Option<usize> { @@ -213,10 +221,10 @@ impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState { pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>); impl Encode for Prio2PrepareShare { - fn encode(&self, bytes: &mut Vec<u8>) { - self.0.f_r.encode(bytes); - self.0.g_r.encode(bytes); - self.0.h_r.encode(bytes); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + self.0.f_r.encode(bytes)?; + self.0.g_r.encode(bytes)?; + self.0.h_r.encode(bytes) } fn encoded_len(&self) -> Option<usize> { @@ -388,7 +396,7 @@ mod tests { use super::*; use crate::vdaf::{ equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, - run_vdaf, + test_utils::run_vdaf, }; use assert_matches::assert_matches; use rand::prelude::*; @@ -434,7 +442,7 @@ mod tests { ) .unwrap(); - let encoded_prepare_state = prepare_state.get_encoded(); + let encoded_prepare_state = prepare_state.get_encoded().unwrap(); let decoded_prepare_state = Prio2PrepareState::get_decoded_with_param( &(&prio2, agg_id), &encoded_prepare_state, @@ -446,7 +454,7 @@ mod tests { encoded_prepare_state.len() ); - let encoded_prepare_share = prepare_share.get_encoded(); + let encoded_prepare_share = prepare_share.get_encoded().unwrap(); let decoded_prepare_share = Prio2PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share) .expect("failed to decode prepare share"); |