summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/tests/discrete_gauss.rs
blob: 5b3ef4c5b3766f3e1ddb7a0fc0b4477511e8549a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// SPDX-License-Identifier: MPL-2.0

use num_bigint::{BigInt, BigUint};
use num_rational::Ratio;
use num_traits::FromPrimitive;
use prio::dp::distributions::DiscreteGaussian;
use prio::vdaf::xof::SeedStreamSha3;
use rand::distributions::Distribution;
use rand::SeedableRng;
use serde::Deserialize;

/// A test vector of discrete Gaussian samples, produced by the python reference
/// implementation for [[CKS20]]. The script used to generate the test vector can
/// be found in this gist:
/// https://gist.github.com/ooovi/529c00fc8a7eafd068cd076b78fc424e
/// The python reference implementation is here:
/// https://github.com/IBM/discrete-gaussian-differential-privacy
///
/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
#[derive(Debug, Eq, PartialEq, Deserialize)]
pub struct DiscreteGaussTestVector {
    #[serde(with = "hex")]
    seed: [u8; 16],
    std_num: u128,
    std_denom: u128,
    samples: Vec<i128>,
}

#[test]
fn discrete_gauss_reference() {
    let test_vectors: Vec<DiscreteGaussTestVector> = vec![
        serde_json::from_str(include_str!(concat!("test_vectors/discrete_gauss_3.json"))).unwrap(),
        serde_json::from_str(include_str!(concat!("test_vectors/discrete_gauss_9.json"))).unwrap(),
        serde_json::from_str(include_str!(concat!(
            "test_vectors/discrete_gauss_100.json"
        )))
        .unwrap(),
        serde_json::from_str(include_str!(concat!(
            "test_vectors/discrete_gauss_41293847.json"
        )))
        .unwrap(),
        serde_json::from_str(include_str!(concat!(
            "test_vectors/discrete_gauss_9999999999999999999999.json"
        )))
        .unwrap(),
    ];

    for test_vector in test_vectors {
        let sampler = DiscreteGaussian::new(Ratio::<BigUint>::new(
            test_vector.std_num.into(),
            test_vector.std_denom.into(),
        ))
        .unwrap();

        // check samples are consistent
        let mut rng = SeedStreamSha3::from_seed(test_vector.seed);
        let samples: Vec<BigInt> = (0..test_vector.samples.len())
            .map(|_| sampler.sample(&mut rng))
            .collect();

        assert_eq!(
            samples,
            test_vector
                .samples
                .iter()
                .map(|&s| BigInt::from_i128(s).unwrap())
                .collect::<Vec::<BigInt>>()
        );
    }
}