diff options
Diffstat (limited to 'third_party/rust/prio/src/flp/types')
-rw-r--r-- | third_party/rust/prio/src/flp/types/fixedpoint_l2.rs | 131 |
1 files changed, 66 insertions, 65 deletions
diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs index b5aa2fd116..8766c035b8 100644 --- a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs +++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs @@ -172,15 +172,17 @@ pub mod compatible_float; use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError}; -use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt}; +use crate::field::{ + Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt, Integer, +}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat; use crate::flp::types::parallel_sum_range_checks; use crate::flp::{FlpError, Gadget, Type, TypeWithNoise}; -use crate::vdaf::xof::SeedStreamSha3; +use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::traits::Fixed; use num_bigint::{BigInt, BigUint, TryFromBigIntError}; -use num_integer::Integer; +use num_integer::Integer as _; use num_rational::Ratio; use rand::{distributions::Distribution, Rng}; use rand_core::SeedableRng; @@ -250,7 +252,7 @@ where /// fixed point vector with `entries` entries. pub fn new(entries: usize) -> Result<Self, FlpError> { // (0) initialize constants - let fi_one = u128::from(Field128::one()); + let fi_one = <Field128 as FieldElementWithInteger>::Integer::one(); // (I) Check that the fixed type is compatible. // @@ -400,7 +402,6 @@ where SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static, SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static, { - const ID: u32 = 0xFFFF0000; type Measurement = Vec<T>; type AggregateResult = Vec<f64>; type Field = Field128; @@ -419,12 +420,9 @@ where // Encode the integer entries bitwise, and write them into the `encoded` // vector. let mut encoded: Vec<Field128> = - vec![Field128::zero(); self.bits_per_entry * self.entries + self.bits_for_norm]; - for (l, entry) in integer_entries.clone().enumerate() { - Field128::fill_with_bitvector_representation( - &entry, - &mut encoded[l * self.bits_per_entry..(l + 1) * self.bits_per_entry], - )?; + Vec::with_capacity(self.bits_per_entry * self.entries + self.bits_for_norm); + for entry in integer_entries.clone() { + encoded.extend(Field128::encode_as_bitvector(entry, self.bits_per_entry)?); } // (II) Vector norm. @@ -434,10 +432,7 @@ where let norm_int = u128::from(norm); // Write the norm into the `entries` vector. - Field128::fill_with_bitvector_representation( - &norm_int, - &mut encoded[self.range_norm_begin..self.range_norm_end], - )?; + encoded.extend(Field128::encode_as_bitvector(norm_int, self.bits_for_norm)?); Ok(encoded) } @@ -535,7 +530,7 @@ where // decode the bit-encoded entries into elements in the range [0,2^n): let decoded_entries: Result<Vec<_>, _> = input[0..self.entries * self.bits_per_entry] .chunks(self.bits_per_entry) - .map(Field128::decode_from_bitvector_representation) + .map(Field128::decode_bitvector) .collect(); // run parallel sum gadget on the decoded entries @@ -544,7 +539,7 @@ where // Chunks which are too short need to be extended with a share of the // encoded zero value, that is: 1/num_shares * (2^(n-1)) - let fi_one = u128::from(Field128::one()); + let fi_one = <Field128 as FieldElementWithInteger>::Integer::one(); let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1)); let zero_enc_share = zero_enc * num_shares_inverse; @@ -567,7 +562,7 @@ where // The submitted norm is also decoded from its bit-encoding, and // compared with the computed norm. let submitted_norm_enc = &input[self.range_norm_begin..self.range_norm_end]; - let submitted_norm = Field128::decode_from_bitvector_representation(submitted_norm_enc)?; + let submitted_norm = Field128::decode_bitvector(submitted_norm_enc)?; let norm_check = computed_norm - submitted_norm; @@ -586,7 +581,7 @@ where let start = i_entry * self.bits_per_entry; let end = (i_entry + 1) * self.bits_per_entry; - let decoded = Field128::decode_from_bitvector_representation(&input[start..end])?; + let decoded = Field128::decode_bitvector(&input[start..end])?; decoded_vector.push(decoded); } Ok(decoded_vector) @@ -644,7 +639,11 @@ where agg_result: &mut [Self::Field], _num_measurements: usize, ) -> Result<(), FlpError> { - self.add_noise(dp_strategy, agg_result, &mut SeedStreamSha3::from_entropy()) + self.add_noise( + dp_strategy, + agg_result, + &mut SeedStreamTurboShake128::from_entropy(), + ) } } @@ -686,8 +685,8 @@ mod tests { use crate::dp::{Rational, ZCdpBudget}; use crate::field::{random_vector, Field128, FieldElement}; use crate::flp::gadgets::ParallelSum; - use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; - use crate::vdaf::xof::SeedStreamSha3; + use crate::flp::test_utils::FlpTest; + use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::types::extra::{U127, U14, U63}; use fixed::{FixedI128, FixedI16, FixedI64}; use fixed_macro::fixed; @@ -768,15 +767,23 @@ mod tests { let strategy = ZCdpDiscreteGaussian::from_budget(ZCdpBudget::new( Rational::from_unsigned(100u8, 3u8).unwrap(), )); - vsum.add_noise(&strategy, &mut v, &mut SeedStreamSha3::from_seed([0u8; 16])) - .unwrap(); + vsum.add_noise( + &strategy, + &mut v, + &mut SeedStreamTurboShake128::from_seed([0u8; 16]), + ) + .unwrap(); assert_eq!( vsum.decode_result(&v, 1).unwrap(), match n { // sensitivity depends on encoding so the noise differs - 16 => vec![0.150604248046875, 0.139373779296875, -0.03759765625], - 32 => vec![0.3051439793780446, 0.1226568529382348, 0.08595499861985445], - 64 => vec![0.2896077990915178, 0.16115188007715098, 0.0788390114728425], + 16 => vec![0.288970947265625, 0.168853759765625, 0.085662841796875], + 32 => vec![0.257810294162482, 0.10634658299386501, 0.10149003705009818], + 64 => vec![ + 0.37697368351762867, + -0.02388947667663828, + 0.19813152630930916 + ], _ => panic!("unsupported bitsize"), } ); @@ -785,52 +792,46 @@ mod tests { let mut input: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap(); assert_eq!(input[0], Field128::zero()); input[0] = one; // it was zero - flp_validity_test( - &vsum, - &input, - &ValidityTestCase::<Field128> { - expect_valid: false, - expected_output: Some(vec![ - Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 - Field128::from(enc_vec[1]), - Field128::from(enc_vec[2]), - ]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &input, + expected_output: Some(&[ + Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + expect_valid: false, + } + .run::<3>(); // encoding contains entries that are not zero or one let mut input2: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap(); input2[0] = one + one; - flp_validity_test( - &vsum, - &input2, - &ValidityTestCase::<Field128> { - expect_valid: false, - expected_output: Some(vec![ - Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 - Field128::from(enc_vec[1]), - Field128::from(enc_vec[2]), - ]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &input2, + expected_output: Some(&[ + Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + expect_valid: false, + } + .run::<3>(); // norm is too big // 2^n - 1, the field element encoded by the all-1 vector let one_enc = Field128::from(((2_u128) << (n - 1)) - 1); - flp_validity_test( - &vsum, - &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors - &ValidityTestCase::<Field128> { - expect_valid: false, - expected_output: Some(vec![one_enc; 3]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors + expected_output: Some(&[one_enc; 3]), + expect_valid: false, + } + .run::<3>(); // invalid submission length, should be 3n + (2*n - 2) for a // 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm. |