summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prio2.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prio2.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prio2.rs30
1 files changed, 19 insertions, 11 deletions
diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs
index 4669c47d00..ba725d90d7 100644
--- a/third_party/rust/prio/src/vdaf/prio2.rs
+++ b/third_party/rust/prio/src/vdaf/prio2.rs
@@ -88,8 +88,13 @@ impl Prio2 {
)
.map_err(|e| VdafError::Uncategorized(e.to_string()))?;
+ let truncated_share = match input_share {
+ Share::Leader(data) => Share::Leader(data[..self.input_len].to_vec()),
+ Share::Helper(seed) => Share::Helper(seed.clone()),
+ };
+
Ok((
- Prio2PrepareState(input_share.truncated(self.input_len)),
+ Prio2PrepareState(truncated_share),
Prio2PrepareShare(verifier_share),
))
}
@@ -117,7 +122,6 @@ impl Prio2 {
}
impl Vdaf for Prio2 {
- const ID: u32 = 0xFFFF0000;
type Measurement = Vec<u32>;
type AggregateResult = Vec<u32>;
type AggregationParam = ();
@@ -126,6 +130,10 @@ impl Vdaf for Prio2 {
type OutputShare = OutputShare<FieldPrio2>;
type AggregateShare = AggregateShare<FieldPrio2>;
+ fn algorithm_id(&self) -> u32 {
+ 0xFFFF0000
+ }
+
fn num_aggregators(&self) -> usize {
// Prio2 can easily be extended to support more than two Aggregators.
2
@@ -184,8 +192,8 @@ impl ConstantTimeEq for Prio2PrepareState {
}
impl Encode for Prio2PrepareState {
- fn encode(&self, bytes: &mut Vec<u8>) {
- self.0.encode(bytes);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ self.0.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
@@ -213,10 +221,10 @@ impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);
impl Encode for Prio2PrepareShare {
- fn encode(&self, bytes: &mut Vec<u8>) {
- self.0.f_r.encode(bytes);
- self.0.g_r.encode(bytes);
- self.0.h_r.encode(bytes);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ self.0.f_r.encode(bytes)?;
+ self.0.g_r.encode(bytes)?;
+ self.0.h_r.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
@@ -388,7 +396,7 @@ mod tests {
use super::*;
use crate::vdaf::{
equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector,
- run_vdaf,
+ test_utils::run_vdaf,
};
use assert_matches::assert_matches;
use rand::prelude::*;
@@ -434,7 +442,7 @@ mod tests {
)
.unwrap();
- let encoded_prepare_state = prepare_state.get_encoded();
+ let encoded_prepare_state = prepare_state.get_encoded().unwrap();
let decoded_prepare_state = Prio2PrepareState::get_decoded_with_param(
&(&prio2, agg_id),
&encoded_prepare_state,
@@ -446,7 +454,7 @@ mod tests {
encoded_prepare_state.len()
);
- let encoded_prepare_share = prepare_share.get_encoded();
+ let encoded_prepare_share = prepare_share.get_encoded().unwrap();
let decoded_prepare_share =
Prio2PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share)
.expect("failed to decode prepare share");