1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
// SPDX-License-Identifier: MPL-2.0
use num_bigint::{BigInt, BigUint};
use num_rational::Ratio;
use num_traits::FromPrimitive;
use prio::dp::distributions::DiscreteGaussian;
use prio::vdaf::xof::SeedStreamTurboShake128;
use rand::distributions::Distribution;
use rand::SeedableRng;
use serde::Deserialize;
/// A test vector of discrete Gaussian samples, produced by the python reference
/// implementation for [[CKS20]]. The script used to generate the test vector can
/// be found in this gist:
/// https://gist.github.com/divergentdave/94cab188e84a4764db6cdd1288e6ead3
/// The python reference implementation is here:
/// https://github.com/IBM/discrete-gaussian-differential-privacy
///
/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf
#[derive(Debug, Eq, PartialEq, Deserialize)]
pub struct DiscreteGaussTestVector {
#[serde(with = "hex")]
seed: [u8; 16],
std_num: u128,
std_denom: u128,
samples: Vec<i128>,
}
#[test]
fn discrete_gauss_reference() {
let test_vectors: Vec<DiscreteGaussTestVector> = vec![
serde_json::from_str(include_str!("test_vectors/discrete_gauss_3.json")).unwrap(),
serde_json::from_str(include_str!("test_vectors/discrete_gauss_9.json")).unwrap(),
serde_json::from_str(include_str!("test_vectors/discrete_gauss_100.json")).unwrap(),
serde_json::from_str(include_str!("test_vectors/discrete_gauss_41293847.json")).unwrap(),
serde_json::from_str(include_str!(
"test_vectors/discrete_gauss_9999999999999999999999.json"
))
.unwrap(),
serde_json::from_str(include_str!("test_vectors/discrete_gauss_2.342.json")).unwrap(),
];
for test_vector in test_vectors {
let sampler = DiscreteGaussian::new(Ratio::<BigUint>::new(
test_vector.std_num.into(),
test_vector.std_denom.into(),
))
.unwrap();
// check samples are consistent
let mut rng = SeedStreamTurboShake128::from_seed(test_vector.seed);
let samples: Vec<BigInt> = (0..test_vector.samples.len())
.map(|_| sampler.sample(&mut rng))
.collect();
assert_eq!(
samples,
test_vector
.samples
.iter()
.map(|&s| BigInt::from_i128(s).unwrap())
.collect::<Vec::<BigInt>>()
);
}
}
|