summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf.rs
blob: e5f4e14c5a3c05ebe4d745375e18a77714236054 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
// SPDX-License-Identifier: MPL-2.0

//! Verifiable Distributed Aggregation Functions (VDAFs) as described in
//! [[draft-irtf-cfrg-vdaf-08]].
//!
//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/

#[cfg(feature = "experimental")]
use crate::dp::DifferentialPrivacyStrategy;
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
use crate::idpf::IdpfError;
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
use crate::vidpf::VidpfError;
use crate::{
    codec::{CodecError, Decode, Encode, ParameterizedDecode},
    field::{encode_fieldvec, merge_vector, FieldElement, FieldError},
    flp::FlpError,
    prng::PrngError,
    vdaf::xof::Seed,
};
use serde::{Deserialize, Serialize};
use std::{error::Error, fmt::Debug, io::Cursor};
use subtle::{Choice, ConstantTimeEq};

/// A component of the domain-separation tag, used to bind the VDAF operations to the document
/// version. This will be revised with each draft with breaking changes.
pub(crate) const VERSION: u8 = 8;

/// Errors emitted by this module.
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum VdafError {
    /// An error occurred.
    #[error("vdaf error: {0}")]
    Uncategorized(String),

    /// Field error.
    #[error("field error: {0}")]
    Field(#[from] FieldError),

    /// An error occured while parsing a message.
    #[error("io error: {0}")]
    IoError(#[from] std::io::Error),

    /// FLP error.
    #[error("flp error: {0}")]
    Flp(#[from] FlpError),

    /// PRNG error.
    #[error("prng error: {0}")]
    Prng(#[from] PrngError),

    /// Failure when calling getrandom().
    #[error("getrandom: {0}")]
    GetRandom(#[from] getrandom::Error),

    /// IDPF error.
    #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
    #[error("idpf error: {0}")]
    Idpf(#[from] IdpfError),

    /// VIDPF error.
    #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
    #[error("vidpf error: {0}")]
    Vidpf(#[from] VidpfError),

    /// Errors from other VDAFs.
    #[error(transparent)]
    Other(Box<dyn Error + 'static + Send + Sync>),
}

/// An additive share of a vector of field elements.
#[derive(Clone, Debug)]
pub enum Share<F, const SEED_SIZE: usize> {
    /// An uncompressed share, typically sent to the leader.
    Leader(Vec<F>),

    /// A compressed share, typically sent to the helper.
    Helper(Seed<SEED_SIZE>),
}

impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Share<F, SEED_SIZE> {
    fn eq(&self, other: &Self) -> bool {
        self.ct_eq(other).into()
    }
}

impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Share<F, SEED_SIZE> {}

impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Share<F, SEED_SIZE> {
    fn ct_eq(&self, other: &Self) -> subtle::Choice {
        // We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types'
        // contents.
        match (self, other) {
            (Share::Leader(self_val), Share::Leader(other_val)) => self_val.ct_eq(other_val),
            (Share::Helper(self_val), Share::Helper(other_val)) => self_val.ct_eq(other_val),
            _ => Choice::from(0),
        }
    }
}

/// Parameters needed to decode a [`Share`]
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum ShareDecodingParameter<const SEED_SIZE: usize> {
    Leader(usize),
    Helper,
}

impl<F: FieldElement, const SEED_SIZE: usize> ParameterizedDecode<ShareDecodingParameter<SEED_SIZE>>
    for Share<F, SEED_SIZE>
{
    fn decode_with_param(
        decoding_parameter: &ShareDecodingParameter<SEED_SIZE>,
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        match decoding_parameter {
            ShareDecodingParameter::Leader(share_length) => {
                let mut data = Vec::with_capacity(*share_length);
                for _ in 0..*share_length {
                    data.push(F::decode(bytes)?)
                }
                Ok(Self::Leader(data))
            }
            ShareDecodingParameter::Helper => {
                let seed = Seed::decode(bytes)?;
                Ok(Self::Helper(seed))
            }
        }
    }
}

impl<F: FieldElement, const SEED_SIZE: usize> Encode for Share<F, SEED_SIZE> {
    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
        match self {
            Share::Leader(share_data) => {
                for x in share_data {
                    x.encode(bytes)?;
                }
                Ok(())
            }
            Share::Helper(share_seed) => share_seed.encode(bytes),
        }
    }

    fn encoded_len(&self) -> Option<usize> {
        match self {
            Share::Leader(share_data) => {
                // Each element of the data vector has the same size.
                Some(share_data.len() * F::ENCODED_SIZE)
            }
            Share::Helper(share_seed) => share_seed.encoded_len(),
        }
    }
}

/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`],
/// and [`Collector`], which define the roles of the various parties involved in the execution of
/// the VDAF.
pub trait Vdaf: Clone + Debug {
    /// The type of Client measurement to be aggregated.
    type Measurement: Clone + Debug;

    /// The aggregate result of the VDAF execution.
    type AggregateResult: Clone + Debug;

    /// The aggregation parameter, used by the Aggregators to map their input shares to output
    /// shares.
    type AggregationParam: Clone + Debug + Decode + Encode;

    /// A public share sent by a Client.
    type PublicShare: Clone + Debug + ParameterizedDecode<Self> + Encode;

    /// An input share sent by a Client.
    type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode;

    /// An output share recovered from an input share by an Aggregator.
    type OutputShare: Clone
        + Debug
        + for<'a> ParameterizedDecode<(&'a Self, &'a Self::AggregationParam)>
        + Encode;

    /// An Aggregator's share of the aggregate result.
    type AggregateShare: Aggregatable<OutputShare = Self::OutputShare>
        + for<'a> ParameterizedDecode<(&'a Self, &'a Self::AggregationParam)>
        + Encode;

    /// Return the VDAF's algorithm ID.
    fn algorithm_id(&self) -> u32;

    /// The number of Aggregators. The Client generates as many input shares as there are
    /// Aggregators.
    fn num_aggregators(&self) -> usize;

    /// Generate the domain separation tag for this VDAF. The output is used for domain separation
    /// by the XOF.
    fn domain_separation_tag(&self, usage: u16) -> [u8; 8] {
        let mut dst = [0_u8; 8];
        dst[0] = VERSION;
        dst[1] = 0; // algorithm class
        dst[2..6].copy_from_slice(&(self.algorithm_id()).to_be_bytes());
        dst[6..8].copy_from_slice(&usage.to_be_bytes());
        dst
    }
}

/// The Client's role in the execution of a VDAF.
pub trait Client<const NONCE_SIZE: usize>: Vdaf {
    /// Shards a measurement into a public share and a sequence of input shares, one for each
    /// Aggregator.
    ///
    /// Implements `Vdaf::shard` from [VDAF].
    ///
    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.1
    fn shard(
        &self,
        measurement: &Self::Measurement,
        nonce: &[u8; NONCE_SIZE],
    ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError>;
}

/// The Aggregator's role in the execution of a VDAF.
pub trait Aggregator<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: Vdaf {
    /// State of the Aggregator during the Prepare process.
    type PrepareState: Clone + Debug + PartialEq + Eq;

    /// The type of messages sent by each aggregator at each round of the Prepare Process.
    ///
    /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
    /// associated with any aggregator involved in the execution of the VDAF.
    type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;

    /// Result of preprocessing a round of preparation shares. This is used by all aggregators as an
    /// input to the next round of the Prepare Process.
    ///
    /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
    /// associated with any aggregator involved in the execution of the VDAF.
    type PrepareMessage: Clone
        + Debug
        + PartialEq
        + Eq
        + ParameterizedDecode<Self::PrepareState>
        + Encode;

    /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned
    /// is passed to [`Self::prepare_next`] to get this aggregator's first-round prepare message.
    ///
    /// Implements `Vdaf.prep_init` from [VDAF].
    ///
    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2
    fn prepare_init(
        &self,
        verify_key: &[u8; VERIFY_KEY_SIZE],
        agg_id: usize,
        agg_param: &Self::AggregationParam,
        nonce: &[u8; NONCE_SIZE],
        public_share: &Self::PublicShare,
        input_share: &Self::InputShare,
    ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>;

    /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`].
    ///
    /// Implements `Vdaf.prep_shares_to_prep` from [VDAF].
    ///
    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2
    fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Self::PrepareShare>>(
        &self,
        agg_param: &Self::AggregationParam,
        inputs: M,
    ) -> Result<Self::PrepareMessage, VdafError>;

    /// Compute the next state transition from the current state and the previous round of input
    /// messages. If this returns [`PrepareTransition::Continue`], then the returned
    /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
    /// this round and passed into another call to this method. This continues until this method
    /// returns [`PrepareTransition::Finish`], at which point the returned output share may be
    /// aggregated. If the method returns an error, the aggregator should consider its input share
    /// invalid and not attempt to process it any further.
    ///
    /// Implements `Vdaf.prep_next` from [VDAF].
    ///
    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2
    fn prepare_next(
        &self,
        state: Self::PrepareState,
        input: Self::PrepareMessage,
    ) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError>;

    /// Aggregates a sequence of output shares into an aggregate share.
    fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
        &self,
        agg_param: &Self::AggregationParam,
        output_shares: M,
    ) -> Result<Self::AggregateShare, VdafError>;
}

/// Aggregator that implements differential privacy with Aggregator-side noise addition.
#[cfg(feature = "experimental")]
#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
pub trait AggregatorWithNoise<
    const VERIFY_KEY_SIZE: usize,
    const NONCE_SIZE: usize,
    DPStrategy: DifferentialPrivacyStrategy,
>: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>
{
    /// Adds noise to an aggregate share such that the aggregate result is differentially private
    /// as long as one Aggregator is honest.
    fn add_noise_to_agg_share(
        &self,
        dp_strategy: &DPStrategy,
        agg_param: &Self::AggregationParam,
        agg_share: &mut Self::AggregateShare,
        num_measurements: usize,
    ) -> Result<(), VdafError>;
}

/// The Collector's role in the execution of a VDAF.
pub trait Collector: Vdaf {
    /// Combines aggregate shares into the aggregate result.
    fn unshard<M: IntoIterator<Item = Self::AggregateShare>>(
        &self,
        agg_param: &Self::AggregationParam,
        agg_shares: M,
        num_measurements: usize,
    ) -> Result<Self::AggregateResult, VdafError>;
}

/// A state transition of an Aggregator during the Prepare process.
#[derive(Clone, Debug)]
pub enum PrepareTransition<
    V: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
    const VERIFY_KEY_SIZE: usize,
    const NONCE_SIZE: usize,
> {
    /// Continue processing.
    Continue(V::PrepareState, V::PrepareShare),

    /// Finish processing and return the output share.
    Finish(V::OutputShare),
}

/// An aggregate share resulting from aggregating output shares together that
/// can merged with aggregate shares of the same type.
pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> {
    /// Type of output shares that can be accumulated into an aggregate share.
    type OutputShare;

    /// Update an aggregate share by merging it with another (`agg_share`).
    fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>;

    /// Update an aggregate share by adding `output_share`.
    fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>;
}

/// An output share comprised of a vector of field elements.
#[derive(Clone)]
pub struct OutputShare<F>(Vec<F>);

impl<F: ConstantTimeEq> PartialEq for OutputShare<F> {
    fn eq(&self, other: &Self) -> bool {
        self.ct_eq(other).into()
    }
}

impl<F: ConstantTimeEq> Eq for OutputShare<F> {}

impl<F: ConstantTimeEq> ConstantTimeEq for OutputShare<F> {
    fn ct_eq(&self, other: &Self) -> Choice {
        self.0.ct_eq(&other.0)
    }
}

impl<F> AsRef<[F]> for OutputShare<F> {
    fn as_ref(&self) -> &[F] {
        &self.0
    }
}

impl<F> From<Vec<F>> for OutputShare<F> {
    fn from(other: Vec<F>) -> Self {
        Self(other)
    }
}

impl<F: FieldElement> Encode for OutputShare<F> {
    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
        encode_fieldvec(&self.0, bytes)
    }

    fn encoded_len(&self) -> Option<usize> {
        Some(F::ENCODED_SIZE * self.0.len())
    }
}

impl<F> Debug for OutputShare<F> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_tuple("OutputShare").finish()
    }
}

/// An aggregate share comprised of a vector of field elements.
///
/// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field
/// elements, and output shares need no special transformation to be merged into an aggregate share.
#[derive(Clone, Debug, Serialize, Deserialize)]

pub struct AggregateShare<F>(Vec<F>);

impl<F> From<Vec<F>> for AggregateShare<F> {
    fn from(other: Vec<F>) -> Self {
        Self(other)
    }
}

impl<F: ConstantTimeEq> PartialEq for AggregateShare<F> {
    fn eq(&self, other: &Self) -> bool {
        self.ct_eq(other).into()
    }
}

impl<F: ConstantTimeEq> Eq for AggregateShare<F> {}

impl<F: ConstantTimeEq> ConstantTimeEq for AggregateShare<F> {
    fn ct_eq(&self, other: &Self) -> subtle::Choice {
        self.0.ct_eq(&other.0)
    }
}

impl<F: FieldElement> AsRef<[F]> for AggregateShare<F> {
    fn as_ref(&self) -> &[F] {
        &self.0
    }
}

impl<F> From<OutputShare<F>> for AggregateShare<F> {
    fn from(other: OutputShare<F>) -> Self {
        Self(other.0)
    }
}

impl<F: FieldElement> Aggregatable for AggregateShare<F> {
    type OutputShare = OutputShare<F>;

    fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> {
        self.sum(agg_share.as_ref())
    }

    fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> {
        // For Poplar1, Prio2, and Prio3, no conversion is needed between output shares and
        // aggregate shares.
        self.sum(output_share.as_ref())
    }
}

impl<F: FieldElement> AggregateShare<F> {
    fn sum(&mut self, other: &[F]) -> Result<(), VdafError> {
        merge_vector(&mut self.0, other).map_err(Into::into)
    }
}

impl<F: FieldElement> Encode for AggregateShare<F> {
    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
        encode_fieldvec(&self.0, bytes)
    }

    fn encoded_len(&self) -> Option<usize> {
        Some(F::ENCODED_SIZE * self.0.len())
    }
}

