diff options
Diffstat (limited to 'third_party/rust/prio/src/idpf.rs')
-rw-r--r-- | third_party/rust/prio/src/idpf.rs | 152 |
1 files changed, 124 insertions, 28 deletions
diff --git a/third_party/rust/prio/src/idpf.rs b/third_party/rust/prio/src/idpf.rs index 2bb73f2159..b3da128fa0 100644 --- a/third_party/rust/prio/src/idpf.rs +++ b/third_party/rust/prio/src/idpf.rs @@ -1,7 +1,7 @@ //! This module implements the incremental distributed point function (IDPF) described in -//! [[draft-irtf-cfrg-vdaf-07]]. +//! [[draft-irtf-cfrg-vdaf-08]]. //! -//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/ use crate::{ codec::{CodecError, Decode, Encode, ParameterizedDecode}, @@ -24,12 +24,14 @@ use std::{ collections::{HashMap, VecDeque}, fmt::Debug, io::{Cursor, Read}, + iter::zip, ops::{Add, AddAssign, ControlFlow, Index, Sub}, }; use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; /// IDPF-related errors. #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum IdpfError { /// Error from incompatible shares at different levels. #[error("tried to merge shares from incompatible levels")] @@ -107,6 +109,11 @@ impl IdpfInput { index: self.index[..=level].to_owned().into(), } } + + /// Return the bit at the specified level if the level is in bounds. + pub fn get(&self, level: usize) -> Option<bool> { + self.index.get(level).as_deref().copied() + } } impl From<BitVec<usize, Lsb0>> for IdpfInput { @@ -146,7 +153,7 @@ pub trait IdpfValue: + Sub<Output = Self> + ConditionallyNegatable + Encode - + Decode + + ParameterizedDecode<Self::ValueParameter> + Sized { /// Any run-time parameters needed to produce a value. @@ -239,11 +246,13 @@ fn extend(seed: &[u8; 16], xof_fixed_key: &XofFixedKeyAes128Key) -> ([[u8; 16]; seed_stream.fill_bytes(&mut seeds[0]); seed_stream.fill_bytes(&mut seeds[1]); - let mut byte = [0u8]; - seed_stream.fill_bytes(&mut byte); - let control_bits = [(byte[0] & 1).into(), ((byte[0] >> 1) & 1).into()]; + // "Steal" the control bits from the seeds. + let control_bits_0 = seeds[0].as_ref()[0] & 1; + let control_bits_1 = seeds[1].as_ref()[0] & 1; + seeds[0].as_mut()[0] &= 0xfe; + seeds[1].as_mut()[0] &= 0xfe; - (seeds, control_bits) + (seeds, [control_bits_0.into(), control_bits_1.into()]) } fn convert<V>( @@ -670,7 +679,7 @@ where VI: Encode, VL: Encode, { - fn encode(&self, bytes: &mut Vec<u8>) { + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { // Control bits need to be written within each byte in LSB-to-MSB order, and assigned into // bytes in big-endian order. Thus, the first four levels will have their control bits // encoded in the last byte, and the last levels will have their control bits encoded in the @@ -691,11 +700,11 @@ where bytes.append(&mut packed_control); for correction_words in self.inner_correction_words.iter() { - Seed(correction_words.seed).encode(bytes); - correction_words.value.encode(bytes); + Seed(correction_words.seed).encode(bytes)?; + correction_words.value.encode(bytes)?; } - Seed(self.leaf_correction_word.seed).encode(bytes); - self.leaf_correction_word.value.encode(bytes); + Seed(self.leaf_correction_word.seed).encode(bytes)?; + self.leaf_correction_word.value.encode(bytes) } fn encoded_len(&self) -> Option<usize> { @@ -785,7 +794,7 @@ where impl<V> Eq for IdpfCorrectionWord<V> where V: ConstantTimeEq {} -fn xor_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] { +pub(crate) fn xor_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] { let mut seed = [0u8; 16]; for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) { *c = a ^ b; @@ -819,7 +828,7 @@ fn control_bit_to_seed_mask(control: Choice) -> [u8; 16] { /// Take two seeds and a control bit, and return the first seed if the control bit is zero, or the /// XOR of the two seeds if the control bit is one. This does not branch on the control bit. -fn conditional_xor_seeds( +pub(crate) fn conditional_xor_seeds( normal_input: &[u8; 16], switched_input: &[u8; 16], control: Choice, @@ -832,13 +841,18 @@ fn conditional_xor_seeds( /// Returns one of two seeds, depending on the value of a selector bit. Does not branch on the /// selector input or make selector-dependent memory accesses. -fn conditional_select_seed(select: Choice, seeds: &[[u8; 16]; 2]) -> [u8; 16] { +pub(crate) fn conditional_select_seed(select: Choice, seeds: &[[u8; 16]; 2]) -> [u8; 16] { or_seeds( &and_seeds(&control_bit_to_seed_mask(!select), &seeds[0]), &and_seeds(&control_bit_to_seed_mask(select), &seeds[1]), ) } +/// Interchange the contents of seeds if the choice is 1, otherwise seeds remain unchanged. +pub(crate) fn conditional_swap_seed(lhs: &mut [u8; 16], rhs: &mut [u8; 16], choice: Choice) { + zip(lhs, rhs).for_each(|(a, b)| u8::conditional_swap(a, b, choice)); +} + /// An interface that provides memoization of IDPF computations. /// /// Each instance of a type implementing `IdpfCache` should only be used with one IDPF key and @@ -947,11 +961,91 @@ impl IdpfCache for RingBufferCache { } } +/// Utilities for testing IDPFs. +#[cfg(feature = "test-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] +pub mod test_utils { + use super::*; + + use rand::prelude::*; + use zipf::ZipfDistribution; + + /// Generate a set of IDPF inputs with the given bit length `bits`. They are sampled according + /// to the Zipf distribution with parameters `zipf_support` and `zipf_exponent`. Return the + /// measurements, along with the prefixes traversed during the heavy hitters computation for + /// the given threshold. + /// + /// The prefix tree consists of a sequence of candidate prefixes for each level. For a given level, + /// the candidate prefixes are computed from the hit counts of the prefixes at the previous level: + /// For any prefix `p` whose hit count is at least the desired threshold, add `p || 0` and `p || 1` + /// to the list. + pub fn generate_zipf_distributed_batch( + rng: &mut impl Rng, + bits: usize, + threshold: usize, + measurement_count: usize, + zipf_support: usize, + zipf_exponent: f64, + ) -> (Vec<IdpfInput>, Vec<Vec<IdpfInput>>) { + // Generate random inputs. + let mut inputs = Vec::with_capacity(zipf_support); + for _ in 0..zipf_support { + let bools: Vec<bool> = (0..bits).map(|_| rng.gen()).collect(); + inputs.push(IdpfInput::from_bools(&bools)); + } + + // Sample a number of inputs according to the Zipf distribution. + let mut samples = Vec::with_capacity(measurement_count); + let zipf = ZipfDistribution::new(zipf_support, zipf_exponent).unwrap(); + for _ in 0..measurement_count { + samples.push(inputs[zipf.sample(rng) - 1].clone()); + } + + // Compute the prefix tree for the desired threshold. + let mut prefix_tree = Vec::with_capacity(bits); + prefix_tree.push(vec![ + IdpfInput::from_bools(&[false]), + IdpfInput::from_bools(&[true]), + ]); + + for level in 0..bits - 1 { + // Compute the hit count of each prefix from the previous level. + let mut hit_counts = vec![0; prefix_tree[level].len()]; + for (hit_count, prefix) in hit_counts.iter_mut().zip(prefix_tree[level].iter()) { + for sample in samples.iter() { + let mut is_prefix = true; + for j in 0..prefix.len() { + if prefix[j] != sample[j] { + is_prefix = false; + break; + } + } + if is_prefix { + *hit_count += 1; + } + } + } + + // Compute the next set of candidate prefixes. + let mut next_prefixes = Vec::with_capacity(prefix_tree.last().unwrap().len()); + for (hit_count, prefix) in hit_counts.iter().zip(prefix_tree[level].iter()) { + if *hit_count >= threshold { + next_prefixes.push(prefix.clone_with_suffix(&[false])); + next_prefixes.push(prefix.clone_with_suffix(&[true])); + } + } + prefix_tree.push(next_prefixes); + } + + (samples, prefix_tree) + } +} + #[cfg(test)] mod tests { use std::{ collections::HashMap, - convert::{TryFrom, TryInto}, + convert::TryInto, io::Cursor, ops::{Add, AddAssign, Sub}, str::FromStr, @@ -1568,16 +1662,16 @@ mod tests { seed: [0xab; 16], control_bits: [Choice::from(1), Choice::from(0)], value: Poplar1IdpfValue::new([ - Field64::try_from(83261u64).unwrap(), - Field64::try_from(125159u64).unwrap(), + Field64::from(83261u64), + Field64::from(125159u64), ]), }, IdpfCorrectionWord{ seed: [0xcd;16], control_bits: [Choice::from(0), Choice::from(1)], value: Poplar1IdpfValue::new([ - Field64::try_from(17614120u64).unwrap(), - Field64::try_from(20674u64).unwrap(), + Field64::from(17614120u64), + Field64::from(20674u64), ]), }, ]), @@ -1605,7 +1699,7 @@ mod tests { "f0debc9a78563412f0debc9a78563412f0debc9a78563412f0debc9a78563412", // field element correction word, continued )) .unwrap(); - let encoded = public_share.get_encoded(); + let encoded = public_share.get_encoded().unwrap(); let decoded = IdpfPublicShare::get_decoded_with_param(&3, &message).unwrap(); assert_eq!(public_share, decoded); assert_eq!(message, encoded); @@ -1692,7 +1786,7 @@ mod tests { "0000000000000000000000000000000000000000000000000000000000000000", )) .unwrap(); - let encoded = public_share.get_encoded(); + let encoded = public_share.get_encoded().unwrap(); let decoded = IdpfPublicShare::get_decoded_with_param(&9, &message).unwrap(); assert_eq!(public_share, decoded); assert_eq!(message, encoded); @@ -1761,7 +1855,7 @@ mod tests { 0, ); - assert_eq!(public_share.get_encoded(), serialized_public_share); + assert_eq!(public_share.get_encoded().unwrap(), serialized_public_share); assert_eq!( IdpfPublicShare::get_decoded_with_param(&idpf_bits, &serialized_public_share) .unwrap(), @@ -1821,7 +1915,7 @@ mod tests { /// Load a test vector for Idpf key generation. fn load_idpfpoplar_test_vector() -> IdpfTestVector { let test_vec: serde_json::Value = - serde_json::from_str(include_str!("vdaf/test_vec/07/IdpfPoplar_0.json")).unwrap(); + serde_json::from_str(include_str!("vdaf/test_vec/08/IdpfPoplar_0.json")).unwrap(); let test_vec_obj = test_vec.as_object().unwrap(); let bits = test_vec_obj @@ -1939,7 +2033,7 @@ mod tests { public_share, expected_public_share, "public share did not match\n{public_share:#x?}\n{expected_public_share:#x?}" ); - let encoded_public_share = public_share.get_encoded(); + let encoded_public_share = public_share.get_encoded().unwrap(); assert_eq!(encoded_public_share, test_vector.public_share); } @@ -1988,7 +2082,9 @@ mod tests { } impl Encode for MyUnit { - fn encode(&self, _: &mut Vec<u8>) {} + fn encode(&self, _: &mut Vec<u8>) -> Result<(), CodecError> { + Ok(()) + } } impl Decode for MyUnit { @@ -2066,8 +2162,8 @@ mod tests { } impl Encode for MyVector { - fn encode(&self, bytes: &mut Vec<u8>) { - encode_u32_items(bytes, &(), &self.0); + fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> { + encode_u32_items(bytes, &(), &self.0) } } |