summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/idpf.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/idpf.rs')
-rw-r--r--third_party/rust/prio/src/idpf.rs152
1 files changed, 124 insertions, 28 deletions
diff --git a/third_party/rust/prio/src/idpf.rs b/third_party/rust/prio/src/idpf.rs
index 2bb73f2159..b3da128fa0 100644
--- a/third_party/rust/prio/src/idpf.rs
+++ b/third_party/rust/prio/src/idpf.rs
@@ -1,7 +1,7 @@
//! This module implements the incremental distributed point function (IDPF) described in
-//! [[draft-irtf-cfrg-vdaf-07]].
+//! [[draft-irtf-cfrg-vdaf-08]].
//!
-//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/
+//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/
use crate::{
codec::{CodecError, Decode, Encode, ParameterizedDecode},
@@ -24,12 +24,14 @@ use std::{
collections::{HashMap, VecDeque},
fmt::Debug,
io::{Cursor, Read},
+ iter::zip,
ops::{Add, AddAssign, ControlFlow, Index, Sub},
};
use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
/// IDPF-related errors.
#[derive(Debug, thiserror::Error)]
+#[non_exhaustive]
pub enum IdpfError {
/// Error from incompatible shares at different levels.
#[error("tried to merge shares from incompatible levels")]
@@ -107,6 +109,11 @@ impl IdpfInput {
index: self.index[..=level].to_owned().into(),
}
}
+
+ /// Return the bit at the specified level if the level is in bounds.
+ pub fn get(&self, level: usize) -> Option<bool> {
+ self.index.get(level).as_deref().copied()
+ }
}
impl From<BitVec<usize, Lsb0>> for IdpfInput {
@@ -146,7 +153,7 @@ pub trait IdpfValue:
+ Sub<Output = Self>
+ ConditionallyNegatable
+ Encode
- + Decode
+ + ParameterizedDecode<Self::ValueParameter>
+ Sized
{
/// Any run-time parameters needed to produce a value.
@@ -239,11 +246,13 @@ fn extend(seed: &[u8; 16], xof_fixed_key: &XofFixedKeyAes128Key) -> ([[u8; 16];
seed_stream.fill_bytes(&mut seeds[0]);
seed_stream.fill_bytes(&mut seeds[1]);
- let mut byte = [0u8];
- seed_stream.fill_bytes(&mut byte);
- let control_bits = [(byte[0] & 1).into(), ((byte[0] >> 1) & 1).into()];
+ // "Steal" the control bits from the seeds.
+ let control_bits_0 = seeds[0].as_ref()[0] & 1;
+ let control_bits_1 = seeds[1].as_ref()[0] & 1;
+ seeds[0].as_mut()[0] &= 0xfe;
+ seeds[1].as_mut()[0] &= 0xfe;
- (seeds, control_bits)
+ (seeds, [control_bits_0.into(), control_bits_1.into()])
}
fn convert<V>(
@@ -670,7 +679,7 @@ where
VI: Encode,
VL: Encode,
{
- fn encode(&self, bytes: &mut Vec<u8>) {
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
// Control bits need to be written within each byte in LSB-to-MSB order, and assigned into
// bytes in big-endian order. Thus, the first four levels will have their control bits
// encoded in the last byte, and the last levels will have their control bits encoded in the
@@ -691,11 +700,11 @@ where
bytes.append(&mut packed_control);
for correction_words in self.inner_correction_words.iter() {
- Seed(correction_words.seed).encode(bytes);
- correction_words.value.encode(bytes);
+ Seed(correction_words.seed).encode(bytes)?;
+ correction_words.value.encode(bytes)?;
}
- Seed(self.leaf_correction_word.seed).encode(bytes);
- self.leaf_correction_word.value.encode(bytes);
+ Seed(self.leaf_correction_word.seed).encode(bytes)?;
+ self.leaf_correction_word.value.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
@@ -785,7 +794,7 @@ where
impl<V> Eq for IdpfCorrectionWord<V> where V: ConstantTimeEq {}
-fn xor_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] {
+pub(crate) fn xor_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] {
let mut seed = [0u8; 16];
for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) {
*c = a ^ b;
@@ -819,7 +828,7 @@ fn control_bit_to_seed_mask(control: Choice) -> [u8; 16] {
/// Take two seeds and a control bit, and return the first seed if the control bit is zero, or the
/// XOR of the two seeds if the control bit is one. This does not branch on the control bit.
-fn conditional_xor_seeds(
+pub(crate) fn conditional_xor_seeds(
normal_input: &[u8; 16],
switched_input: &[u8; 16],
control: Choice,
@@ -832,13 +841,18 @@ fn conditional_xor_seeds(
/// Returns one of two seeds, depending on the value of a selector bit. Does not branch on the
/// selector input or make selector-dependent memory accesses.
-fn conditional_select_seed(select: Choice, seeds: &[[u8; 16]; 2]) -> [u8; 16] {
+pub(crate) fn conditional_select_seed(select: Choice, seeds: &[[u8; 16]; 2]) -> [u8; 16] {
or_seeds(
&and_seeds(&control_bit_to_seed_mask(!select), &seeds[0]),
&and_seeds(&control_bit_to_seed_mask(select), &seeds[1]),
)
}
+/// Interchange the contents of seeds if the choice is 1, otherwise seeds remain unchanged.
+pub(crate) fn conditional_swap_seed(lhs: &mut [u8; 16], rhs: &mut [u8; 16], choice: Choice) {
+ zip(lhs, rhs).for_each(|(a, b)| u8::conditional_swap(a, b, choice));
+}
+
/// An interface that provides memoization of IDPF computations.
///
/// Each instance of a type implementing `IdpfCache` should only be used with one IDPF key and
@@ -947,11 +961,91 @@ impl IdpfCache for RingBufferCache {
}
}
+/// Utilities for testing IDPFs.
+#[cfg(feature = "test-util")]
+#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
+pub mod test_utils {
+ use super::*;
+
+ use rand::prelude::*;
+ use zipf::ZipfDistribution;
+
+ /// Generate a set of IDPF inputs with the given bit length `bits`. They are sampled according
+ /// to the Zipf distribution with parameters `zipf_support` and `zipf_exponent`. Return the
+ /// measurements, along with the prefixes traversed during the heavy hitters computation for
+ /// the given threshold.
+ ///
+ /// The prefix tree consists of a sequence of candidate prefixes for each level. For a given level,
+ /// the candidate prefixes are computed from the hit counts of the prefixes at the previous level:
+ /// For any prefix `p` whose hit count is at least the desired threshold, add `p || 0` and `p || 1`
+ /// to the list.
+ pub fn generate_zipf_distributed_batch(
+ rng: &mut impl Rng,
+ bits: usize,
+ threshold: usize,
+ measurement_count: usize,
+ zipf_support: usize,
+ zipf_exponent: f64,
+ ) -> (Vec<IdpfInput>, Vec<Vec<IdpfInput>>) {
+ // Generate random inputs.
+ let mut inputs = Vec::with_capacity(zipf_support);
+ for _ in 0..zipf_support {
+ let bools: Vec<bool> = (0..bits).map(|_| rng.gen()).collect();
+ inputs.push(IdpfInput::from_bools(&bools));
+ }
+
+ // Sample a number of inputs according to the Zipf distribution.
+ let mut samples = Vec::with_capacity(measurement_count);
+ let zipf = ZipfDistribution::new(zipf_support, zipf_exponent).unwrap();
+ for _ in 0..measurement_count {
+ samples.push(inputs[zipf.sample(rng) - 1].clone());
+ }
+
+ // Compute the prefix tree for the desired threshold.
+ let mut prefix_tree = Vec::with_capacity(bits);
+ prefix_tree.push(vec![
+ IdpfInput::from_bools(&[false]),
+ IdpfInput::from_bools(&[true]),
+ ]);
+
+ for level in 0..bits - 1 {
+ // Compute the hit count of each prefix from the previous level.
+ let mut hit_counts = vec![0; prefix_tree[level].len()];
+ for (hit_count, prefix) in hit_counts.iter_mut().zip(prefix_tree[level].iter()) {
+ for sample in samples.iter() {
+ let mut is_prefix = true;
+ for j in 0..prefix.len() {
+ if prefix[j] != sample[j] {
+ is_prefix = false;
+ break;
+ }
+ }
+ if is_prefix {
+ *hit_count += 1;
+ }
+ }
+ }
+
+ // Compute the next set of candidate prefixes.
+ let mut next_prefixes = Vec::with_capacity(prefix_tree.last().unwrap().len());
+ for (hit_count, prefix) in hit_counts.iter().zip(prefix_tree[level].iter()) {
+ if *hit_count >= threshold {
+ next_prefixes.push(prefix.clone_with_suffix(&[false]));
+ next_prefixes.push(prefix.clone_with_suffix(&[true]));
+ }
+ }
+ prefix_tree.push(next_prefixes);
+ }
+
+ (samples, prefix_tree)
+ }
+}
+
#[cfg(test)]
mod tests {
use std::{
collections::HashMap,
- convert::{TryFrom, TryInto},
+ convert::TryInto,
io::Cursor,
ops::{Add, AddAssign, Sub},
str::FromStr,
@@ -1568,16 +1662,16 @@ mod tests {
seed: [0xab; 16],
control_bits: [Choice::from(1), Choice::from(0)],
value: Poplar1IdpfValue::new([
- Field64::try_from(83261u64).unwrap(),
- Field64::try_from(125159u64).unwrap(),
+ Field64::from(83261u64),
+ Field64::from(125159u64),
]),
},
IdpfCorrectionWord{
seed: [0xcd;16],
control_bits: [Choice::from(0), Choice::from(1)],
value: Poplar1IdpfValue::new([
- Field64::try_from(17614120u64).unwrap(),
- Field64::try_from(20674u64).unwrap(),
+ Field64::from(17614120u64),
+ Field64::from(20674u64),
]),
},
]),
@@ -1605,7 +1699,7 @@ mod tests {
"f0debc9a78563412f0debc9a78563412f0debc9a78563412f0debc9a78563412", // field element correction word, continued
))
.unwrap();
- let encoded = public_share.get_encoded();
+ let encoded = public_share.get_encoded().unwrap();
let decoded = IdpfPublicShare::get_decoded_with_param(&3, &message).unwrap();
assert_eq!(public_share, decoded);
assert_eq!(message, encoded);
@@ -1692,7 +1786,7 @@ mod tests {
"0000000000000000000000000000000000000000000000000000000000000000",
))
.unwrap();
- let encoded = public_share.get_encoded();
+ let encoded = public_share.get_encoded().unwrap();
let decoded = IdpfPublicShare::get_decoded_with_param(&9, &message).unwrap();
assert_eq!(public_share, decoded);
assert_eq!(message, encoded);
@@ -1761,7 +1855,7 @@ mod tests {
0,
);
- assert_eq!(public_share.get_encoded(), serialized_public_share);
+ assert_eq!(public_share.get_encoded().unwrap(), serialized_public_share);
assert_eq!(
IdpfPublicShare::get_decoded_with_param(&idpf_bits, &serialized_public_share)
.unwrap(),
@@ -1821,7 +1915,7 @@ mod tests {
/// Load a test vector for Idpf key generation.
fn load_idpfpoplar_test_vector() -> IdpfTestVector {
let test_vec: serde_json::Value =
- serde_json::from_str(include_str!("vdaf/test_vec/07/IdpfPoplar_0.json")).unwrap();
+ serde_json::from_str(include_str!("vdaf/test_vec/08/IdpfPoplar_0.json")).unwrap();
let test_vec_obj = test_vec.as_object().unwrap();
let bits = test_vec_obj
@@ -1939,7 +2033,7 @@ mod tests {
public_share, expected_public_share,
"public share did not match\n{public_share:#x?}\n{expected_public_share:#x?}"
);
- let encoded_public_share = public_share.get_encoded();
+ let encoded_public_share = public_share.get_encoded().unwrap();
assert_eq!(encoded_public_share, test_vector.public_share);
}
@@ -1988,7 +2082,9 @@ mod tests {
}
impl Encode for MyUnit {
- fn encode(&self, _: &mut Vec<u8>) {}
+ fn encode(&self, _: &mut Vec<u8>) -> Result<(), CodecError> {
+ Ok(())
+ }
}
impl Decode for MyUnit {
@@ -2066,8 +2162,8 @@ mod tests {
}
impl Encode for MyVector {
- fn encode(&self, bytes: &mut Vec<u8>) {
- encode_u32_items(bytes, &(), &self.0);
+ fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
+ encode_u32_items(bytes, &(), &self.0)
}
}