// Licensed under the Apache License, Version 2.0 or the MIT license // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. use std::{os::raw::c_void, pin::Pin}; use neqo_common::qdebug; use crate::{ agentio::as_c_void, constants::Epoch, err::Res, p11::{PK11SymKey, PK11_ReferenceSymKey, SymKey}, ssl::{PRFileDesc, SSLSecretCallback, SSLSecretDirection}, }; experimental_api!(SSL_SecretCallback( fd: *mut PRFileDesc, cb: SSLSecretCallback, arg: *mut c_void, )); #[derive(Clone, Copy, Debug)] pub enum SecretDirection { Read, Write, } impl From for SecretDirection { #[must_use] fn from(dir: SSLSecretDirection::Type) -> Self { match dir { SSLSecretDirection::ssl_secret_read => Self::Read, SSLSecretDirection::ssl_secret_write => Self::Write, _ => unreachable!(), } } } #[derive(Debug, Default)] #[allow(clippy::module_name_repetitions)] pub struct DirectionalSecrets { // We only need to maintain 3 secrets for the epochs used during the handshake. secrets: [Option; 3], } impl DirectionalSecrets { fn put(&mut self, epoch: Epoch, key: SymKey) { assert!(epoch > 0); let i = (epoch - 1) as usize; assert!(i < self.secrets.len()); // assert!(self.secrets[i].is_none()); self.secrets[i] = Some(key); } pub fn take(&mut self, epoch: Epoch) -> Option { assert!(epoch > 0); let i = (epoch - 1) as usize; assert!(i < self.secrets.len()); self.secrets[i].take() } } #[derive(Debug, Default)] pub struct Secrets { r: DirectionalSecrets, w: DirectionalSecrets, } impl Secrets { #[allow(clippy::unused_self)] unsafe extern "C" fn secret_available( _fd: *mut PRFileDesc, epoch: u16, dir: SSLSecretDirection::Type, secret: *mut PK11SymKey, arg: *mut c_void, ) { let secrets = arg.cast::().as_mut().unwrap(); secrets.put_raw(epoch, dir, secret); } fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) { let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) }; let key = SymKey::from_ptr(key_ptr).expect("NSS shouldn't be passing out NULL secrets"); self.put(SecretDirection::from(dir), epoch, key); } fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) { qdebug!("{:?} secret available for {:?}: {:?}", dir, epoch, key); let keys = match dir { SecretDirection::Read => &mut self.r, SecretDirection::Write => &mut self.w, }; keys.put(epoch, key); } } #[derive(Debug)] pub struct SecretHolder { secrets: Pin>, } impl SecretHolder { /// This registers with NSS. The lifetime of this object needs to match the lifetime /// of the connection, or bad things might happen. pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> { let p = as_c_void(&mut self.secrets); unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) } } pub fn take_read(&mut self, epoch: Epoch) -> Option { self.secrets.r.take(epoch) } pub fn take_write(&mut self, epoch: Epoch) -> Option { self.secrets.w.take(epoch) } } impl Default for SecretHolder { fn default() -> Self { Self { secrets: Box::pin(Secrets::default()), } } }