summaryrefslogtreecommitdiffstats
path: root/third_party/rust/prio/src/flp/types.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/prio/src/flp/types.rs')
-rw-r--r--third_party/rust/prio/src/flp/types.rs460
1 files changed, 54 insertions, 406 deletions
diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs
index 18c290355c..bca88a36cd 100644
--- a/third_party/rust/prio/src/flp/types.rs
+++ b/third_party/rust/prio/src/flp/types.rs
@@ -9,6 +9,7 @@ use crate::polynomial::poly_range_check;
use std::convert::TryInto;
use std::fmt::{self, Debug};
use std::marker::PhantomData;
+use subtle::Choice;
/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of the measurements (i.e., the total number of `1s`).
#[derive(Clone, PartialEq, Eq)]
pub struct Count<F> {
@@ -37,18 +38,16 @@ impl<F: FftFriendlyFieldElement> Default for Count<F> {
}
impl<F: FftFriendlyFieldElement> Type for Count<F> {
- const ID: u32 = 0x00000000;
- type Measurement = F::Integer;
+ type Measurement = bool;
type AggregateResult = F::Integer;
type Field = F;
- fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> {
- let max = F::valid_integer_try_from(1)?;
- if *value > max {
- return Err(FlpError::Encode("Count value must be 0 or 1".to_string()));
- }
-
- Ok(vec![F::from(*value)])
+ fn encode_measurement(&self, value: &bool) -> Result<Vec<F>, FlpError> {
+ Ok(vec![F::conditional_select(
+ &F::zero(),
+ &F::one(),
+ Choice::from(u8::from(*value)),
+ )])
}
fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> {
@@ -140,13 +139,12 @@ impl<F: FftFriendlyFieldElement> Sum<F> {
}
impl<F: FftFriendlyFieldElement> Type for Sum<F> {
- const ID: u32 = 0x00000001;
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
- let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ let v = F::encode_as_bitvector(*summand, self.bits)?.collect();
Ok(v)
}
@@ -174,7 +172,7 @@ impl<F: FftFriendlyFieldElement> Type for Sum<F> {
fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
self.truncate_call_check(&input)?;
- let res = F::decode_from_bitvector_representation(&input)?;
+ let res = F::decode_bitvector(&input)?;
Ok(vec![res])
}
@@ -239,13 +237,12 @@ impl<F: FftFriendlyFieldElement> Average<F> {
}
impl<F: FftFriendlyFieldElement> Type for Average<F> {
- const ID: u32 = 0xFFFF0000;
type Measurement = F::Integer;
type AggregateResult = f64;
type Field = F;
fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> {
- let v = F::encode_into_bitvector_representation(summand, self.bits)?;
+ let v = F::encode_as_bitvector(*summand, self.bits)?.collect();
Ok(v)
}
@@ -279,7 +276,7 @@ impl<F: FftFriendlyFieldElement> Type for Average<F> {
fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
self.truncate_call_check(&input)?;
- let res = F::decode_from_bitvector_representation(&input)?;
+ let res = F::decode_bitvector(&input)?;
Ok(vec![res])
}
@@ -380,7 +377,6 @@ where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
- const ID: u32 = 0x00000003;
type Measurement = usize;
type AggregateResult = Vec<F::Integer>;
type Field = F;
@@ -574,7 +570,6 @@ where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
- const ID: u32 = 0x00000002;
type Measurement = Vec<F::Integer>;
type AggregateResult = Vec<F::Integer>;
type Field = F;
@@ -588,18 +583,15 @@ where
)));
}
- let mut flattened = vec![F::zero(); self.flattened_len];
- for (summand, chunk) in measurement
- .iter()
- .zip(flattened.chunks_exact_mut(self.bits))
- {
+ let mut flattened = Vec::with_capacity(self.flattened_len);
+ for summand in measurement.iter() {
if summand > &self.max {
return Err(FlpError::Encode(format!(
"summand exceeds maximum of 2^{}-1",
self.bits
)));
}
- F::fill_with_bitvector_representation(summand, chunk)?;
+ flattened.extend(F::encode_as_bitvector(*summand, self.bits)?);
}
Ok(flattened)
@@ -642,7 +634,7 @@ where
self.truncate_call_check(&input)?;
let mut unflattened = Vec::with_capacity(self.len);
for chunk in input.chunks(self.bits) {
- unflattened.push(F::decode_from_bitvector_representation(chunk)?);
+ unflattened.push(F::decode_bitvector(chunk)?);
}
Ok(unflattened)
}
@@ -783,7 +775,7 @@ mod tests {
use crate::flp::gadgets::ParallelSum;
#[cfg(feature = "multithreaded")]
use crate::flp::gadgets::ParallelSumMultithreaded;
- use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase};
+ use crate::flp::test_utils::FlpTest;
use std::cmp;
#[test]
@@ -797,7 +789,7 @@ mod tests {
count
.decode_result(
&count
- .truncate(count.encode_measurement(&1).unwrap())
+ .truncate(count.encode_measurement(&true).unwrap())
.unwrap(),
1
)
@@ -806,39 +798,11 @@ mod tests {
);
// Test FLP on valid input.
- flp_validity_test(
- &count,
- &count.encode_measurement(&1).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &count,
- &count.encode_measurement(&0).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero]),
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]);
+ FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]);
// Test FLP on invalid input.
- flp_validity_test(
- &count,
- &[TestField::from(1337)],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&count, &[TestField::from(1337)]);
// Try running the validity circuit on an input that's too short.
count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err();
@@ -865,72 +829,22 @@ mod tests {
);
// Test FLP on valid input.
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&sum,
&sum.encode_measurement(&1337).unwrap(),
- &ValidityTestCase {
- expect_valid: true,
- expected_output: Some(vec![TestField::from(1337)]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &Sum::new(0).unwrap(),
- &[],
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &Sum::new(2).unwrap(),
- &[one, zero],
- &ValidityTestCase {
- expect_valid: true,
- expected_output: Some(vec![one]),
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
+ &[TestField::from(1337)],
+ );
+ FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]);
+ FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]);
+ FlpTest::expect_valid::<3>(
&Sum::new(9).unwrap(),
&[one, zero, one, one, zero, one, one, one, zero],
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![TestField::from(237)]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[TestField::from(237)],
+ );
// Test FLP on invalid input.
- flp_validity_test(
- &Sum::new(3).unwrap(),
- &[one, nine, zero],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &Sum::new(5).unwrap(),
- &[zero, zero, zero, zero, nine],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]);
+ FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]);
}
#[test]
@@ -1000,83 +914,29 @@ mod tests {
);
// Test valid inputs.
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&hist,
&hist.encode_measurement(&0).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one, zero, zero]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[one, zero, zero],
+ );
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&hist,
&hist.encode_measurement(&1).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero, one, zero]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[zero, one, zero],
+ );
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&hist,
&hist.encode_measurement(&2).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![zero, zero, one]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &[zero, zero, one],
+ );
// Test invalid inputs.
- flp_validity_test(
- &hist,
- &[zero, zero, nine],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &hist,
- &[zero, one, one],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &hist,
- &[one, one, one],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
-
- flp_validity_test(
- &hist,
- &[zero, zero, zero],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&hist, &[zero, zero, nine]);
+ FlpTest::expect_invalid::<3>(&hist, &[zero, one, one]);
+ FlpTest::expect_invalid::<3>(&hist, &[one, one, one]);
+ FlpTest::expect_invalid::<3>(&hist, &[zero, zero, zero]);
}
#[test]
@@ -1104,72 +964,38 @@ mod tests {
for len in 1..10 {
let chunk_length = cmp::max((len as f64).sqrt() as usize, 1);
let sum_vec = f(1, len, chunk_length).unwrap();
- flp_validity_test(
+ FlpTest::expect_valid_no_output::<3>(
&sum_vec,
&sum_vec.encode_measurement(&vec![1; len]).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one; len]),
- num_shares: 3,
- },
- )
- .unwrap();
+ );
}
let len = 100;
let sum_vec = f(1, len, 10).unwrap();
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&sum_vec,
&sum_vec.encode_measurement(&vec![1; len]).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![one; len]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &vec![one; len],
+ );
let len = 23;
let sum_vec = f(4, len, 4).unwrap();
- flp_validity_test(
+ FlpTest::expect_valid::<3>(
&sum_vec,
&sum_vec.encode_measurement(&vec![9; len]).unwrap(),
- &ValidityTestCase::<TestField> {
- expect_valid: true,
- expected_output: Some(vec![nine; len]),
- num_shares: 3,
- },
- )
- .unwrap();
+ &vec![nine; len],
+ );
// Test on invalid inputs.
for len in 1..10 {
let chunk_length = cmp::max((len as f64).sqrt() as usize, 1);
let sum_vec = f(1, len, chunk_length).unwrap();
- flp_validity_test(
- &sum_vec,
- &vec![nine; len],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]);
}
let len = 23;
let sum_vec = f(2, len, 4).unwrap();
- flp_validity_test(
- &sum_vec,
- &vec![nine; 2 * len],
- &ValidityTestCase::<TestField> {
- expect_valid: false,
- expected_output: None,
- num_shares: 3,
- },
- )
- .unwrap();
+ FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]);
// Round trip
let want = vec![1; len];
@@ -1232,184 +1058,6 @@ mod tests {
}
}
-#[cfg(test)]
-mod test_utils {
- use super::*;
- use crate::field::{random_vector, split_vector, FieldElement};
-
- pub(crate) struct ValidityTestCase<F> {
- pub(crate) expect_valid: bool,
- pub(crate) expected_output: Option<Vec<F>>,
- // Number of shares to split input and proofs into in `flp_test`.
- pub(crate) num_shares: usize,
- }
-
- pub(crate) fn flp_validity_test<T: Type>(
- typ: &T,
- input: &[T::Field],
- t: &ValidityTestCase<T::Field>,
- ) -> Result<(), FlpError> {
- let mut gadgets = typ.gadget();
-
- if input.len() != typ.input_len() {
- return Err(FlpError::Test(format!(
- "unexpected input length: got {}; want {}",
- input.len(),
- typ.input_len()
- )));
- }
-
- if typ.query_rand_len() != gadgets.len() {
- return Err(FlpError::Test(format!(
- "query rand length: got {}; want {}",
- typ.query_rand_len(),
- gadgets.len()
- )));
- }
-
- let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
- let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
- let query_rand = random_vector(typ.query_rand_len()).unwrap();
-
- // Run the validity circuit.
- let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?;
- if v != T::Field::zero() && t.expect_valid {
- return Err(FlpError::Test(format!(
- "expected valid input: valid() returned {v}"
- )));
- }
- if v == T::Field::zero() && !t.expect_valid {
- return Err(FlpError::Test(format!(
- "expected invalid input: valid() returned {v}"
- )));
- }
-
- // Generate the proof.
- let proof = typ.prove(input, &prove_rand, &joint_rand)?;
- if proof.len() != typ.proof_len() {
- return Err(FlpError::Test(format!(
- "unexpected proof length: got {}; want {}",
- proof.len(),
- typ.proof_len()
- )));
- }
-
- // Query the proof.
- let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?;
- if verifier.len() != typ.verifier_len() {
- return Err(FlpError::Test(format!(
- "unexpected verifier length: got {}; want {}",
- verifier.len(),
- typ.verifier_len()
- )));
- }
-
- // Decide if the input is valid.
- let res = typ.decide(&verifier)?;
- if res != t.expect_valid {
- return Err(FlpError::Test(format!(
- "decision is {}; want {}",
- res, t.expect_valid,
- )));
- }
-
- // Run distributed FLP.
- let input_shares: Vec<Vec<T::Field>> = split_vector(input, t.num_shares)
- .unwrap()
- .into_iter()
- .collect();
-
- let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, t.num_shares)
- .unwrap()
- .into_iter()
- .collect();
-
- let verifier: Vec<T::Field> = (0..t.num_shares)
- .map(|i| {
- typ.query(
- &input_shares[i],
- &proof_shares[i],
- &query_rand,
- &joint_rand,
- t.num_shares,
- )
- .unwrap()
- })
- .reduce(|mut left, right| {
- for (x, y) in left.iter_mut().zip(right.iter()) {
- *x += *y;
- }
- left
- })
- .unwrap();
-
- let res = typ.decide(&verifier)?;
- if res != t.expect_valid {
- return Err(FlpError::Test(format!(
- "distributed decision is {}; want {}",
- res, t.expect_valid,
- )));
- }
-
- // Try verifying various proof mutants.
- for i in 0..proof.len() {
- let mut mutated_proof = proof.clone();
- mutated_proof[i] += T::Field::one();
- let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?;
- if typ.decide(&verifier)? {
- return Err(FlpError::Test(format!(
- "decision for proof mutant {} is {}; want {}",
- i, true, false,
- )));
- }
- }
-
- // Try verifying a proof that is too short.
- let mut mutated_proof = proof.clone();
- mutated_proof.truncate(gadgets[0].arity() - 1);
- if typ
- .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
- .is_ok()
- {
- return Err(FlpError::Test(
- "query on short proof succeeded; want failure".to_string(),
- ));
- }
-
- // Try verifying a proof that is too long.
- let mut mutated_proof = proof;
- mutated_proof.extend_from_slice(&[T::Field::one(); 17]);
- if typ
- .query(input, &mutated_proof, &query_rand, &joint_rand, 1)
- .is_ok()
- {
- return Err(FlpError::Test(
- "query on long proof succeeded; want failure".to_string(),
- ));
- }
-
- if let Some(ref want) = t.expected_output {
- let got = typ.truncate(input.to_vec())?;
-
- if got.len() != typ.output_len() {
- return Err(FlpError::Test(format!(
- "unexpected output length: got {}; want {}",
- got.len(),
- typ.output_len()
- )));
- }
-
- if &got != want {
- return Err(FlpError::Test(format!(
- "unexpected output: got {got:?}; want {want:?}"
- )));
- }
- }
-
- Ok(())
- }
-}
-
#[cfg(feature = "experimental")]
#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
pub mod fixedpoint_l2;