/// Utilities for testing VDAFs.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_utils {
    use super::{Aggregatable, Aggregator, Client, Collector, PrepareTransition, VdafError};
    use crate::codec::{Encode, ParameterizedDecode};
    use rand::prelude::*;

    /// Execute the VDAF end-to-end and return the aggregate result.
    pub fn run_vdaf<V, M, const SEED_SIZE: usize>(
        vdaf: &V,
        agg_param: &V::AggregationParam,
        measurements: M,
    ) -> Result<V::AggregateResult, VdafError>
    where
        V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
        M: IntoIterator<Item = V::Measurement>,
    {
        let mut sharded_measurements = Vec::new();
        for measurement in measurements.into_iter() {
            let nonce = random();
            let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?;

            sharded_measurements.push((public_share, nonce, input_shares));
        }

        run_vdaf_sharded(vdaf, agg_param, sharded_measurements)
    }

    /// Execute the VDAF on sharded measurements and return the aggregate result.
    pub fn run_vdaf_sharded<V, M, I, const SEED_SIZE: usize>(
        vdaf: &V,
        agg_param: &V::AggregationParam,
        sharded_measurements: M,
    ) -> Result<V::AggregateResult, VdafError>
    where
        V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
        M: IntoIterator<Item = (V::PublicShare, [u8; 16], I)>,
        I: IntoIterator<Item = V::InputShare>,
    {
        let mut rng = thread_rng();
        let mut verify_key = [0; SEED_SIZE];
        rng.fill(&mut verify_key[..]);

        let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
        let mut num_measurements: usize = 0;
        for (public_share, nonce, input_shares) in sharded_measurements.into_iter() {
            num_measurements += 1;
            let out_shares = run_vdaf_prepare(
                vdaf,
                &verify_key,
                agg_param,
                &nonce,
                public_share,
                input_shares,
            )?;
            for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
                // Check serialization of output shares
                let encoded_out_share = out_share.get_encoded().unwrap();
                let round_trip_out_share =
                    V::OutputShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_out_share)
                        .unwrap();
                assert_eq!(
                    round_trip_out_share.get_encoded().unwrap(),
                    encoded_out_share
                );

                let this_agg_share = V::AggregateShare::from(out_share);
                if let Some(ref mut inner) = agg_share {
                    inner.merge(&this_agg_share)?;
                } else {
                    *agg_share = Some(this_agg_share);
                }
            }
        }

        for agg_share in agg_shares.iter() {
            // Check serialization of aggregate shares
            let encoded_agg_share = agg_share.as_ref().unwrap().get_encoded().unwrap();
            let round_trip_agg_share =
                V::AggregateShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_agg_share)
                    .unwrap();
            assert_eq!(
                round_trip_agg_share.get_encoded().unwrap(),
                encoded_agg_share
            );
        }

        let res = vdaf.unshard(
            agg_param,
            agg_shares.into_iter().map(|option| option.unwrap()),
            num_measurements,
        )?;
        Ok(res)
    }

    /// Execute VDAF preparation for a single report and return the recovered output shares.
    pub fn run_vdaf_prepare<V, M, const SEED_SIZE: usize>(
        vdaf: &V,
        verify_key: &[u8; SEED_SIZE],
        agg_param: &V::AggregationParam,
        nonce: &[u8; 16],
        public_share: V::PublicShare,
        input_shares: M,
    ) -> Result<Vec<V::OutputShare>, VdafError>
    where
        V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
        M: IntoIterator<Item = V::InputShare>,
    {
        let public_share =
            V::PublicShare::get_decoded_with_param(vdaf, &public_share.get_encoded().unwrap())
                .unwrap();
        let input_shares = input_shares
            .into_iter()
            .map(|input_share| input_share.get_encoded().unwrap());

        let mut states = Vec::new();
        let mut outbound = Vec::new();
        for (agg_id, input_share) in input_shares.enumerate() {
            let (state, msg) = vdaf.prepare_init(
                verify_key,
                agg_id,
                agg_param,
                nonce,
                &public_share,
                &V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
                    .expect("failed to decode input share"),
            )?;
            states.push(state);
            outbound.push(msg.get_encoded().unwrap());
        }

        let mut inbound = vdaf
            .prepare_shares_to_prepare_message(
                agg_param,
                outbound.iter().map(|encoded| {
                    V::PrepareShare::get_decoded_with_param(&states[0], encoded)
                        .expect("failed to decode prep share")
                }),
            )?
            .get_encoded()
            .unwrap();

        let mut out_shares = Vec::new();
        loop {
            let mut outbound = Vec::new();
            for state in states.iter_mut() {
                match vdaf.prepare_next(
                    state.clone(),
                    V::PrepareMessage::get_decoded_with_param(state, &inbound)
                        .expect("failed to decode prep message"),
                )? {
                    PrepareTransition::Continue(new_state, msg) => {
                        outbound.push(msg.get_encoded().unwrap());
                        *state = new_state
                    }
                    PrepareTransition::Finish(out_share) => {
                        out_shares.push(out_share);
                    }
                }
            }

            if outbound.len() == vdaf.num_aggregators() {
                // Another round is required before output shares are computed.
                inbound = vdaf
                    .prepare_shares_to_prepare_message(
                        agg_param,
                        outbound.iter().map(|encoded| {
                            V::PrepareShare::get_decoded_with_param(&states[0], encoded)
                                .expect("failed to decode prep share")
                        }),
                    )?
                    .get_encoded()
                    .unwrap();
            } else if outbound.is_empty() {
                // Each Aggregator recovered an output share.
                break;
            } else {
                panic!("Aggregators did not finish the prepare phase at the same time");
            }
        }

        Ok(out_shares)
    }
}

