summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/vdaf/prg.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/vdaf/prg.rs')
-rw-r--r--third_party/rust/prio/src/vdaf/prg.rs239
1 files changed, 239 insertions, 0 deletions
diff --git a/third_party/rust/prio/src/vdaf/prg.rs b/third_party/rust/prio/src/vdaf/prg.rs
new file mode 100644
index 0000000000..a5930f1283
--- /dev/null
+++ b/third_party/rust/prio/src/vdaf/prg.rs
@@ -0,0 +1,239 @@
+// SPDX-License-Identifier: MPL-2.0
+
+//! Implementations of PRGs specified in [[draft-irtf-cfrg-vdaf-03]].
+//!
+//! [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+
+use crate::vdaf::{CodecError, Decode, Encode};
+#[cfg(feature = "crypto-dependencies")]
+use aes::{
+ cipher::{KeyIvInit, StreamCipher},
+ Aes128,
+};
+#[cfg(feature = "crypto-dependencies")]
+use cmac::{Cmac, Mac};
+#[cfg(feature = "crypto-dependencies")]
+use ctr::Ctr64BE;
+#[cfg(feature = "crypto-dependencies")]
+use std::fmt::Formatter;
+use std::{
+ fmt::Debug,
+ io::{Cursor, Read},
+};
+
+/// Function pointer to fill a buffer with random bytes. Under normal operation,
+/// `getrandom::getrandom()` will be used, but other implementations can be used to control
+/// randomness when generating or verifying test vectors.
+pub(crate) type RandSource = fn(&mut [u8]) -> Result<(), getrandom::Error>;
+
+/// Input of [`Prg`].
+#[derive(Clone, Debug, Eq)]
+pub struct Seed<const L: usize>(pub(crate) [u8; L]);
+
+impl<const L: usize> Seed<L> {
+ /// Generate a uniform random seed.
+ pub fn generate() -> Result<Self, getrandom::Error> {
+ Self::from_rand_source(getrandom::getrandom)
+ }
+
+ pub(crate) fn from_rand_source(rand_source: RandSource) -> Result<Self, getrandom::Error> {
+ let mut seed = [0; L];
+ rand_source(&mut seed)?;
+ Ok(Self(seed))
+ }
+}
+
+impl<const L: usize> AsRef<[u8; L]> for Seed<L> {
+ fn as_ref(&self) -> &[u8; L] {
+ &self.0
+ }
+}
+
+impl<const L: usize> PartialEq for Seed<L> {
+ fn eq(&self, other: &Self) -> bool {
+ // Do constant-time compare.
+ let mut r = 0;
+ for (x, y) in self.0[..].iter().zip(&other.0[..]) {
+ r |= x ^ y;
+ }
+ r == 0
+ }
+}
+
+impl<const L: usize> Encode for Seed<L> {
+ fn encode(&self, bytes: &mut Vec<u8>) {
+ bytes.extend_from_slice(&self.0[..]);
+ }
+}
+
+impl<const L: usize> Decode for Seed<L> {
+ fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
+ let mut seed = [0; L];
+ bytes.read_exact(&mut seed)?;
+ Ok(Seed(seed))
+ }
+}
+
+/// A stream of pseudorandom bytes derived from a seed.
+pub trait SeedStream {
+ /// Fill `buf` with the next `buf.len()` bytes of output.
+ fn fill(&mut self, buf: &mut [u8]);
+}
+
+/// A pseudorandom generator (PRG) with the interface specified in [[draft-irtf-cfrg-vdaf-03]].
+///
+/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+pub trait Prg<const L: usize>: Clone + Debug {
+ /// The type of stream produced by this PRG.
+ type SeedStream: SeedStream;
+
+ /// Construct an instance of [`Prg`] with the given seed.
+ fn init(seed_bytes: &[u8; L]) -> Self;
+
+ /// Update the PRG state by passing in the next fragment of the info string. The final info
+ /// string is assembled from the concatenation of sequence of fragments passed to this method.
+ fn update(&mut self, data: &[u8]);
+
+ /// Finalize the PRG state, producing a seed stream.
+ fn into_seed_stream(self) -> Self::SeedStream;
+
+ /// Finalize the PRG state, producing a seed.
+ fn into_seed(self) -> Seed<L> {
+ let mut new_seed = [0; L];
+ let mut seed_stream = self.into_seed_stream();
+ seed_stream.fill(&mut new_seed);
+ Seed(new_seed)
+ }
+
+ /// Construct a seed stream from the given seed and info string.
+ fn seed_stream(seed: &Seed<L>, info: &[u8]) -> Self::SeedStream {
+ let mut prg = Self::init(seed.as_ref());
+ prg.update(info);
+ prg.into_seed_stream()
+ }
+}
+
+/// The PRG based on AES128 as specified in [[draft-irtf-cfrg-vdaf-03]].
+///
+/// [draft-irtf-cfrg-vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/
+#[derive(Clone, Debug)]
+#[cfg(feature = "crypto-dependencies")]
+pub struct PrgAes128(Cmac<Aes128>);
+
+#[cfg(feature = "crypto-dependencies")]
+impl Prg<16> for PrgAes128 {
+ type SeedStream = SeedStreamAes128;
+
+ fn init(seed_bytes: &[u8; 16]) -> Self {
+ Self(Cmac::new_from_slice(seed_bytes).unwrap())
+ }
+
+ fn update(&mut self, data: &[u8]) {
+ self.0.update(data);
+ }
+
+ fn into_seed_stream(self) -> SeedStreamAes128 {
+ let key = self.0.finalize().into_bytes();
+ SeedStreamAes128::new(&key, &[0; 16])
+ }
+}
+
+/// The key stream produced by AES128 in CTR-mode.
+#[cfg(feature = "crypto-dependencies")]
+pub struct SeedStreamAes128(Ctr64BE<Aes128>);
+
+#[cfg(feature = "crypto-dependencies")]
+impl SeedStreamAes128 {
+ pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self {
+ SeedStreamAes128(Ctr64BE::<Aes128>::new(key.into(), iv.into()))
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl SeedStream for SeedStreamAes128 {
+ fn fill(&mut self, buf: &mut [u8]) {
+ buf.fill(0);
+ self.0.apply_keystream(buf);
+ }
+}
+
+#[cfg(feature = "crypto-dependencies")]
+impl Debug for SeedStreamAes128 {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ // Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that
+ // with [`cipher::StreamCipherCoreWrapper::get_core`][2].
+ //
+ // [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html
+ // [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html
+ self.0.get_core().fmt(f)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{field::Field128, prng::Prng};
+ use serde::{Deserialize, Serialize};
+ use std::convert::TryInto;
+
+ #[derive(Deserialize, Serialize)]
+ struct PrgTestVector {
+ #[serde(with = "hex")]
+ seed: Vec<u8>,
+ #[serde(with = "hex")]
+ info: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex")]
+ derived_seed: Vec<u8>,
+ #[serde(with = "hex")]
+ expanded_vec_field128: Vec<u8>,
+ }
+
+ // Test correctness of dervied methods.
+ fn test_prg<P, const L: usize>()
+ where
+ P: Prg<L>,
+ {
+ let seed = Seed::generate().unwrap();
+ let info = b"info string";
+
+ let mut prg = P::init(seed.as_ref());
+ prg.update(info);
+
+ let mut want = Seed([0; L]);
+ prg.clone().into_seed_stream().fill(&mut want.0[..]);
+ let got = prg.clone().into_seed();
+ assert_eq!(got, want);
+
+ let mut want = [0; 45];
+ prg.clone().into_seed_stream().fill(&mut want);
+ let mut got = [0; 45];
+ P::seed_stream(&seed, info).fill(&mut got);
+ assert_eq!(got, want);
+ }
+
+ #[test]
+ fn prg_aes128() {
+ let t: PrgTestVector =
+ serde_json::from_str(include_str!("test_vec/03/PrgAes128.json")).unwrap();
+ let mut prg = PrgAes128::init(&t.seed.try_into().unwrap());
+ prg.update(&t.info);
+
+ assert_eq!(
+ prg.clone().into_seed(),
+ Seed(t.derived_seed.try_into().unwrap())
+ );
+
+ let mut bytes = std::io::Cursor::new(t.expanded_vec_field128.as_slice());
+ let mut want = Vec::with_capacity(t.length);
+ while (bytes.position() as usize) < t.expanded_vec_field128.len() {
+ want.push(Field128::decode(&mut bytes).unwrap())
+ }
+ let got: Vec<Field128> = Prng::from_seed_stream(prg.clone().into_seed_stream())
+ .take(t.length)
+ .collect();
+ assert_eq!(got, want);
+
+ test_prg::<PrgAes128, 16>();
+ }
+}