summaryrefslogtreecommitdiffstats
path: root/toolkit/components/telemetry/dap/ffi/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/telemetry/dap/ffi/src/lib.rs')
-rw-r--r--toolkit/components/telemetry/dap/ffi/src/lib.rs335
1 files changed, 335 insertions, 0 deletions
diff --git a/toolkit/components/telemetry/dap/ffi/src/lib.rs b/toolkit/components/telemetry/dap/ffi/src/lib.rs
new file mode 100644
index 0000000000..998c8af204
--- /dev/null
+++ b/toolkit/components/telemetry/dap/ffi/src/lib.rs
@@ -0,0 +1,335 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+use std::error::Error;
+use std::io::Cursor;
+
+use prio::vdaf::prio3::Prio3Sum;
+use prio::vdaf::prio3::Prio3SumVec;
+use thin_vec::ThinVec;
+
+pub mod types;
+use types::HpkeConfig;
+use types::PlaintextInputShare;
+use types::Report;
+use types::ReportID;
+use types::ReportMetadata;
+use types::Time;
+
+use prio::codec::Encode;
+use prio::codec::{decode_u16_items, encode_u32_items};
+use prio::flp::types::{Sum, SumVec};
+use prio::vdaf::prio3::Prio3;
+use prio::vdaf::Client;
+use prio::vdaf::VdafError;
+
+use crate::types::HpkeCiphertext;
+
+extern "C" {
+ pub fn dapHpkeEncryptOneshot(
+ aKey: *const u8,
+ aKeyLength: u32,
+ aInfo: *const u8,
+ aInfoLength: u32,
+ aAad: *const u8,
+ aAadLength: u32,
+ aPlaintext: *const u8,
+ aPlaintextLength: u32,
+ aOutputEncapsulatedKey: &mut ThinVec<u8>,
+ aOutputShare: &mut ThinVec<u8>,
+ ) -> bool;
+}
+
+pub fn new_prio_u8(num_aggregators: u8, bits: u32) -> Result<Prio3Sum, VdafError> {
+ if bits > 64 {
+ return Err(VdafError::Uncategorized(format!(
+ "bit length ({}) exceeds limit for aggregate type (64)",
+ bits
+ )));
+ }
+
+ Prio3::new(num_aggregators, Sum::new(bits as usize)?)
+}
+
+pub fn new_prio_vecu8(num_aggregators: u8, len: usize) -> Result<Prio3SumVec, VdafError> {
+ let chunk_length = prio::vdaf::prio3::optimal_chunk_length(8 * len);
+ Prio3::new(num_aggregators, SumVec::new(8, len, chunk_length)?)
+}
+
+pub fn new_prio_vecu16(num_aggregators: u8, len: usize) -> Result<Prio3SumVec, VdafError> {
+ let chunk_length = prio::vdaf::prio3::optimal_chunk_length(16 * len);
+ Prio3::new(num_aggregators, SumVec::new(16, len, chunk_length)?)
+}
+
+enum Role {
+ Leader = 2,
+ Helper = 3,
+}
+
+/// A minimal wrapper around the FFI function which mostly just converts datatypes.
+fn hpke_encrypt_wrapper(
+ plain_share: &Vec<u8>,
+ aad: &Vec<u8>,
+ info: &Vec<u8>,
+ hpke_config: &HpkeConfig,
+) -> Result<HpkeCiphertext, Box<dyn std::error::Error>> {
+ let mut encrypted_share = ThinVec::<u8>::new();
+ let mut encapsulated_key = ThinVec::<u8>::new();
+ unsafe {
+ if !dapHpkeEncryptOneshot(
+ hpke_config.public_key.as_ptr(),
+ hpke_config.public_key.len() as u32,
+ info.as_ptr(),
+ info.len() as u32,
+ aad.as_ptr(),
+ aad.len() as u32,
+ plain_share.as_ptr(),
+ plain_share.len() as u32,
+ &mut encapsulated_key,
+ &mut encrypted_share,
+ ) {
+ return Err(Box::from("Encryption failed."));
+ }
+ }
+
+ Ok(HpkeCiphertext {
+ config_id: hpke_config.id,
+ enc: encapsulated_key.to_vec(),
+ payload: encrypted_share.to_vec(),
+ })
+}
+
+trait Shardable {
+ fn shard(
+ &self,
+ nonce: &[u8; 16],
+ ) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>>;
+}
+
+impl Shardable for u8 {
+ fn shard(
+ &self,
+ nonce: &[u8; 16],
+ ) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
+ let prio = new_prio_u8(2, 2)?;
+
+ let (public_share, input_shares) = prio.shard(&(*self as u128), nonce)?;
+
+ debug_assert_eq!(input_shares.len(), 2);
+
+ let encoded_input_shares = input_shares.iter().map(|s| s.get_encoded()).collect();
+ let encoded_public_share = public_share.get_encoded();
+ Ok((encoded_public_share, encoded_input_shares))
+ }
+}
+
+impl Shardable for ThinVec<u8> {
+ fn shard(
+ &self,
+ nonce: &[u8; 16],
+ ) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
+ let prio = new_prio_vecu8(2, self.len())?;
+
+ let measurement: Vec<u128> = self.iter().map(|e| (*e as u128)).collect();
+ let (public_share, input_shares) = prio.shard(&measurement, nonce)?;
+
+ debug_assert_eq!(input_shares.len(), 2);
+
+ let encoded_input_shares = input_shares.iter().map(|s| s.get_encoded()).collect();
+ let encoded_public_share = public_share.get_encoded();
+ Ok((encoded_public_share, encoded_input_shares))
+ }
+}
+
+impl Shardable for ThinVec<u16> {
+ fn shard(
+ &self,
+ nonce: &[u8; 16],
+ ) -> Result<(Vec<u8>, Vec<Vec<u8>>), Box<dyn std::error::Error>> {
+ let prio = new_prio_vecu16(2, self.len())?;
+
+ let measurement: Vec<u128> = self.iter().map(|e| (*e as u128)).collect();
+ let (public_share, input_shares) = prio.shard(&measurement, nonce)?;
+
+ debug_assert_eq!(input_shares.len(), 2);
+
+ let encoded_input_shares = input_shares.iter().map(|s| s.get_encoded()).collect();
+ let encoded_public_share = public_share.get_encoded();
+ Ok((encoded_public_share, encoded_input_shares))
+ }
+}
+
+/// Pre-fill the info part of the HPKE sealing with the constants from the standard.
+fn make_base_info() -> Vec<u8> {
+ let mut info = Vec::<u8>::new();
+ const START: &[u8] = "dap-07 input share".as_bytes();
+ info.extend(START);
+ const FIXED: u8 = 1;
+ info.push(FIXED);
+
+ info
+}
+
+fn select_hpke_config(configs: Vec<HpkeConfig>) -> Result<HpkeConfig, Box<dyn Error>> {
+ for config in configs {
+ if config.kem_id == 0x20 /* DHKEM(X25519, HKDF-SHA256) */ &&
+ config.kdf_id == 0x01 /* HKDF-SHA256 */ &&
+ config.aead_id == 0x01
+ /* AES-128-GCM */
+ {
+ return Ok(config);
+ }
+ }
+
+ Err("No suitable HPKE config found.".into())
+}
+
+/// This function creates a full report - ready to send - for a measurement.
+///
+/// To do that it also needs the HPKE configurations for the endpoints and some
+/// additional data which is part of the authentication.
+fn get_dap_report_internal<T: Shardable>(
+ leader_hpke_config_encoded: &ThinVec<u8>,
+ helper_hpke_config_encoded: &ThinVec<u8>,
+ measurement: &T,
+ task_id: &[u8; 32],
+ time_precision: u64,
+) -> Result<Report, Box<dyn std::error::Error>> {
+ let leader_hpke_configs: Vec<HpkeConfig> =
+ decode_u16_items(&(), &mut Cursor::new(leader_hpke_config_encoded))?;
+ let leader_hpke_config = select_hpke_config(leader_hpke_configs)?;
+ let helper_hpke_configs: Vec<HpkeConfig> =
+ decode_u16_items(&(), &mut Cursor::new(helper_hpke_config_encoded))?;
+ let helper_hpke_config = select_hpke_config(helper_hpke_configs)?;
+
+ let report_id = ReportID::generate();
+ let (encoded_public_share, encoded_input_shares) = measurement.shard(report_id.as_ref())?;
+
+ let plaintext_input_shares: Vec<Vec<u8>> = encoded_input_shares
+ .into_iter()
+ .map(|encoded_input_share| {
+ PlaintextInputShare {
+ extensions: Vec::new(),
+ payload: encoded_input_share,
+ }
+ .get_encoded()
+ })
+ .collect();
+
+ let metadata = ReportMetadata {
+ report_id,
+ time: Time::generate(time_precision),
+ };
+
+ // This quote from the standard describes which info and aad to use for the encryption:
+ // enc, payload = SealBase(pk,
+ // "dap-02 input share" || 0x01 || server_role,
+ // task_id || metadata || public_share, input_share)
+ // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#name-upload-request
+ let mut info = make_base_info();
+
+ let mut aad = Vec::from(*task_id);
+ metadata.encode(&mut aad);
+ encode_u32_items(&mut aad, &(), &encoded_public_share);
+
+ info.push(Role::Leader as u8);
+
+ let leader_payload =
+ hpke_encrypt_wrapper(&plaintext_input_shares[0], &aad, &info, &leader_hpke_config)?;
+
+ *info.last_mut().unwrap() = Role::Helper as u8;
+
+ let helper_payload =
+ hpke_encrypt_wrapper(&plaintext_input_shares[1], &aad, &info, &helper_hpke_config)?;
+
+ Ok(Report {
+ metadata,
+ public_share: encoded_public_share,
+ leader_encrypted_input_share: leader_payload,
+ helper_encrypted_input_share: helper_payload,
+ })
+}
+
+/// Wraps the function above with minor C interop.
+/// Mostly it turns any error result into a return value of false.
+#[no_mangle]
+pub extern "C" fn dapGetReportU8(
+ leader_hpke_config_encoded: &ThinVec<u8>,
+ helper_hpke_config_encoded: &ThinVec<u8>,
+ measurement: u8,
+ task_id: &ThinVec<u8>,
+ time_precision: u64,
+ out_report: &mut ThinVec<u8>,
+) -> bool {
+ assert_eq!(task_id.len(), 32);
+
+ if let Ok(report) = get_dap_report_internal::<u8>(
+ leader_hpke_config_encoded,
+ helper_hpke_config_encoded,
+ &measurement,
+ &task_id.as_slice().try_into().unwrap(),
+ time_precision,
+ ) {
+ let encoded_report = report.get_encoded();
+ out_report.extend(encoded_report);
+
+ true
+ } else {
+ false
+ }
+}
+
+#[no_mangle]
+pub extern "C" fn dapGetReportVecU8(
+ leader_hpke_config_encoded: &ThinVec<u8>,
+ helper_hpke_config_encoded: &ThinVec<u8>,
+ measurement: &ThinVec<u8>,
+ task_id: &ThinVec<u8>,
+ time_precision: u64,
+ out_report: &mut ThinVec<u8>,
+) -> bool {
+ assert_eq!(task_id.len(), 32);
+
+ if let Ok(report) = get_dap_report_internal::<ThinVec<u8>>(
+ leader_hpke_config_encoded,
+ helper_hpke_config_encoded,
+ measurement,
+ &task_id.as_slice().try_into().unwrap(),
+ time_precision,
+ ) {
+ let encoded_report = report.get_encoded();
+ out_report.extend(encoded_report);
+
+ true
+ } else {
+ false
+ }
+}
+
+#[no_mangle]
+pub extern "C" fn dapGetReportVecU16(
+ leader_hpke_config_encoded: &ThinVec<u8>,
+ helper_hpke_config_encoded: &ThinVec<u8>,
+ measurement: &ThinVec<u16>,
+ task_id: &ThinVec<u8>,
+ time_precision: u64,
+ out_report: &mut ThinVec<u8>,
+) -> bool {
+ assert_eq!(task_id.len(), 32);
+
+ if let Ok(report) = get_dap_report_internal::<ThinVec<u16>>(
+ leader_hpke_config_encoded,
+ helper_hpke_config_encoded,
+ measurement,
+ &task_id.as_slice().try_into().unwrap(),
+ time_precision,
+ ) {
+ let encoded_report = report.get_encoded();
+ out_report.extend(encoded_report);
+
+ true
+ } else {
+ false
+ }
+}