#[cfg(test)]
fn fieldvec_roundtrip_test<F, V, T>(vdaf: &V, agg_param: &V::AggregationParam, length: usize)
where
    F: FieldElement,
    V: Vdaf,
    T: Encode,
    for<'a> T: ParameterizedDecode<(&'a V, &'a V::AggregationParam)>,
{
    // Generate an arbitrary vector of field elements.
    let vec: Vec<F> = crate::field::random_vector(length).unwrap();

    // Serialize the field element vector into a vector of bytes.
    let mut bytes = Vec::with_capacity(vec.len() * F::ENCODED_SIZE);
    encode_fieldvec(&vec, &mut bytes).unwrap();

    // Deserialize the type of interest from those bytes.
    let value = T::get_decoded_with_param(&(vdaf, agg_param), &bytes).unwrap();

    // Round-trip the value back to a vector of bytes.
    let encoded = value.get_encoded().unwrap();

    assert_eq!(encoded, bytes);
}

#[cfg(test)]
fn equality_comparison_test<T>(values: &[T])
where
    T: Debug + PartialEq,
{
    use std::ptr;

    // This function expects that every value passed in `values` is distinct, i.e. should not
    // compare as equal to any other element. We test both (i, j) and (j, i) to gain confidence that
    // equality implementations are symmetric.
    for (i, i_val) in values.iter().enumerate() {
        for (j, j_val) in values.iter().enumerate() {
            if i == j {
                assert!(ptr::eq(i_val, j_val)); // sanity
                assert_eq!(
                    i_val, j_val,
                    "Expected element at index {i} to be equal to itself, but it was not"
                );
            } else {
                assert_ne!(
                    i_val, j_val,
                    "Expected elements at indices {i} & {j} to not be equal, but they were"
                )
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::vdaf::{equality_comparison_test, xof::Seed, AggregateShare, OutputShare, Share};

    #[test]
    fn share_equality_test() {
        equality_comparison_test(&[
            Share::Leader(Vec::from([1, 2, 3])),
            Share::Leader(Vec::from([3, 2, 1])),
            Share::Helper(Seed([1, 2, 3])),
            Share::Helper(Seed([3, 2, 1])),
        ])
    }

    #[test]
    fn output_share_equality_test() {
        equality_comparison_test(&[
            OutputShare(Vec::from([1, 2, 3])),
            OutputShare(Vec::from([3, 2, 1])),
        ])
    }

    #[test]
    fn aggregate_share_equality_test() {
        equality_comparison_test(&[
            AggregateShare(Vec::from([1, 2, 3])),
            AggregateShare(Vec::from([3, 2, 1])),
        ])
    }
}

#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod dummy;
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
#[cfg_attr(
    docsrs,
    doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
)]
pub mod poplar1;
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
#[cfg_attr(
    docsrs,
    doc(cfg(all(feature = "crypto-dependencies", feature = "experimental")))
)]
pub mod prio2;
pub mod prio3;
#[cfg(any(test, feature = "test-util"))]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod prio3_test;
pub mod xof;