summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf.rs
blob: f75a2c488bc9371bed25ff2e77ed77c2fe3c138a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
// SPDX-License-Identifier: MPL-2.0

//! Verifiable Distributed Aggregation Functions (VDAFs) as described in
//! [[draft-irtf-cfrg-vdaf-03]].
//!
//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/

use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
use crate::field::{FieldElement, FieldError};
use crate::flp::FlpError;
use crate::prng::PrngError;
use crate::vdaf::prg::Seed;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::fmt::Debug;
use std::io::Cursor;

/// A component of the domain-separation tag, used to bind the VDAF operations to the document
/// version. This will be revised with each draft with breaking changes.
const VERSION: &[u8] = b"vdaf-03";
/// Length of the domain-separation tag, including document version and algorithm ID.
const DST_LEN: usize = VERSION.len() + 4;

/// Errors emitted by this module.
#[derive(Debug, thiserror::Error)]
pub enum VdafError {
    /// An error occurred.
    #[error("vdaf error: {0}")]
    Uncategorized(String),

    /// Field error.
    #[error("field error: {0}")]
    Field(#[from] FieldError),

    /// An error occured while parsing a message.
    #[error("io error: {0}")]
    IoError(#[from] std::io::Error),

    /// FLP error.
    #[error("flp error: {0}")]
    Flp(#[from] FlpError),

    /// PRNG error.
    #[error("prng error: {0}")]
    Prng(#[from] PrngError),

    /// failure when calling getrandom().
    #[error("getrandom: {0}")]
    GetRandom(#[from] getrandom::Error),
}

/// An additive share of a vector of field elements.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Share<F, const L: usize> {
    /// An uncompressed share, typically sent to the leader.
    Leader(Vec<F>),

    /// A compressed share, typically sent to the helper.
    Helper(Seed<L>),
}

impl<F: Clone, const L: usize> Share<F, L> {
    /// Truncate the Leader's share to the given length. If this is the Helper's share, then this
    /// method clones the input without modifying it.
    #[cfg(feature = "prio2")]
    pub(crate) fn truncated(&self, len: usize) -> Self {
        match self {
            Self::Leader(ref data) => Self::Leader(data[..len].to_vec()),
            Self::Helper(ref seed) => Self::Helper(seed.clone()),
        }
    }
}

/// Parameters needed to decode a [`Share`]
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum ShareDecodingParameter<const L: usize> {
    Leader(usize),
    Helper,
}

impl<F: FieldElement, const L: usize> ParameterizedDecode<ShareDecodingParameter<L>>
    for Share<F, L>
{
    fn decode_with_param(
        decoding_parameter: &ShareDecodingParameter<L>,
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        match decoding_parameter {
            ShareDecodingParameter::Leader(share_length) => {
                let mut data = Vec::with_capacity(*share_length);
                for _ in 0..*share_length {
                    data.push(F::decode(bytes)?)
                }
                Ok(Self::Leader(data))
            }
            ShareDecodingParameter::Helper => {
                let seed = Seed::decode(bytes)?;
                Ok(Self::Helper(seed))
            }
        }
    }
}

impl<F: FieldElement, const L: usize> Encode for Share<F, L> {
    fn encode(&self, bytes: &mut Vec<u8>) {
        match self {
            Share::Leader(share_data) => {
                for x in share_data {
                    x.encode(bytes);
                }
            }
            Share::Helper(share_seed) => {
                share_seed.encode(bytes);
            }
        }
    }
}

/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`],
/// and [`Collector`], which define the roles of the various parties involved in the execution of
/// the VDAF.
// TODO(brandon): once GATs are stabilized [https://github.com/rust-lang/rust/issues/44265],
// state the "&AggregateShare must implement Into<Vec<u8>>" constraint in terms of a where clause
// on the associated type instead of a where clause on the trait.
pub trait Vdaf: Clone + Debug
where
    for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
    /// Algorithm identifier for this VDAF.
    const ID: u32;

    /// The type of Client measurement to be aggregated.
    type Measurement: Clone + Debug;

    /// The aggregate result of the VDAF execution.
    type AggregateResult: Clone + Debug;

    /// The aggregation parameter, used by the Aggregators to map their input shares to output
    /// shares.
    type AggregationParam: Clone + Debug + Decode + Encode;

    /// A public share sent by a Client.
    type PublicShare: Clone + Debug + for<'a> ParameterizedDecode<&'a Self> + Encode;

    /// An input share sent by a Client.
    type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode;

    /// An output share recovered from an input share by an Aggregator.
    type OutputShare: Clone + Debug;

    /// An Aggregator's share of the aggregate result.
    type AggregateShare: Aggregatable<OutputShare = Self::OutputShare> + for<'a> TryFrom<&'a [u8]>;

    /// The number of Aggregators. The Client generates as many input shares as there are
    /// Aggregators.
    fn num_aggregators(&self) -> usize;
}

/// The Client's role in the execution of a VDAF.
pub trait Client: Vdaf
where
    for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
    /// Shards a measurement into a public share and a sequence of input shares, one for each
    /// Aggregator.
    fn shard(
        &self,
        measurement: &Self::Measurement,
    ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError>;
}

/// The Aggregator's role in the execution of a VDAF.
pub trait Aggregator<const L: usize>: Vdaf
where
    for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
    /// State of the Aggregator during the Prepare process.
    type PrepareState: Clone + Debug;

    /// The type of messages broadcast by each aggregator at each round of the Prepare Process.
    ///
    /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
    /// associated with any aggregator involved in the execution of the VDAF.
    type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;

    /// Result of preprocessing a round of preparation shares.
    ///
    /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
    /// associated with any aggregator involved in the execution of the VDAF.
    type PrepareMessage: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;

    /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned
    /// is passed to [`Aggregator::prepare_step`] to get this aggregator's first-round prepare
    /// message.
    fn prepare_init(
        &self,
        verify_key: &[u8; L],
        agg_id: usize,
        agg_param: &Self::AggregationParam,
        nonce: &[u8],
        public_share: &Self::PublicShare,
        input_share: &Self::InputShare,
    ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>;

    /// Preprocess a round of preparation shares into a single input to [`Aggregator::prepare_step`].
    fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>(
        &self,
        inputs: M,
    ) -> Result<Self::PrepareMessage, VdafError>;

    /// Compute the next state transition from the current state and the previous round of input
    /// messages. If this returns [`PrepareTransition::Continue`], then the returned
    /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
    /// this round and passed into another call to this method. This continues until this method
    /// returns [`PrepareTransition::Finish`], at which point the returned output share may be
    /// aggregated. If the method returns an error, the aggregator should consider its input share
    /// invalid and not attempt to process it any further.
    fn prepare_step(
        &self,
        state: Self::PrepareState,
        input: Self::PrepareMessage,
    ) -> Result<PrepareTransition<Self, L>, VdafError>;

    /// Aggregates a sequence of output shares into an aggregate share.
    fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
        &self,
        agg_param: &Self::AggregationParam,
        output_shares: M,
    ) -> Result<Self::AggregateShare, VdafError>;
}

/// The Collector's role in the execution of a VDAF.
pub trait Collector: Vdaf
where
    for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
    /// Combines aggregate shares into the aggregate result.
    fn unshard<M: IntoIterator<Item = Self::AggregateShare>>(
        &self,
        agg_param: &Self::AggregationParam,
        agg_shares: M,
        num_measurements: usize,
    ) -> Result<Self::AggregateResult, VdafError>;
}

/// A state transition of an Aggregator during the Prepare process.
#[derive(Debug)]
pub enum PrepareTransition<V: Aggregator<L>, const L: usize>
where
    for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
{
    /// Continue processing.
    Continue(V::PrepareState, V::PrepareShare),

    /// Finish processing and return the output share.
    Finish(V::OutputShare),
}

/// An aggregate share resulting from aggregating output shares together that
/// can merged with aggregate shares of the same type.
pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> {
    /// Type of output shares that can be accumulated into an aggregate share.
    type OutputShare;

    /// Update an aggregate share by merging it with another (`agg_share`).
    fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>;

    /// Update an aggregate share by adding `output share`
    fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>;
}

/// An output share comprised of a vector of `F` elements.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct OutputShare<F>(Vec<F>);

impl<F> AsRef<[F]> for OutputShare<F> {
    fn as_ref(&self) -> &[F] {
        &self.0
    }
}

impl<F> From<Vec<F>> for OutputShare<F> {
    fn from(other: Vec<F>) -> Self {
        Self(other)
    }
}

impl<F: FieldElement> TryFrom<&[u8]> for OutputShare<F> {
    type Error = FieldError;

    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
        fieldvec_try_from_bytes(bytes)
    }
}

impl<F: FieldElement> From<&OutputShare<F>> for Vec<u8> {
    fn from(output_share: &OutputShare<F>) -> Self {
        fieldvec_to_vec(&output_share.0)
    }
}

/// An aggregate share suitable for VDAFs whose output shares and aggregate
/// shares are vectors of `F` elements, and an output share needs no special
/// transformation to be merged into an aggregate share.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AggregateShare<F>(Vec<F>);

impl<F> AsRef<[F]> for AggregateShare<F> {
    fn as_ref(&self) -> &[F] {
        &self.0
    }
}

impl<F> From<OutputShare<F>> for AggregateShare<F> {
    fn from(other: OutputShare<F>) -> Self {
        Self(other.0)
    }
}

impl<F> From<Vec<F>> for AggregateShare<F> {
    fn from(other: Vec<F>) -> Self {
        Self(other)
    }
}

impl<F: FieldElement> Aggregatable for AggregateShare<F> {
    type OutputShare = OutputShare<F>;

    fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> {
        self.sum(agg_share.as_ref())
    }

    fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> {
        // For prio3 and poplar1, no conversion is needed between output shares and aggregation
        // shares.
        self.sum(output_share.as_ref())
    }
}

impl<F: FieldElement> AggregateShare<F> {
    fn sum(&mut self, other: &[F]) -> Result<(), VdafError> {
        if self.0.len() != other.len() {
            return Err(VdafError::Uncategorized(format!(
                "cannot sum shares of different lengths (left = {}, right = {}",
                self.0.len(),
                other.len()
            )));
        }

        for (x, y) in self.0.iter_mut().zip(other) {
            *x += *y;
        }

        Ok(())
    }
}

impl<F: FieldElement> TryFrom<&[u8]> for AggregateShare<F> {
    type Error = FieldError;

    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
        fieldvec_try_from_bytes(bytes)
    }
}

impl<F: FieldElement> From<&AggregateShare<F>> for Vec<u8> {
    fn from(aggregate_share: &AggregateShare<F>) -> Self {
        fieldvec_to_vec(&aggregate_share.0)
    }
}

/// fieldvec_try_from_bytes converts a slice of bytes to a type that is equivalent to a vector of
/// field elements.
#[inline(always)]
fn fieldvec_try_from_bytes<F: FieldElement, T: From<Vec<F>>>(
    bytes: &[u8],
) -> Result<T, FieldError> {
    F::byte_slice_into_vec(bytes).map(T::from)
}

/// fieldvec_to_vec converts a type that is equivalent to a vector of field elements into a vector
/// of bytes.
#[inline(always)]
fn fieldvec_to_vec<F: FieldElement, T: AsRef<[F]>>(val: T) -> Vec<u8> {
    F::slice_into_byte_vec(val.as_ref())
}

#[cfg(test)]
pub(crate) fn run_vdaf<V, M, const L: usize>(
    vdaf: &V,
    agg_param: &V::AggregationParam,
    measurements: M,
) -> Result<V::AggregateResult, VdafError>
where
    V: Client + Aggregator<L> + Collector,
    for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
    M: IntoIterator<Item = V::Measurement>,
{
    use rand::prelude::*;
    let mut verify_key = [0; L];
    thread_rng().fill(&mut verify_key[..]);

    // NOTE Here we use the same nonce for each measurement for testing purposes. However, this is
    // not secure. In use, the Aggregators MUST ensure that nonces are unique for each measurement.
    let nonce = b"this is a nonce";

    let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
    let mut num_measurements: usize = 0;
    for measurement in measurements.into_iter() {
        num_measurements += 1;
        let (public_share, input_shares) = vdaf.shard(&measurement)?;
        let out_shares = run_vdaf_prepare(
            vdaf,
            &verify_key,
            agg_param,
            nonce,
            public_share,
            input_shares,
        )?;
        for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
            if let Some(ref mut inner) = agg_share {
                inner.merge(&out_share.into())?;
            } else {
                *agg_share = Some(out_share.into());
            }
        }
    }

    let res = vdaf.unshard(
        agg_param,
        agg_shares.into_iter().map(|option| option.unwrap()),
        num_measurements,
    )?;
    Ok(res)
}

#[cfg(test)]
pub(crate) fn run_vdaf_prepare<V, M, const L: usize>(
    vdaf: &V,
    verify_key: &[u8; L],
    agg_param: &V::AggregationParam,
    nonce: &[u8],
    public_share: V::PublicShare,
    input_shares: M,
) -> Result<Vec<V::OutputShare>, VdafError>
where
    V: Client + Aggregator<L> + Collector,
    for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
    M: IntoIterator<Item = V::InputShare>,
{
    let input_shares = input_shares
        .into_iter()
        .map(|input_share| input_share.get_encoded());

    let mut states = Vec::new();
    let mut outbound = Vec::new();
    for (agg_id, input_share) in input_shares.enumerate() {
        let (state, msg) = vdaf.prepare_init(
            verify_key,
            agg_id,
            agg_param,
            nonce,
            &public_share,
            &V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
                .expect("failed to decode input share"),
        )?;
        states.push(state);
        outbound.push(msg.get_encoded());
    }

    let mut inbound = vdaf
        .prepare_preprocess(outbound.iter().map(|encoded| {
            V::PrepareShare::get_decoded_with_param(&states[0], encoded)
                .expect("failed to decode prep share")
        }))?
        .get_encoded();

    let mut out_shares = Vec::new();
    loop {
        let mut outbound = Vec::new();
        for state in states.iter_mut() {
            match vdaf.prepare_step(
                state.clone(),
                V::PrepareMessage::get_decoded_with_param(state, &inbound)
                    .expect("failed to decode prep message"),
            )? {
                PrepareTransition::Continue(new_state, msg) => {
                    outbound.push(msg.get_encoded());
                    *state = new_state
                }
                PrepareTransition::Finish(out_share) => {
                    out_shares.push(out_share);
                }
            }
        }

        if outbound.len() == vdaf.num_aggregators() {
            // Another round is required before output shares are computed.
            inbound = vdaf
                .prepare_preprocess(outbound.iter().map(|encoded| {
                    V::PrepareShare::get_decoded_with_param(&states[0], encoded)
                        .expect("failed to decode prep share")
                }))?
                .get_encoded();
        } else if outbound.is_empty() {
            // Each Aggregator recovered an output share.
            break;
        } else {
            panic!("Aggregators did not finish the prepare phase at the same time");
        }
    }

    Ok(out_shares)
}

#[cfg(test)]
mod tests {
    use super::{AggregateShare, OutputShare};
    use crate::field::{Field128, Field64, FieldElement};
    use itertools::iterate;
    use std::convert::TryFrom;
    use std::fmt::Debug;

    fn fieldvec_roundtrip_test<F, T>()
    where
        F: FieldElement,
        for<'a> T: Debug + PartialEq + From<Vec<F>> + TryFrom<&'a [u8]>,
        for<'a> <T as TryFrom<&'a [u8]>>::Error: Debug,
        for<'a> Vec<u8>: From<&'a T>,
    {
        // Generate a value based on an arbitrary vector of field elements.
        let g = F::generator();
        let want_value = T::from(iterate(F::one(), |&v| g * v).take(10).collect());

        // Round-trip the value through a byte-vector.
        let buf: Vec<u8> = (&want_value).into();
        let got_value = T::try_from(&buf).unwrap();

        assert_eq!(want_value, got_value);
    }

    #[test]
    fn roundtrip_output_share() {
        fieldvec_roundtrip_test::<Field64, OutputShare<Field64>>();
        fieldvec_roundtrip_test::<Field128, OutputShare<Field128>>();
    }

    #[test]
    fn roundtrip_aggregate_share() {
        fieldvec_roundtrip_test::<Field64, AggregateShare<Field64>>();
        fieldvec_roundtrip_test::<Field128, AggregateShare<Field128>>();
    }
}

#[cfg(feature = "crypto-dependencies")]
pub mod poplar1;
pub mod prg;
#[cfg(feature = "prio2")]
pub mod prio2;
pub mod prio3;
#[cfg(test)]
mod prio3_test;