summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/flp/types/fixedpoint_l2.rs')
-rw-r--r--third_party/rust/prio/src/flp/types/fixedpoint_l2.rs131
1 files changed, 66 insertions, 65 deletions
diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
index b5aa2fd116..8766c035b8 100644
--- a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
+++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs
@@ -172,15 +172,17 @@
pub mod compatible_float;
use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError};
-use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt};
+use crate::field::{
+ Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt, Integer,
+};
use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat;
use crate::flp::types::parallel_sum_range_checks;
use crate::flp::{FlpError, Gadget, Type, TypeWithNoise};
-use crate::vdaf::xof::SeedStreamSha3;
+use crate::vdaf::xof::SeedStreamTurboShake128;
use fixed::traits::Fixed;
use num_bigint::{BigInt, BigUint, TryFromBigIntError};
-use num_integer::Integer;
+use num_integer::Integer as _;
use num_rational::Ratio;
use rand::{distributions::Distribution, Rng};
use rand_core::SeedableRng;
@@ -250,7 +252,7 @@ where
/// fixed point vector with `entries` entries.
pub fn new(entries: usize) -> Result<Self, FlpError> {
// (0) initialize constants
- let fi_one = u128::from(Field128::one());
+ let fi_one = <Field128 as FieldElementWithInteger>::Integer::one();
// (I) Check that the fixed type is compatible.
//
@@ -400,7 +402,6 @@ where
SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static,
SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static,
{
- const ID: u32 = 0xFFFF0000;
type Measurement = Vec<T>;
type AggregateResult = Vec<f64>;
type Field = Field128;
@@ -419,12 +420,9 @@ where
// Encode the integer entries bitwise, and write them into the `encoded`
// vector.
let mut encoded: Vec<Field128> =
- vec![Field128::zero(); self.bits_per_entry * self.entries + self.bits_for_norm];
- for (l, entry) in integer_entries.clone().enumerate() {
- Field128::fill_with_bitvector_representation(
- &entry,
- &mut encoded[l * self.bits_per_entry..(l + 1) * self.bits_per_entry],
- )?;
+ Vec::with_capacity(self.bits_per_entry * self.entries + self.bits_for_norm);
+ for entry in integer_entries.clone() {
+ encoded.extend(Field128::encode_as_bitvector(entry, self.bits_per_entry)?);
}
// (II) Vector norm.
@@ -434,10 +432,7 @@ where
let norm_int = u128::from(norm);
// Write the norm into the `entries` vector.
- Field128::fill_with_bitvector_representation(
- &norm_int,
- &mut encoded[self.range_norm_begin..self.range_norm_end],
- )?;
+ encoded.extend(Field128::encode_as_bitvector(norm_int, self.bits_for_norm)?);
Ok(encoded)
}
@@ -535,7 +530,7 @@ where
// decode the bit-encoded entries into elements in the range [0,2^n):
let decoded_entries: Result<Vec<_>, _> = input[0..self.entries * self.bits_per_entry]
.chunks(self.bits_per_entry)
- .map(Field128::decode_from_bitvector_representation)
+ .map(Field128::decode_bitvector)
.collect();
// run parallel sum gadget on the decoded entries
@@ -544,7 +539,7 @@ where
// Chunks which are too short need to be extended with a share of the
// encoded zero value, that is: 1/num_shares * (2^(n-1))
- let fi_one = u128::from(Field128::one());
+ let fi_one = <Field128 as FieldElementWithInteger>::Integer::one();
let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1));
let zero_enc_share = zero_enc * num_shares_inverse;
@@ -567,7 +562,7 @@ where
// The submitted norm is also decoded from its bit-encoding, and
// compared with the computed norm.
let submitted_norm_enc = &input[self.range_norm_begin..self.range_norm_end];
- let submitted_norm = Field128::decode_from_bitvector_representation(submitted_norm_enc)?;
+ let submitted_norm = Field128::decode_bitvector(submitted_norm_enc)?;
let norm_check = computed_norm - submitted_norm;
@@ -586,7 +581,7 @@ where
let start = i_entry * self.bits_per_entry;
let end = (i_entry + 1) * self.bits_per_entry;
- let decoded = Field128::decode_from_bitvector_representation(&input[start..end])?;
+ let decoded = Field128::decode_bitvector(&input[start..end])?;
decoded_vector.push(decoded);
}
Ok(decoded_vector)
@@ -644,7 +639,11 @@ where
agg_result: &mut [Self::Field],
_num_measurements: usize,
) -> Result<(), FlpError> {
- self.add_noise(dp_strategy, agg_result, &mut SeedStreamSha3::from_entropy())
+ self.add_noise(
+ dp_strategy,
+ agg_result,
+ &mut SeedStreamTurboShake128::from_entropy(),
+ )
}
}
@@ -686,8 +685,8 @@ mod tests {
use crate::dp::{Rational, ZCdpBudget};
use crate::field::{random_vector, Field128, FieldElement};
use crate::flp::gadgets::ParallelSum;
- use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase};
- use crate::vdaf::xof::SeedStreamSha3;
+ use crate::flp::test_utils::FlpTest;
+ use crate::vdaf::xof::SeedStreamTurboShake128;
use fixed::types::extra::{U127, U14, U63};
use fixed::{FixedI128, FixedI16, FixedI64};
use fixed_macro::fixed;
@@ -768,15 +767,23 @@ mod tests {
let strategy = ZCdpDiscreteGaussian::from_budget(ZCdpBudget::new(
Rational::from_unsigned(100u8, 3u8).unwrap(),
));
- vsum.add_noise(&strategy, &mut v, &mut SeedStreamSha3::from_seed([0u8; 16]))
- .unwrap();
+ vsum.add_noise(
+ &strategy,
+ &mut v,
+ &mut SeedStreamTurboShake128::from_seed([0u8; 16]),
+ )
+ .unwrap();
assert_eq!(
vsum.decode_result(&v, 1).unwrap(),
match n {
// sensitivity depends on encoding so the noise differs
- 16 => vec![0.150604248046875, 0.139373779296875, -0.03759765625],
- 32 => vec![0.3051439793780446, 0.1226568529382348, 0.08595499861985445],
- 64 => vec![0.2896077990915178, 0.16115188007715098, 0.0788390114728425],
+ 16 => vec![0.288970947265625, 0.168853759765625, 0.085662841796875],
+ 32 => vec![0.257810294162482, 0.10634658299386501, 0.10149003705009818],
+ 64 => vec![
+ 0.37697368351762867,
+ -0.02388947667663828,
+ 0.19813152630930916
+ ],
_ => panic!("unsupported bitsize"),
}
);
@@ -785,52 +792,46 @@ mod tests {
let mut input: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap();
assert_eq!(input[0], Field128::zero());
input[0] = one; // it was zero
- flp_validity_test(
- &vsum,
- &input,
- &ValidityTestCase::<Field128> {
- expect_valid: false,
- expected_output: Some(vec![
- Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0
- Field128::from(enc_vec[1]),
- Field128::from(enc_vec[2]),
- ]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest {
+ name: None,
+ flp: &vsum,
+ input: &input,
+ expected_output: Some(&[
+ Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0
+ Field128::from(enc_vec[1]),
+ Field128::from(enc_vec[2]),
+ ]),
+ expect_valid: false,
+ }
+ .run::<3>();
// encoding contains entries that are not zero or one
let mut input2: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap();
input2[0] = one + one;
- flp_validity_test(
- &vsum,
- &input2,
- &ValidityTestCase::<Field128> {
- expect_valid: false,
- expected_output: Some(vec![
- Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0
- Field128::from(enc_vec[1]),
- Field128::from(enc_vec[2]),
- ]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest {
+ name: None,
+ flp: &vsum,
+ input: &input2,
+ expected_output: Some(&[
+ Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0
+ Field128::from(enc_vec[1]),
+ Field128::from(enc_vec[2]),
+ ]),
+ expect_valid: false,
+ }
+ .run::<3>();
// norm is too big
// 2^n - 1, the field element encoded by the all-1 vector
let one_enc = Field128::from(((2_u128) << (n - 1)) - 1);
- flp_validity_test(
- &vsum,
- &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors
- &ValidityTestCase::<Field128> {
- expect_valid: false,
- expected_output: Some(vec![one_enc; 3]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest {
+ name: None,
+ flp: &vsum,
+ input: &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors
+ expected_output: Some(&[one_enc; 3]),
+ expect_valid: false,
+ }
+ .run::<3>();
// invalid submission length, should be 3n + (2*n - 2) for a
// 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm.