diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/neqo-crypto/src/agent.rs | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-crypto/src/agent.rs')
-rw-r--r-- | third_party/rust/neqo-crypto/src/agent.rs | 1263 |
1 files changed, 1263 insertions, 0 deletions
diff --git a/third_party/rust/neqo-crypto/src/agent.rs b/third_party/rust/neqo-crypto/src/agent.rs new file mode 100644 index 0000000000..cd0bb4cb12 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/agent.rs @@ -0,0 +1,1263 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + convert::TryFrom, + ffi::{CStr, CString}, + mem::{self, MaybeUninit}, + ops::{Deref, DerefMut}, + os::raw::{c_uint, c_void}, + pin::Pin, + ptr::{null, null_mut}, + rc::Rc, + time::Instant, +}; + +use neqo_common::{hex_snip_middle, hex_with_len, qdebug, qinfo, qtrace, qwarn}; + +pub use crate::{ + agentio::{as_c_void, Record, RecordList}, + cert::CertificateInfo, +}; +use crate::{ + agentio::{AgentIo, METHODS}, + assert_initialized, + auth::AuthenticationStatus, + constants::{ + Alert, Cipher, Epoch, Extension, Group, SignatureScheme, Version, TLS_VERSION_1_3, + }, + ech, + err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res}, + ext::{ExtensionHandler, ExtensionTracker}, + p11::{self, PrivateKey, PublicKey}, + prio, + replay::AntiReplay, + secrets::SecretHolder, + ssl::{self, PRBool}, + time::{Time, TimeHolder}, +}; + +/// The maximum number of tickets to remember for a given connection. +const MAX_TICKETS: usize = 4; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum HandshakeState { + New, + InProgress, + AuthenticationPending, + /// When encrypted client hello is enabled, the server might engage a fallback. + /// This is the status that is returned. The included value is the public + /// name of the server, which should be used to validated the certificate. + EchFallbackAuthenticationPending(String), + Authenticated(PRErrorCode), + Complete(SecretAgentInfo), + Failed(Error), +} + +impl HandshakeState { + #[must_use] + pub fn is_connected(&self) -> bool { + matches!(self, Self::Complete(_)) + } + + #[must_use] + pub fn is_final(&self) -> bool { + matches!(self, Self::Complete(_) | Self::Failed(_)) + } + + #[must_use] + pub fn authentication_needed(&self) -> bool { + matches!( + self, + Self::AuthenticationPending | Self::EchFallbackAuthenticationPending(_) + ) + } +} + +fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>> { + let mut alpn_state = ssl::SSLNextProtoState::SSL_NEXT_PROTO_NO_SUPPORT; + let mut chosen = vec![0_u8; 255]; + let mut chosen_len: c_uint = 0; + secstatus_to_res(unsafe { + ssl::SSL_GetNextProto( + fd, + &mut alpn_state, + chosen.as_mut_ptr(), + &mut chosen_len, + c_uint::try_from(chosen.len())?, + ) + })?; + + let alpn = match (pre, alpn_state) { + (true, ssl::SSLNextProtoState::SSL_NEXT_PROTO_EARLY_VALUE) + | ( + false, + ssl::SSLNextProtoState::SSL_NEXT_PROTO_NEGOTIATED + | ssl::SSLNextProtoState::SSL_NEXT_PROTO_SELECTED, + ) => { + chosen.truncate(usize::try_from(chosen_len)?); + Some(match String::from_utf8(chosen) { + Ok(a) => a, + Err(_) => return Err(Error::InternalError), + }) + } + _ => None, + }; + qtrace!([format!("{fd:p}")], "got ALPN {:?}", alpn); + Ok(alpn) +} + +pub struct SecretAgentPreInfo { + info: ssl::SSLPreliminaryChannelInfo, + alpn: Option<String>, +} + +macro_rules! preinfo_arg { + ($v:ident, $m:ident, $f:ident: $t:ident $(,)?) => { + #[must_use] + pub fn $v(&self) -> Option<$t> { + match self.info.valuesSet & ssl::$m { + 0 => None, + _ => Some($t::try_from(self.info.$f).unwrap()), + } + } + }; +} + +impl SecretAgentPreInfo { + fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> { + let mut info: MaybeUninit<ssl::SSLPreliminaryChannelInfo> = MaybeUninit::uninit(); + secstatus_to_res(unsafe { + ssl::SSL_GetPreliminaryChannelInfo( + fd, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLPreliminaryChannelInfo>())?, + ) + })?; + + Ok(Self { + info: unsafe { info.assume_init() }, + alpn: get_alpn(fd, true)?, + }) + } + + preinfo_arg!(version, ssl_preinfo_version, protocolVersion: Version); + preinfo_arg!(cipher_suite, ssl_preinfo_cipher_suite, cipherSuite: Cipher); + preinfo_arg!( + early_data_cipher, + ssl_preinfo_0rtt_cipher_suite, + zeroRttCipherSuite: Cipher, + ); + + #[must_use] + pub fn early_data(&self) -> bool { + self.info.canSendEarlyData != 0 + } + + /// # Panics + /// + /// If `usize` is less than 32 bits and the value is too large. + #[must_use] + pub fn max_early_data(&self) -> usize { + usize::try_from(self.info.maxEarlyDataSize).unwrap() + } + + /// Was ECH accepted. + #[must_use] + pub fn ech_accepted(&self) -> Option<bool> { + if self.info.valuesSet & ssl::ssl_preinfo_ech == 0 { + None + } else { + Some(self.info.echAccepted != 0) + } + } + + /// Get the ECH public name that was used. This will only be available + /// (that is, not `None`) if `ech_accepted()` returns `false`. + /// In this case, certificate validation needs to use this name rather + /// than the original name to validate the certificate. If + /// that validation passes (that is, `SecretAgent::authenticated` is called + /// with `AuthenticationStatus::Ok`), then the handshake will still fail. + /// After the failed handshake, the state will be `Error::EchRetry`, + /// which contains a valid ECH configuration. + /// + /// # Errors + /// + /// When the public name is not valid UTF-8. (Note: names should be ASCII.) + pub fn ech_public_name(&self) -> Res<Option<&str>> { + if self.info.valuesSet & ssl::ssl_preinfo_ech == 0 || self.info.echPublicName.is_null() { + Ok(None) + } else { + let n = unsafe { CStr::from_ptr(self.info.echPublicName) }; + Ok(Some(n.to_str()?)) + } + } + + #[must_use] + pub fn alpn(&self) -> Option<&String> { + self.alpn.as_ref() + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct SecretAgentInfo { + version: Version, + cipher: Cipher, + group: Group, + resumed: bool, + early_data: bool, + ech_accepted: bool, + alpn: Option<String>, + signature_scheme: SignatureScheme, +} + +impl SecretAgentInfo { + fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> { + let mut info: MaybeUninit<ssl::SSLChannelInfo> = MaybeUninit::uninit(); + secstatus_to_res(unsafe { + ssl::SSL_GetChannelInfo( + fd, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLChannelInfo>())?, + ) + })?; + let info = unsafe { info.assume_init() }; + Ok(Self { + version: info.protocolVersion, + cipher: info.cipherSuite, + group: Group::try_from(info.keaGroup)?, + resumed: info.resumed != 0, + early_data: info.earlyDataAccepted != 0, + ech_accepted: info.echAccepted != 0, + alpn: get_alpn(fd, false)?, + signature_scheme: SignatureScheme::try_from(info.signatureScheme)?, + }) + } + #[must_use] + pub fn version(&self) -> Version { + self.version + } + #[must_use] + pub fn cipher_suite(&self) -> Cipher { + self.cipher + } + #[must_use] + pub fn key_exchange(&self) -> Group { + self.group + } + #[must_use] + pub fn resumed(&self) -> bool { + self.resumed + } + #[must_use] + pub fn early_data_accepted(&self) -> bool { + self.early_data + } + #[must_use] + pub fn ech_accepted(&self) -> bool { + self.ech_accepted + } + #[must_use] + pub fn alpn(&self) -> Option<&String> { + self.alpn.as_ref() + } + #[must_use] + pub fn signature_scheme(&self) -> SignatureScheme { + self.signature_scheme + } +} + +/// `SecretAgent` holds the common parts of client and server. +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct SecretAgent { + fd: *mut ssl::PRFileDesc, + secrets: SecretHolder, + raw: Option<bool>, + io: Pin<Box<AgentIo>>, + state: HandshakeState, + + /// Records whether authentication of certificates is required. + auth_required: Pin<Box<bool>>, + /// Records any fatal alert that is sent by the stack. + alert: Pin<Box<Option<Alert>>>, + /// The current time. + now: TimeHolder, + + extension_handlers: Vec<ExtensionTracker>, + + /// The encrypted client hello (ECH) configuration that is in use. + /// Empty if ECH is not enabled. + ech_config: Vec<u8>, +} + +impl SecretAgent { + fn new() -> Res<Self> { + let mut io = Box::pin(AgentIo::new()); + let fd = Self::create_fd(&mut io)?; + Ok(Self { + fd, + secrets: SecretHolder::default(), + raw: None, + io, + state: HandshakeState::New, + + auth_required: Box::pin(false), + alert: Box::pin(None), + now: TimeHolder::default(), + + extension_handlers: Vec::new(), + + ech_config: Vec::new(), + }) + } + + // Create a new SSL file descriptor. + // + // Note that we create separate bindings for PRFileDesc as both + // ssl::PRFileDesc and prio::PRFileDesc. This keeps the bindings + // minimal, but it means that the two forms need casts to translate + // between them. ssl::PRFileDesc is left as an opaque type, as the + // ssl::SSL_* APIs only need an opaque type. + fn create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc> { + assert_initialized(); + let label = CString::new("sslwrapper")?; + let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) }; + + let base_fd = unsafe { prio::PR_CreateIOLayerStub(id, METHODS) }; + if base_fd.is_null() { + return Err(Error::CreateSslSocket); + } + let fd = unsafe { + (*base_fd).secret = as_c_void(io).cast(); + ssl::SSL_ImportFD(null_mut(), base_fd.cast()) + }; + if fd.is_null() { + unsafe { prio::PR_Close(base_fd) }; + return Err(Error::CreateSslSocket); + } + Ok(fd) + } + + unsafe extern "C" fn auth_complete_hook( + arg: *mut c_void, + _fd: *mut ssl::PRFileDesc, + _check_sig: ssl::PRBool, + _is_server: ssl::PRBool, + ) -> ssl::SECStatus { + let auth_required_ptr = arg.cast::<bool>(); + *auth_required_ptr = true; + // NSS insists on getting SECWouldBlock here rather than accepting + // the usual combination of PR_WOULD_BLOCK_ERROR and SECFailure. + ssl::_SECStatus_SECWouldBlock + } + + unsafe extern "C" fn alert_sent_cb( + fd: *const ssl::PRFileDesc, + arg: *mut c_void, + alert: *const ssl::SSLAlert, + ) { + let alert = alert.as_ref().unwrap(); + if alert.level == 2 { + // Fatal alerts demand attention. + let st = arg.cast::<Option<Alert>>().as_mut().unwrap(); + if st.is_none() { + *st = Some(alert.description); + } else { + qwarn!([format!("{fd:p}")], "duplicate alert {}", alert.description); + } + } + } + + // Ready this for connecting. + fn ready(&mut self, is_server: bool, grease: bool) -> Res<()> { + secstatus_to_res(unsafe { + ssl::SSL_AuthCertificateHook( + self.fd, + Some(Self::auth_complete_hook), + as_c_void(&mut self.auth_required), + ) + })?; + + secstatus_to_res(unsafe { + ssl::SSL_AlertSentCallback( + self.fd, + Some(Self::alert_sent_cb), + as_c_void(&mut self.alert), + ) + })?; + + self.now.bind(self.fd)?; + self.configure(grease)?; + secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, ssl::PRBool::from(is_server)) }) + } + + /// Default configuration. + /// + /// # Errors + /// + /// If `set_version_range` fails. + fn configure(&mut self, grease: bool) -> Res<()> { + self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; + self.set_option(ssl::Opt::Locking, false)?; + self.set_option(ssl::Opt::Tickets, false)?; + self.set_option(ssl::Opt::OcspStapling, true)?; + if let Err(e) = self.set_option(ssl::Opt::Grease, grease) { + // Until NSS supports greasing, it's OK to fail here. + qinfo!([self], "Failed to enable greasing {:?}", e); + } + Ok(()) + } + + /// Set the versions that are supported. + /// + /// # Errors + /// + /// If the range of versions isn't supported. + pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> { + let range = ssl::SSLVersionRange { min, max }; + secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) }) + } + + /// Enable a set of ciphers. Note that the order of these is not respected. + /// + /// # Errors + /// + /// If NSS can't enable or disable ciphers. + pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> { + if self.state != HandshakeState::New { + qwarn!([self], "Cannot enable ciphers in state {:?}", self.state); + return Err(Error::InternalError); + } + + let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() }; + let cipher_count = usize::from(unsafe { ssl::SSL_GetNumImplementedCiphers() }); + for i in 0..cipher_count { + let p = all_ciphers.wrapping_add(i); + secstatus_to_res(unsafe { + ssl::SSL_CipherPrefSet(self.fd, i32::from(*p), ssl::PRBool::from(false)) + })?; + } + + for c in ciphers { + secstatus_to_res(unsafe { + ssl::SSL_CipherPrefSet(self.fd, i32::from(*c), ssl::PRBool::from(true)) + })?; + } + Ok(()) + } + + /// Set key exchange groups. + /// + /// # Errors + /// + /// If the underlying API fails (which shouldn't happen). + pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> { + // SSLNamedGroup is a different size to Group, so copy one by one. + let group_vec: Vec<_> = groups + .iter() + .map(|&g| ssl::SSLNamedGroup::Type::from(g)) + .collect(); + + let ptr = group_vec.as_slice().as_ptr(); + secstatus_to_res(unsafe { + ssl::SSL_NamedGroupConfig(self.fd, ptr, c_uint::try_from(group_vec.len())?) + }) + } + + /// Set the number of additional key shares that will be sent in the client hello + /// + /// # Errors + /// + /// If the underlying API fails (which shouldn't happen). + pub fn send_additional_key_shares(&mut self, count: usize) -> Res<()> { + secstatus_to_res(unsafe { + ssl::SSL_SendAdditionalKeyShares(self.fd, c_uint::try_from(count)?) + }) + } + + /// Set TLS options. + /// + /// # Errors + /// + /// Returns an error if the option or option value is invalid; i.e., never. + pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> { + opt.set(self.fd, value) + } + + /// Enable 0-RTT. + /// + /// # Errors + /// + /// See `set_option`. + pub fn enable_0rtt(&mut self) -> Res<()> { + self.set_option(ssl::Opt::EarlyData, true) + } + + /// Disable the `EndOfEarlyData` message. + /// + /// # Errors + /// + /// See `set_option`. + pub fn disable_end_of_early_data(&mut self) -> Res<()> { + self.set_option(ssl::Opt::SuppressEndOfEarlyData, true) + } + + /// `set_alpn` sets a list of preferred protocols, starting with the most preferred. + /// Though ALPN [RFC7301] permits octet sequences, this only allows for UTF-8-encoded + /// strings. + /// + /// This asserts if no items are provided, or if any individual item is longer than + /// 255 octets in length. + /// + /// # Errors + /// + /// This should always panic rather than return an error. + /// + /// # Panics + /// + /// If any of the provided `protocols` are more than 255 bytes long. + /// + /// [RFC7301]: https://datatracker.ietf.org/doc/html/rfc7301 + pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> { + // Validate and set length. + let mut encoded_len = protocols.len(); + for v in protocols { + assert!(v.as_ref().len() < 256); + assert!(!v.as_ref().is_empty()); + encoded_len += v.as_ref().len(); + } + + // Prepare to encode. + let mut encoded = Vec::with_capacity(encoded_len); + let mut add = |v: &str| { + if let Ok(s) = u8::try_from(v.len()) { + encoded.push(s); + encoded.extend_from_slice(v.as_bytes()); + } + }; + + // NSS inherited an idiosyncratic API as a result of having implemented NPN + // before ALPN. For that reason, we need to put the "best" option last. + let (first, rest) = protocols + .split_first() + .expect("at least one ALPN value needed"); + for v in rest { + add(v.as_ref()); + } + add(first.as_ref()); + assert_eq!(encoded_len, encoded.len()); + + // Now give the result to NSS. + secstatus_to_res(unsafe { + ssl::SSL_SetNextProtoNego( + self.fd, + encoded.as_slice().as_ptr(), + c_uint::try_from(encoded.len())?, + ) + }) + } + + /// Install an extension handler. + /// + /// This can be called multiple times with different values for `ext`. The handler is provided + /// as `Rc<RefCell<dyn T>>` so that the caller is able to hold a reference to the handler + /// and later access any state that it accumulates. + /// + /// # Errors + /// + /// When the extension handler can't be successfully installed. + pub fn extension_handler( + &mut self, + ext: Extension, + handler: Rc<RefCell<dyn ExtensionHandler>>, + ) -> Res<()> { + let tracker = unsafe { ExtensionTracker::new(self.fd, ext, handler) }?; + self.extension_handlers.push(tracker); + Ok(()) + } + + // This function tracks whether handshake() or handshake_raw() was used + // and prevents the other from being used. + fn set_raw(&mut self, r: bool) -> Res<()> { + if self.raw.is_none() { + self.secrets.register(self.fd)?; + self.raw = Some(r); + Ok(()) + } else if self.raw.unwrap() == r { + Ok(()) + } else { + Err(Error::MixedHandshakeMethod) + } + } + + /// Get information about the connection. + /// This includes the version, ciphersuite, and ALPN. + /// + /// Calling this function returns None until the connection is complete. + #[must_use] + pub fn info(&self) -> Option<&SecretAgentInfo> { + match self.state { + HandshakeState::Complete(ref info) => Some(info), + _ => None, + } + } + + /// Get any preliminary information about the status of the connection. + /// + /// This includes whether 0-RTT was accepted and any information related to that. + /// Calling this function collects all the relevant information. + /// + /// # Errors + /// + /// When the underlying socket functions fail. + pub fn preinfo(&self) -> Res<SecretAgentPreInfo> { + SecretAgentPreInfo::new(self.fd) + } + + /// Get the peer's certificate chain. + #[must_use] + pub fn peer_certificate(&self) -> Option<CertificateInfo> { + CertificateInfo::new(self.fd) + } + + /// Return any fatal alert that the TLS stack might have sent. + #[must_use] + pub fn alert(&self) -> Option<&Alert> { + (*self.alert).as_ref() + } + + /// Call this function to mark the peer as authenticated. + /// + /// # Panics + /// + /// If the handshake doesn't need to be authenticated. + pub fn authenticated(&mut self, status: AuthenticationStatus) { + assert!(self.state.authentication_needed()); + *self.auth_required = false; + self.state = HandshakeState::Authenticated(status.into()); + } + + fn capture_error<T>(&mut self, res: Res<T>) -> Res<T> { + if let Err(e) = res { + let e = ech::convert_ech_error(self.fd, e); + qwarn!([self], "error: {:?}", e); + self.state = HandshakeState::Failed(e.clone()); + Err(e) + } else { + res + } + } + + fn update_state(&mut self, res: Res<()>) -> Res<()> { + self.state = if is_blocked(&res) { + if *self.auth_required { + self.preinfo()?.ech_public_name()?.map_or( + HandshakeState::AuthenticationPending, + |public_name| { + HandshakeState::EchFallbackAuthenticationPending(public_name.to_owned()) + }, + ) + } else { + HandshakeState::InProgress + } + } else { + self.capture_error(res)?; + let info = self.capture_error(SecretAgentInfo::new(self.fd))?; + HandshakeState::Complete(info) + }; + qinfo!([self], "state -> {:?}", self.state); + Ok(()) + } + + /// Drive the TLS handshake, taking bytes from `input` and putting + /// any bytes necessary into `output`. + /// This takes the current time as `now`. + /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake + /// is complete and how many bytes were written to `output`, respectively. + /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this + /// function if you want to proceed, because this will mark the certificate as OK. + /// + /// # Errors + /// + /// When the handshake fails this returns an error. + pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>> { + self.now.set(now)?; + self.set_raw(false)?; + + let rv = { + // Within this scope, _h maintains a mutable reference to self.io. + let _h = self.io.wrap(input); + match self.state { + HandshakeState::Authenticated(ref err) => unsafe { + ssl::SSL_AuthCertificateComplete(self.fd, *err) + }, + _ => unsafe { ssl::SSL_ForceHandshake(self.fd) }, + } + }; + // Take before updating state so that we leave the output buffer empty + // even if there is an error. + let output = self.io.take_output(); + self.update_state(secstatus_to_res(rv))?; + Ok(output) + } + + /// Setup to receive records for raw handshake functions. + fn setup_raw(&mut self) -> Res<Pin<Box<RecordList>>> { + self.set_raw(true)?; + self.capture_error(RecordList::setup(self.fd)) + } + + /// Drive the TLS handshake, but get the raw content of records, not + /// protected records as bytes. This function is incompatible with + /// `handshake()`; use either this or `handshake()` exclusively. + /// + /// Ideally, this only includes records from the current epoch. + /// If you send data from multiple epochs, you might end up being sad. + /// + /// # Errors + /// + /// When the handshake fails this returns an error. + pub fn handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList> { + self.now.set(now)?; + let records = self.setup_raw()?; + + // Fire off any authentication we might need to complete. + if let HandshakeState::Authenticated(ref err) = self.state { + let result = + secstatus_to_res(unsafe { ssl::SSL_AuthCertificateComplete(self.fd, *err) }); + qdebug!([self], "SSL_AuthCertificateComplete: {:?}", result); + // This should return SECSuccess, so don't use update_state(). + self.capture_error(result)?; + } + + // Feed in any records. + if let Some(rec) = input { + self.capture_error(rec.write(self.fd))?; + } + + // Drive the handshake once more. + let rv = secstatus_to_res(unsafe { ssl::SSL_ForceHandshake(self.fd) }); + self.update_state(rv)?; + + Ok(*Pin::into_inner(records)) + } + + /// # Panics + /// + /// If setup fails. + #[allow(unknown_lints, clippy::branches_sharing_code)] + pub fn close(&mut self) { + // It should be safe to close multiple times. + if self.fd.is_null() { + return; + } + if let Some(true) = self.raw { + // Need to hold the record list in scope until the close is done. + let _records = self.setup_raw().expect("Can only close"); + unsafe { prio::PR_Close(self.fd.cast()) }; + } else { + // Need to hold the IO wrapper in scope until the close is done. + let _io = self.io.wrap(&[]); + unsafe { prio::PR_Close(self.fd.cast()) }; + }; + let _output = self.io.take_output(); + self.fd = null_mut(); + } + + /// State returns the status of the handshake. + #[must_use] + pub fn state(&self) -> &HandshakeState { + &self.state + } + + /// Take a read secret. This will only return a non-`None` value once. + #[must_use] + pub fn read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> { + self.secrets.take_read(epoch) + } + + /// Take a write secret. + #[must_use] + pub fn write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> { + self.secrets.take_write(epoch) + } + + /// Get the active ECH configuration, which is empty if ECH is disabled. + #[must_use] + pub fn ech_config(&self) -> &[u8] { + &self.ech_config + } +} + +impl Drop for SecretAgent { + fn drop(&mut self) { + self.close(); + } +} + +impl ::std::fmt::Display for SecretAgent { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Agent {:p}", self.fd) + } +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone)] +pub struct ResumptionToken { + token: Vec<u8>, + expiration_time: Instant, +} + +impl AsRef<[u8]> for ResumptionToken { + fn as_ref(&self) -> &[u8] { + &self.token + } +} + +impl ResumptionToken { + #[must_use] + pub fn new(token: Vec<u8>, expiration_time: Instant) -> Self { + Self { + token, + expiration_time, + } + } + + #[must_use] + pub fn expiration_time(&self) -> Instant { + self.expiration_time + } +} + +/// A TLS Client. +#[derive(Debug)] +#[allow( + renamed_and_removed_lints, + clippy::box_vec, + unknown_lints, + clippy::box_collection +)] // We need the Box. +pub struct Client { + agent: SecretAgent, + + /// The name of the server we're attempting a connection to. + server_name: String, + /// Records the resumption tokens we've received. + resumption: Pin<Box<Vec<ResumptionToken>>>, +} + +impl Client { + /// Create a new client agent. + /// + /// # Errors + /// + /// Errors returned if the socket can't be created or configured. + pub fn new(server_name: impl Into<String>, grease: bool) -> Res<Self> { + let server_name = server_name.into(); + let mut agent = SecretAgent::new()?; + let url = CString::new(server_name.as_bytes())?; + secstatus_to_res(unsafe { ssl::SSL_SetURL(agent.fd, url.as_ptr()) })?; + agent.ready(false, grease)?; + let mut client = Self { + agent, + server_name, + resumption: Box::pin(Vec::new()), + }; + client.ready()?; + Ok(client) + } + + unsafe extern "C" fn resumption_token_cb( + fd: *mut ssl::PRFileDesc, + token: *const u8, + len: c_uint, + arg: *mut c_void, + ) -> ssl::SECStatus { + let mut info: MaybeUninit<ssl::SSLResumptionTokenInfo> = MaybeUninit::uninit(); + if ssl::SSL_GetResumptionTokenInfo( + token, + len, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLResumptionTokenInfo>()).unwrap(), + ) + .is_err() + { + // Ignore the token. + return ssl::SECSuccess; + } + let expiration_time = info.assume_init().expirationTime; + if ssl::SSL_DestroyResumptionTokenInfo(info.as_mut_ptr()).is_err() { + // Ignore the token. + return ssl::SECSuccess; + } + let resumption = arg.cast::<Vec<ResumptionToken>>().as_mut().unwrap(); + let len = usize::try_from(len).unwrap(); + let mut v = Vec::with_capacity(len); + v.extend_from_slice(std::slice::from_raw_parts(token, len)); + qinfo!( + [format!("{fd:p}")], + "Got resumption token {}", + hex_snip_middle(&v) + ); + + if resumption.len() >= MAX_TICKETS { + resumption.remove(0); + } + if let Ok(t) = Time::try_from(expiration_time) { + resumption.push(ResumptionToken::new(v, *t)); + } + ssl::SECSuccess + } + + #[must_use] + pub fn server_name(&self) -> &str { + &self.server_name + } + + fn ready(&mut self) -> Res<()> { + let fd = self.fd; + unsafe { + ssl::SSL_SetResumptionTokenCallback( + fd, + Some(Self::resumption_token_cb), + as_c_void(&mut self.resumption), + ) + } + } + + /// Take a resumption token. + #[must_use] + pub fn resumption_token(&mut self) -> Option<ResumptionToken> { + (*self.resumption).pop() + } + + /// Check if there are more resumption tokens. + #[must_use] + pub fn has_resumption_token(&self) -> bool { + !(*self.resumption).is_empty() + } + + /// Enable resumption, using a token previously provided. + /// + /// # Errors + /// + /// Error returned when the resumption token is invalid or + /// the socket is not able to use the value. + pub fn enable_resumption(&mut self, token: impl AsRef<[u8]>) -> Res<()> { + unsafe { + ssl::SSL_SetResumptionToken( + self.agent.fd, + token.as_ref().as_ptr(), + c_uint::try_from(token.as_ref().len())?, + ) + } + } + + /// Enable encrypted client hello (ECH), using the encoded `ECHConfigList`. + /// + /// When ECH is enabled, a client needs to look for `Error::EchRetry` as a + /// failure code. If `Error::EchRetry` is received when connecting, the + /// connection attempt should be retried and the included value provided + /// to this function (instead of what is received from DNS). + /// + /// Calling this function with an empty value for `ech_config_list` enables + /// ECH greasing. When that is done, there is no need to look for `EchRetry` + /// + /// # Errors + /// + /// Error returned when the configuration is invalid. + pub fn enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> { + let config = ech_config_list.as_ref(); + qdebug!([self], "Enable ECH for a server: {}", hex_with_len(config)); + self.ech_config = Vec::from(config); + if config.is_empty() { + unsafe { ech::SSL_EnableTls13GreaseEch(self.agent.fd, PRBool::from(true)) } + } else { + unsafe { + ech::SSL_SetClientEchConfigs( + self.agent.fd, + config.as_ptr(), + c_uint::try_from(config.len())?, + ) + } + } + } +} + +impl Deref for Client { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + &self.agent + } +} + +impl DerefMut for Client { + fn deref_mut(&mut self) -> &mut SecretAgent { + &mut self.agent + } +} + +impl ::std::fmt::Display for Client { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Client {:p}", self.agent.fd) + } +} + +/// `ZeroRttCheckResult` encapsulates the options for handling a `ClientHello`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ZeroRttCheckResult { + /// Accept 0-RTT. + Accept, + /// Reject 0-RTT, but continue the handshake normally. + Reject, + /// Send HelloRetryRequest (probably not needed for QUIC). + HelloRetryRequest(Vec<u8>), + /// Fail the handshake. + Fail, +} + +/// A `ZeroRttChecker` is used by the agent to validate the application token (as provided by +/// `send_ticket`) +pub trait ZeroRttChecker: std::fmt::Debug + std::marker::Unpin { + fn check(&self, token: &[u8]) -> ZeroRttCheckResult; +} + +/// Using `AllowZeroRtt` for the implementation of `ZeroRttChecker` means +/// accepting 0-RTT always. This generally isn't a great idea, so this +/// generates a strong warning when it is used. +#[derive(Debug)] +pub struct AllowZeroRtt {} +impl ZeroRttChecker for AllowZeroRtt { + fn check(&self, _token: &[u8]) -> ZeroRttCheckResult { + qwarn!("AllowZeroRtt accepting 0-RTT"); + ZeroRttCheckResult::Accept + } +} + +#[derive(Debug)] +struct ZeroRttCheckState { + checker: Pin<Box<dyn ZeroRttChecker>>, +} + +impl ZeroRttCheckState { + pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self { + Self { + checker: Pin::new(checker), + } + } +} + +#[derive(Debug)] +pub struct Server { + agent: SecretAgent, + /// This holds the HRR callback context. + zero_rtt_check: Option<Pin<Box<ZeroRttCheckState>>>, +} + +impl Server { + /// Create a new server agent. + /// + /// # Errors + /// + /// Errors returned when NSS fails. + pub fn new(certificates: &[impl AsRef<str>]) -> Res<Self> { + let mut agent = SecretAgent::new()?; + + for n in certificates { + let c = CString::new(n.as_ref())?; + let cert_ptr = unsafe { p11::PK11_FindCertFromNickname(c.as_ptr(), null_mut()) }; + let Ok(cert) = p11::Certificate::from_ptr(cert_ptr) else { + return Err(Error::CertificateLoading); + }; + let key_ptr = unsafe { p11::PK11_FindKeyByAnyCert(*cert, null_mut()) }; + let Ok(key) = p11::PrivateKey::from_ptr(key_ptr) else { + return Err(Error::CertificateLoading); + }; + secstatus_to_res(unsafe { + ssl::SSL_ConfigServerCert(agent.fd, *cert, *key, null(), 0) + })?; + } + + agent.ready(true, true)?; + Ok(Self { + agent, + zero_rtt_check: None, + }) + } + + unsafe extern "C" fn hello_retry_cb( + first_hello: PRBool, + client_token: *const u8, + client_token_len: c_uint, + retry_token: *mut u8, + retry_token_len: *mut c_uint, + retry_token_max: c_uint, + arg: *mut c_void, + ) -> ssl::SSLHelloRetryRequestAction::Type { + if first_hello == 0 { + // On the second ClientHello after HelloRetryRequest, skip checks. + return ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept; + } + + let check_state = arg.cast::<ZeroRttCheckState>().as_mut().unwrap(); + let token = if client_token.is_null() { + &[] + } else { + std::slice::from_raw_parts(client_token, usize::try_from(client_token_len).unwrap()) + }; + match check_state.checker.check(token) { + ZeroRttCheckResult::Accept => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept, + ZeroRttCheckResult::Fail => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_fail, + ZeroRttCheckResult::Reject => { + ssl::SSLHelloRetryRequestAction::ssl_hello_retry_reject_0rtt + } + ZeroRttCheckResult::HelloRetryRequest(tok) => { + // Don't bother propagating errors from this, because it should be caught in + // testing. + assert!(tok.len() <= usize::try_from(retry_token_max).unwrap()); + let slc = std::slice::from_raw_parts_mut(retry_token, tok.len()); + slc.copy_from_slice(&tok); + *retry_token_len = c_uint::try_from(tok.len()).unwrap(); + ssl::SSLHelloRetryRequestAction::ssl_hello_retry_request + } + } + } + + /// Enable 0-RTT. This shadows the function of the same name that can be accessed + /// via the Deref implementation on Server. + /// + /// # Errors + /// + /// Returns an error if the underlying NSS functions fail. + pub fn enable_0rtt( + &mut self, + anti_replay: &AntiReplay, + max_early_data: u32, + checker: Box<dyn ZeroRttChecker>, + ) -> Res<()> { + let mut check_state = Box::pin(ZeroRttCheckState::new(checker)); + unsafe { + ssl::SSL_HelloRetryRequestCallback( + self.agent.fd, + Some(Self::hello_retry_cb), + as_c_void(&mut check_state), + ) + }?; + unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?; + self.zero_rtt_check = Some(check_state); + self.agent.enable_0rtt()?; + anti_replay.config_socket(self.fd)?; + Ok(()) + } + + /// Send a session ticket to the client. + /// This adds |extra| application-specific content into that ticket. + /// The records that are sent are captured and returned. + /// + /// # Errors + /// + /// If NSS is unable to send a ticket, or if this agent is incorrectly configured. + pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList> { + self.agent.now.set(now)?; + let records = self.setup_raw()?; + + unsafe { + ssl::SSL_SendSessionTicket(self.fd, extra.as_ptr(), c_uint::try_from(extra.len())?) + }?; + + Ok(*Pin::into_inner(records)) + } + + /// Enable encrypted client hello (ECH). + /// + /// # Errors + /// + /// Fails when NSS cannot create a key pair. + pub fn enable_ech( + &mut self, + config: u8, + public_name: &str, + sk: &PrivateKey, + pk: &PublicKey, + ) -> Res<()> { + let cfg = ech::encode_config(config, public_name, pk)?; + qdebug!([self], "Enable ECH for a server: {}", hex_with_len(&cfg)); + unsafe { + ech::SSL_SetServerEchConfigs( + self.agent.fd, + **pk, + **sk, + cfg.as_ptr(), + c_uint::try_from(cfg.len())?, + )?; + }; + self.ech_config = cfg; + Ok(()) + } +} + +impl Deref for Server { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + &self.agent + } +} + +impl DerefMut for Server { + fn deref_mut(&mut self) -> &mut SecretAgent { + &mut self.agent + } +} + +impl ::std::fmt::Display for Server { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Server {:p}", self.agent.fd) + } +} + +/// A generic container for Client or Server. +#[derive(Debug)] +pub enum Agent { + Client(crate::agent::Client), + Server(crate::agent::Server), +} + +impl Deref for Agent { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + match self { + Self::Client(c) => c, + Self::Server(s) => s, + } + } +} + +impl DerefMut for Agent { + fn deref_mut(&mut self) -> &mut SecretAgent { + match self { + Self::Client(c) => c, + Self::Server(s) => s, + } + } +} + +impl From<Client> for Agent { + #[must_use] + fn from(c: Client) -> Self { + Self::Client(c) + } +} + +impl From<Server> for Agent { + #[must_use] + fn from(s: Server) -> Self { + Self::Server(s) + } +} |