diff options
Diffstat (limited to 'third_party/rust/neqo-crypto/src')
23 files changed, 4779 insertions, 0 deletions
diff --git a/third_party/rust/neqo-crypto/src/aead.rs b/third_party/rust/neqo-crypto/src/aead.rs new file mode 100644 index 0000000000..a2f009a403 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/aead.rs @@ -0,0 +1,175 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + convert::{TryFrom, TryInto}, + fmt, + ops::{Deref, DerefMut}, + os::raw::{c_char, c_uint}, + ptr::null_mut, +}; + +use crate::{ + constants::{Cipher, Version}, + err::Res, + experimental_api, + p11::{PK11SymKey, SymKey}, + scoped_ptr, + ssl::{self, PRUint16, PRUint64, PRUint8, SSLAeadContext}, +}; + +experimental_api!(SSL_MakeAead( + version: PRUint16, + cipher: PRUint16, + secret: *mut PK11SymKey, + label_prefix: *const c_char, + label_prefix_len: c_uint, + ctx: *mut *mut SSLAeadContext, +)); +experimental_api!(SSL_AeadEncrypt( + ctx: *const SSLAeadContext, + counter: PRUint64, + aad: *const PRUint8, + aad_len: c_uint, + input: *const PRUint8, + input_len: c_uint, + output: *const PRUint8, + output_len: *mut c_uint, + max_output: c_uint +)); +experimental_api!(SSL_AeadDecrypt( + ctx: *const SSLAeadContext, + counter: PRUint64, + aad: *const PRUint8, + aad_len: c_uint, + input: *const PRUint8, + input_len: c_uint, + output: *const PRUint8, + output_len: *mut c_uint, + max_output: c_uint +)); +experimental_api!(SSL_DestroyAead(ctx: *mut SSLAeadContext)); +scoped_ptr!(AeadContext, SSLAeadContext, SSL_DestroyAead); + +pub struct RealAead { + ctx: AeadContext, +} + +impl RealAead { + /// Create a new AEAD based on the indicated TLS version and cipher suite. + /// + /// # Errors + /// + /// Returns `Error` when the supporting NSS functions fail. + pub fn new( + _fuzzing: bool, + version: Version, + cipher: Cipher, + secret: &SymKey, + prefix: &str, + ) -> Res<Self> { + let s: *mut PK11SymKey = **secret; + unsafe { Self::from_raw(version, cipher, s, prefix) } + } + + #[must_use] + #[allow(clippy::unused_self)] + pub fn expansion(&self) -> usize { + 16 + } + + unsafe fn from_raw( + version: Version, + cipher: Cipher, + secret: *mut PK11SymKey, + prefix: &str, + ) -> Res<Self> { + let p = prefix.as_bytes(); + let mut ctx: *mut ssl::SSLAeadContext = null_mut(); + SSL_MakeAead( + version, + cipher, + secret, + p.as_ptr().cast(), + c_uint::try_from(p.len())?, + &mut ctx, + )?; + Ok(Self { + ctx: AeadContext::from_ptr(ctx)?, + }) + } + + /// Encrypt a plaintext. + /// + /// The space provided in `output` needs to be larger than `input` by + /// the value provided in `Aead::expansion`. + /// + /// # Errors + /// + /// If the input can't be protected or any input is too large for NSS. + pub fn encrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + let mut l: c_uint = 0; + unsafe { + SSL_AeadEncrypt( + *self.ctx, + count, + aad.as_ptr(), + c_uint::try_from(aad.len())?, + input.as_ptr(), + c_uint::try_from(input.len())?, + output.as_mut_ptr(), + &mut l, + c_uint::try_from(output.len())?, + ) + }?; + Ok(&output[0..(l.try_into()?)]) + } + + /// Decrypt a ciphertext. + /// + /// Note that NSS insists upon having extra space available for decryption, so + /// the buffer for `output` should be the same length as `input`, even though + /// the final result will be shorter. + /// + /// # Errors + /// + /// If the input isn't authenticated or any input is too large for NSS. + pub fn decrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + let mut l: c_uint = 0; + unsafe { + SSL_AeadDecrypt( + *self.ctx, + count, + aad.as_ptr(), + c_uint::try_from(aad.len())?, + input.as_ptr(), + c_uint::try_from(input.len())?, + output.as_mut_ptr(), + &mut l, + c_uint::try_from(output.len())?, + ) + }?; + Ok(&output[0..(l.try_into()?)]) + } +} + +impl fmt::Debug for RealAead { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "[AEAD Context]") + } +} diff --git a/third_party/rust/neqo-crypto/src/aead_fuzzing.rs b/third_party/rust/neqo-crypto/src/aead_fuzzing.rs new file mode 100644 index 0000000000..4e5a6de07f --- /dev/null +++ b/third_party/rust/neqo-crypto/src/aead_fuzzing.rs @@ -0,0 +1,103 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::fmt; + +use crate::{ + constants::{Cipher, Version}, + err::{sec::SEC_ERROR_BAD_DATA, Error, Res}, + p11::SymKey, + RealAead, +}; + +pub const FIXED_TAG_FUZZING: &[u8] = &[0x0a; 16]; + +pub struct FuzzingAead { + real: Option<RealAead>, +} + +impl FuzzingAead { + pub fn new( + fuzzing: bool, + version: Version, + cipher: Cipher, + secret: &SymKey, + prefix: &str, + ) -> Res<Self> { + let real = if fuzzing { + None + } else { + Some(RealAead::new(false, version, cipher, secret, prefix)?) + }; + Ok(Self { real }) + } + + #[must_use] + pub fn expansion(&self) -> usize { + if let Some(aead) = &self.real { + aead.expansion() + } else { + FIXED_TAG_FUZZING.len() + } + } + + pub fn encrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + if let Some(aead) = &self.real { + return aead.encrypt(count, aad, input, output); + } + + let l = input.len(); + output[..l].copy_from_slice(input); + output[l..l + 16].copy_from_slice(FIXED_TAG_FUZZING); + Ok(&output[..l + 16]) + } + + pub fn decrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + if let Some(aead) = &self.real { + return aead.decrypt(count, aad, input, output); + } + + if input.len() < FIXED_TAG_FUZZING.len() { + return Err(Error::from(SEC_ERROR_BAD_DATA)); + } + + let len_encrypted = input.len() - FIXED_TAG_FUZZING.len(); + // Check that: + // 1) expansion is all zeros and + // 2) if the encrypted data is also supplied that at least some values are no zero + // (otherwise padding will be interpreted as a valid packet) + if &input[len_encrypted..] == FIXED_TAG_FUZZING + && (len_encrypted == 0 || input[..len_encrypted].iter().any(|x| *x != 0x0)) + { + output[..len_encrypted].copy_from_slice(&input[..len_encrypted]); + Ok(&output[..len_encrypted]) + } else { + Err(Error::from(SEC_ERROR_BAD_DATA)) + } + } +} + +impl fmt::Debug for FuzzingAead { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(a) = &self.real { + a.fmt(f) + } else { + write!(f, "[FUZZING AEAD]") + } + } +} diff --git a/third_party/rust/neqo-crypto/src/agent.rs b/third_party/rust/neqo-crypto/src/agent.rs new file mode 100644 index 0000000000..cd0bb4cb12 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/agent.rs @@ -0,0 +1,1263 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + convert::TryFrom, + ffi::{CStr, CString}, + mem::{self, MaybeUninit}, + ops::{Deref, DerefMut}, + os::raw::{c_uint, c_void}, + pin::Pin, + ptr::{null, null_mut}, + rc::Rc, + time::Instant, +}; + +use neqo_common::{hex_snip_middle, hex_with_len, qdebug, qinfo, qtrace, qwarn}; + +pub use crate::{ + agentio::{as_c_void, Record, RecordList}, + cert::CertificateInfo, +}; +use crate::{ + agentio::{AgentIo, METHODS}, + assert_initialized, + auth::AuthenticationStatus, + constants::{ + Alert, Cipher, Epoch, Extension, Group, SignatureScheme, Version, TLS_VERSION_1_3, + }, + ech, + err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res}, + ext::{ExtensionHandler, ExtensionTracker}, + p11::{self, PrivateKey, PublicKey}, + prio, + replay::AntiReplay, + secrets::SecretHolder, + ssl::{self, PRBool}, + time::{Time, TimeHolder}, +}; + +/// The maximum number of tickets to remember for a given connection. +const MAX_TICKETS: usize = 4; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum HandshakeState { + New, + InProgress, + AuthenticationPending, + /// When encrypted client hello is enabled, the server might engage a fallback. + /// This is the status that is returned. The included value is the public + /// name of the server, which should be used to validated the certificate. + EchFallbackAuthenticationPending(String), + Authenticated(PRErrorCode), + Complete(SecretAgentInfo), + Failed(Error), +} + +impl HandshakeState { + #[must_use] + pub fn is_connected(&self) -> bool { + matches!(self, Self::Complete(_)) + } + + #[must_use] + pub fn is_final(&self) -> bool { + matches!(self, Self::Complete(_) | Self::Failed(_)) + } + + #[must_use] + pub fn authentication_needed(&self) -> bool { + matches!( + self, + Self::AuthenticationPending | Self::EchFallbackAuthenticationPending(_) + ) + } +} + +fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>> { + let mut alpn_state = ssl::SSLNextProtoState::SSL_NEXT_PROTO_NO_SUPPORT; + let mut chosen = vec![0_u8; 255]; + let mut chosen_len: c_uint = 0; + secstatus_to_res(unsafe { + ssl::SSL_GetNextProto( + fd, + &mut alpn_state, + chosen.as_mut_ptr(), + &mut chosen_len, + c_uint::try_from(chosen.len())?, + ) + })?; + + let alpn = match (pre, alpn_state) { + (true, ssl::SSLNextProtoState::SSL_NEXT_PROTO_EARLY_VALUE) + | ( + false, + ssl::SSLNextProtoState::SSL_NEXT_PROTO_NEGOTIATED + | ssl::SSLNextProtoState::SSL_NEXT_PROTO_SELECTED, + ) => { + chosen.truncate(usize::try_from(chosen_len)?); + Some(match String::from_utf8(chosen) { + Ok(a) => a, + Err(_) => return Err(Error::InternalError), + }) + } + _ => None, + }; + qtrace!([format!("{fd:p}")], "got ALPN {:?}", alpn); + Ok(alpn) +} + +pub struct SecretAgentPreInfo { + info: ssl::SSLPreliminaryChannelInfo, + alpn: Option<String>, +} + +macro_rules! preinfo_arg { + ($v:ident, $m:ident, $f:ident: $t:ident $(,)?) => { + #[must_use] + pub fn $v(&self) -> Option<$t> { + match self.info.valuesSet & ssl::$m { + 0 => None, + _ => Some($t::try_from(self.info.$f).unwrap()), + } + } + }; +} + +impl SecretAgentPreInfo { + fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> { + let mut info: MaybeUninit<ssl::SSLPreliminaryChannelInfo> = MaybeUninit::uninit(); + secstatus_to_res(unsafe { + ssl::SSL_GetPreliminaryChannelInfo( + fd, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLPreliminaryChannelInfo>())?, + ) + })?; + + Ok(Self { + info: unsafe { info.assume_init() }, + alpn: get_alpn(fd, true)?, + }) + } + + preinfo_arg!(version, ssl_preinfo_version, protocolVersion: Version); + preinfo_arg!(cipher_suite, ssl_preinfo_cipher_suite, cipherSuite: Cipher); + preinfo_arg!( + early_data_cipher, + ssl_preinfo_0rtt_cipher_suite, + zeroRttCipherSuite: Cipher, + ); + + #[must_use] + pub fn early_data(&self) -> bool { + self.info.canSendEarlyData != 0 + } + + /// # Panics + /// + /// If `usize` is less than 32 bits and the value is too large. + #[must_use] + pub fn max_early_data(&self) -> usize { + usize::try_from(self.info.maxEarlyDataSize).unwrap() + } + + /// Was ECH accepted. + #[must_use] + pub fn ech_accepted(&self) -> Option<bool> { + if self.info.valuesSet & ssl::ssl_preinfo_ech == 0 { + None + } else { + Some(self.info.echAccepted != 0) + } + } + + /// Get the ECH public name that was used. This will only be available + /// (that is, not `None`) if `ech_accepted()` returns `false`. + /// In this case, certificate validation needs to use this name rather + /// than the original name to validate the certificate. If + /// that validation passes (that is, `SecretAgent::authenticated` is called + /// with `AuthenticationStatus::Ok`), then the handshake will still fail. + /// After the failed handshake, the state will be `Error::EchRetry`, + /// which contains a valid ECH configuration. + /// + /// # Errors + /// + /// When the public name is not valid UTF-8. (Note: names should be ASCII.) + pub fn ech_public_name(&self) -> Res<Option<&str>> { + if self.info.valuesSet & ssl::ssl_preinfo_ech == 0 || self.info.echPublicName.is_null() { + Ok(None) + } else { + let n = unsafe { CStr::from_ptr(self.info.echPublicName) }; + Ok(Some(n.to_str()?)) + } + } + + #[must_use] + pub fn alpn(&self) -> Option<&String> { + self.alpn.as_ref() + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct SecretAgentInfo { + version: Version, + cipher: Cipher, + group: Group, + resumed: bool, + early_data: bool, + ech_accepted: bool, + alpn: Option<String>, + signature_scheme: SignatureScheme, +} + +impl SecretAgentInfo { + fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> { + let mut info: MaybeUninit<ssl::SSLChannelInfo> = MaybeUninit::uninit(); + secstatus_to_res(unsafe { + ssl::SSL_GetChannelInfo( + fd, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLChannelInfo>())?, + ) + })?; + let info = unsafe { info.assume_init() }; + Ok(Self { + version: info.protocolVersion, + cipher: info.cipherSuite, + group: Group::try_from(info.keaGroup)?, + resumed: info.resumed != 0, + early_data: info.earlyDataAccepted != 0, + ech_accepted: info.echAccepted != 0, + alpn: get_alpn(fd, false)?, + signature_scheme: SignatureScheme::try_from(info.signatureScheme)?, + }) + } + #[must_use] + pub fn version(&self) -> Version { + self.version + } + #[must_use] + pub fn cipher_suite(&self) -> Cipher { + self.cipher + } + #[must_use] + pub fn key_exchange(&self) -> Group { + self.group + } + #[must_use] + pub fn resumed(&self) -> bool { + self.resumed + } + #[must_use] + pub fn early_data_accepted(&self) -> bool { + self.early_data + } + #[must_use] + pub fn ech_accepted(&self) -> bool { + self.ech_accepted + } + #[must_use] + pub fn alpn(&self) -> Option<&String> { + self.alpn.as_ref() + } + #[must_use] + pub fn signature_scheme(&self) -> SignatureScheme { + self.signature_scheme + } +} + +/// `SecretAgent` holds the common parts of client and server. +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct SecretAgent { + fd: *mut ssl::PRFileDesc, + secrets: SecretHolder, + raw: Option<bool>, + io: Pin<Box<AgentIo>>, + state: HandshakeState, + + /// Records whether authentication of certificates is required. + auth_required: Pin<Box<bool>>, + /// Records any fatal alert that is sent by the stack. + alert: Pin<Box<Option<Alert>>>, + /// The current time. + now: TimeHolder, + + extension_handlers: Vec<ExtensionTracker>, + + /// The encrypted client hello (ECH) configuration that is in use. + /// Empty if ECH is not enabled. + ech_config: Vec<u8>, +} + +impl SecretAgent { + fn new() -> Res<Self> { + let mut io = Box::pin(AgentIo::new()); + let fd = Self::create_fd(&mut io)?; + Ok(Self { + fd, + secrets: SecretHolder::default(), + raw: None, + io, + state: HandshakeState::New, + + auth_required: Box::pin(false), + alert: Box::pin(None), + now: TimeHolder::default(), + + extension_handlers: Vec::new(), + + ech_config: Vec::new(), + }) + } + + // Create a new SSL file descriptor. + // + // Note that we create separate bindings for PRFileDesc as both + // ssl::PRFileDesc and prio::PRFileDesc. This keeps the bindings + // minimal, but it means that the two forms need casts to translate + // between them. ssl::PRFileDesc is left as an opaque type, as the + // ssl::SSL_* APIs only need an opaque type. + fn create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc> { + assert_initialized(); + let label = CString::new("sslwrapper")?; + let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) }; + + let base_fd = unsafe { prio::PR_CreateIOLayerStub(id, METHODS) }; + if base_fd.is_null() { + return Err(Error::CreateSslSocket); + } + let fd = unsafe { + (*base_fd).secret = as_c_void(io).cast(); + ssl::SSL_ImportFD(null_mut(), base_fd.cast()) + }; + if fd.is_null() { + unsafe { prio::PR_Close(base_fd) }; + return Err(Error::CreateSslSocket); + } + Ok(fd) + } + + unsafe extern "C" fn auth_complete_hook( + arg: *mut c_void, + _fd: *mut ssl::PRFileDesc, + _check_sig: ssl::PRBool, + _is_server: ssl::PRBool, + ) -> ssl::SECStatus { + let auth_required_ptr = arg.cast::<bool>(); + *auth_required_ptr = true; + // NSS insists on getting SECWouldBlock here rather than accepting + // the usual combination of PR_WOULD_BLOCK_ERROR and SECFailure. + ssl::_SECStatus_SECWouldBlock + } + + unsafe extern "C" fn alert_sent_cb( + fd: *const ssl::PRFileDesc, + arg: *mut c_void, + alert: *const ssl::SSLAlert, + ) { + let alert = alert.as_ref().unwrap(); + if alert.level == 2 { + // Fatal alerts demand attention. + let st = arg.cast::<Option<Alert>>().as_mut().unwrap(); + if st.is_none() { + *st = Some(alert.description); + } else { + qwarn!([format!("{fd:p}")], "duplicate alert {}", alert.description); + } + } + } + + // Ready this for connecting. + fn ready(&mut self, is_server: bool, grease: bool) -> Res<()> { + secstatus_to_res(unsafe { + ssl::SSL_AuthCertificateHook( + self.fd, + Some(Self::auth_complete_hook), + as_c_void(&mut self.auth_required), + ) + })?; + + secstatus_to_res(unsafe { + ssl::SSL_AlertSentCallback( + self.fd, + Some(Self::alert_sent_cb), + as_c_void(&mut self.alert), + ) + })?; + + self.now.bind(self.fd)?; + self.configure(grease)?; + secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, ssl::PRBool::from(is_server)) }) + } + + /// Default configuration. + /// + /// # Errors + /// + /// If `set_version_range` fails. + fn configure(&mut self, grease: bool) -> Res<()> { + self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; + self.set_option(ssl::Opt::Locking, false)?; + self.set_option(ssl::Opt::Tickets, false)?; + self.set_option(ssl::Opt::OcspStapling, true)?; + if let Err(e) = self.set_option(ssl::Opt::Grease, grease) { + // Until NSS supports greasing, it's OK to fail here. + qinfo!([self], "Failed to enable greasing {:?}", e); + } + Ok(()) + } + + /// Set the versions that are supported. + /// + /// # Errors + /// + /// If the range of versions isn't supported. + pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> { + let range = ssl::SSLVersionRange { min, max }; + secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) }) + } + + /// Enable a set of ciphers. Note that the order of these is not respected. + /// + /// # Errors + /// + /// If NSS can't enable or disable ciphers. + pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> { + if self.state != HandshakeState::New { + qwarn!([self], "Cannot enable ciphers in state {:?}", self.state); + return Err(Error::InternalError); + } + + let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() }; + let cipher_count = usize::from(unsafe { ssl::SSL_GetNumImplementedCiphers() }); + for i in 0..cipher_count { + let p = all_ciphers.wrapping_add(i); + secstatus_to_res(unsafe { + ssl::SSL_CipherPrefSet(self.fd, i32::from(*p), ssl::PRBool::from(false)) + })?; + } + + for c in ciphers { + secstatus_to_res(unsafe { + ssl::SSL_CipherPrefSet(self.fd, i32::from(*c), ssl::PRBool::from(true)) + })?; + } + Ok(()) + } + + /// Set key exchange groups. + /// + /// # Errors + /// + /// If the underlying API fails (which shouldn't happen). + pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> { + // SSLNamedGroup is a different size to Group, so copy one by one. + let group_vec: Vec<_> = groups + .iter() + .map(|&g| ssl::SSLNamedGroup::Type::from(g)) + .collect(); + + let ptr = group_vec.as_slice().as_ptr(); + secstatus_to_res(unsafe { + ssl::SSL_NamedGroupConfig(self.fd, ptr, c_uint::try_from(group_vec.len())?) + }) + } + + /// Set the number of additional key shares that will be sent in the client hello + /// + /// # Errors + /// + /// If the underlying API fails (which shouldn't happen). + pub fn send_additional_key_shares(&mut self, count: usize) -> Res<()> { + secstatus_to_res(unsafe { + ssl::SSL_SendAdditionalKeyShares(self.fd, c_uint::try_from(count)?) + }) + } + + /// Set TLS options. + /// + /// # Errors + /// + /// Returns an error if the option or option value is invalid; i.e., never. + pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> { + opt.set(self.fd, value) + } + + /// Enable 0-RTT. + /// + /// # Errors + /// + /// See `set_option`. + pub fn enable_0rtt(&mut self) -> Res<()> { + self.set_option(ssl::Opt::EarlyData, true) + } + + /// Disable the `EndOfEarlyData` message. + /// + /// # Errors + /// + /// See `set_option`. + pub fn disable_end_of_early_data(&mut self) -> Res<()> { + self.set_option(ssl::Opt::SuppressEndOfEarlyData, true) + } + + /// `set_alpn` sets a list of preferred protocols, starting with the most preferred. + /// Though ALPN [RFC7301] permits octet sequences, this only allows for UTF-8-encoded + /// strings. + /// + /// This asserts if no items are provided, or if any individual item is longer than + /// 255 octets in length. + /// + /// # Errors + /// + /// This should always panic rather than return an error. + /// + /// # Panics + /// + /// If any of the provided `protocols` are more than 255 bytes long. + /// + /// [RFC7301]: https://datatracker.ietf.org/doc/html/rfc7301 + pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> { + // Validate and set length. + let mut encoded_len = protocols.len(); + for v in protocols { + assert!(v.as_ref().len() < 256); + assert!(!v.as_ref().is_empty()); + encoded_len += v.as_ref().len(); + } + + // Prepare to encode. + let mut encoded = Vec::with_capacity(encoded_len); + let mut add = |v: &str| { + if let Ok(s) = u8::try_from(v.len()) { + encoded.push(s); + encoded.extend_from_slice(v.as_bytes()); + } + }; + + // NSS inherited an idiosyncratic API as a result of having implemented NPN + // before ALPN. For that reason, we need to put the "best" option last. + let (first, rest) = protocols + .split_first() + .expect("at least one ALPN value needed"); + for v in rest { + add(v.as_ref()); + } + add(first.as_ref()); + assert_eq!(encoded_len, encoded.len()); + + // Now give the result to NSS. + secstatus_to_res(unsafe { + ssl::SSL_SetNextProtoNego( + self.fd, + encoded.as_slice().as_ptr(), + c_uint::try_from(encoded.len())?, + ) + }) + } + + /// Install an extension handler. + /// + /// This can be called multiple times with different values for `ext`. The handler is provided + /// as `Rc<RefCell<dyn T>>` so that the caller is able to hold a reference to the handler + /// and later access any state that it accumulates. + /// + /// # Errors + /// + /// When the extension handler can't be successfully installed. + pub fn extension_handler( + &mut self, + ext: Extension, + handler: Rc<RefCell<dyn ExtensionHandler>>, + ) -> Res<()> { + let tracker = unsafe { ExtensionTracker::new(self.fd, ext, handler) }?; + self.extension_handlers.push(tracker); + Ok(()) + } + + // This function tracks whether handshake() or handshake_raw() was used + // and prevents the other from being used. + fn set_raw(&mut self, r: bool) -> Res<()> { + if self.raw.is_none() { + self.secrets.register(self.fd)?; + self.raw = Some(r); + Ok(()) + } else if self.raw.unwrap() == r { + Ok(()) + } else { + Err(Error::MixedHandshakeMethod) + } + } + + /// Get information about the connection. + /// This includes the version, ciphersuite, and ALPN. + /// + /// Calling this function returns None until the connection is complete. + #[must_use] + pub fn info(&self) -> Option<&SecretAgentInfo> { + match self.state { + HandshakeState::Complete(ref info) => Some(info), + _ => None, + } + } + + /// Get any preliminary information about the status of the connection. + /// + /// This includes whether 0-RTT was accepted and any information related to that. + /// Calling this function collects all the relevant information. + /// + /// # Errors + /// + /// When the underlying socket functions fail. + pub fn preinfo(&self) -> Res<SecretAgentPreInfo> { + SecretAgentPreInfo::new(self.fd) + } + + /// Get the peer's certificate chain. + #[must_use] + pub fn peer_certificate(&self) -> Option<CertificateInfo> { + CertificateInfo::new(self.fd) + } + + /// Return any fatal alert that the TLS stack might have sent. + #[must_use] + pub fn alert(&self) -> Option<&Alert> { + (*self.alert).as_ref() + } + + /// Call this function to mark the peer as authenticated. + /// + /// # Panics + /// + /// If the handshake doesn't need to be authenticated. + pub fn authenticated(&mut self, status: AuthenticationStatus) { + assert!(self.state.authentication_needed()); + *self.auth_required = false; + self.state = HandshakeState::Authenticated(status.into()); + } + + fn capture_error<T>(&mut self, res: Res<T>) -> Res<T> { + if let Err(e) = res { + let e = ech::convert_ech_error(self.fd, e); + qwarn!([self], "error: {:?}", e); + self.state = HandshakeState::Failed(e.clone()); + Err(e) + } else { + res + } + } + + fn update_state(&mut self, res: Res<()>) -> Res<()> { + self.state = if is_blocked(&res) { + if *self.auth_required { + self.preinfo()?.ech_public_name()?.map_or( + HandshakeState::AuthenticationPending, + |public_name| { + HandshakeState::EchFallbackAuthenticationPending(public_name.to_owned()) + }, + ) + } else { + HandshakeState::InProgress + } + } else { + self.capture_error(res)?; + let info = self.capture_error(SecretAgentInfo::new(self.fd))?; + HandshakeState::Complete(info) + }; + qinfo!([self], "state -> {:?}", self.state); + Ok(()) + } + + /// Drive the TLS handshake, taking bytes from `input` and putting + /// any bytes necessary into `output`. + /// This takes the current time as `now`. + /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake + /// is complete and how many bytes were written to `output`, respectively. + /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this + /// function if you want to proceed, because this will mark the certificate as OK. + /// + /// # Errors + /// + /// When the handshake fails this returns an error. + pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>> { + self.now.set(now)?; + self.set_raw(false)?; + + let rv = { + // Within this scope, _h maintains a mutable reference to self.io. + let _h = self.io.wrap(input); + match self.state { + HandshakeState::Authenticated(ref err) => unsafe { + ssl::SSL_AuthCertificateComplete(self.fd, *err) + }, + _ => unsafe { ssl::SSL_ForceHandshake(self.fd) }, + } + }; + // Take before updating state so that we leave the output buffer empty + // even if there is an error. + let output = self.io.take_output(); + self.update_state(secstatus_to_res(rv))?; + Ok(output) + } + + /// Setup to receive records for raw handshake functions. + fn setup_raw(&mut self) -> Res<Pin<Box<RecordList>>> { + self.set_raw(true)?; + self.capture_error(RecordList::setup(self.fd)) + } + + /// Drive the TLS handshake, but get the raw content of records, not + /// protected records as bytes. This function is incompatible with + /// `handshake()`; use either this or `handshake()` exclusively. + /// + /// Ideally, this only includes records from the current epoch. + /// If you send data from multiple epochs, you might end up being sad. + /// + /// # Errors + /// + /// When the handshake fails this returns an error. + pub fn handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList> { + self.now.set(now)?; + let records = self.setup_raw()?; + + // Fire off any authentication we might need to complete. + if let HandshakeState::Authenticated(ref err) = self.state { + let result = + secstatus_to_res(unsafe { ssl::SSL_AuthCertificateComplete(self.fd, *err) }); + qdebug!([self], "SSL_AuthCertificateComplete: {:?}", result); + // This should return SECSuccess, so don't use update_state(). + self.capture_error(result)?; + } + + // Feed in any records. + if let Some(rec) = input { + self.capture_error(rec.write(self.fd))?; + } + + // Drive the handshake once more. + let rv = secstatus_to_res(unsafe { ssl::SSL_ForceHandshake(self.fd) }); + self.update_state(rv)?; + + Ok(*Pin::into_inner(records)) + } + + /// # Panics + /// + /// If setup fails. + #[allow(unknown_lints, clippy::branches_sharing_code)] + pub fn close(&mut self) { + // It should be safe to close multiple times. + if self.fd.is_null() { + return; + } + if let Some(true) = self.raw { + // Need to hold the record list in scope until the close is done. + let _records = self.setup_raw().expect("Can only close"); + unsafe { prio::PR_Close(self.fd.cast()) }; + } else { + // Need to hold the IO wrapper in scope until the close is done. + let _io = self.io.wrap(&[]); + unsafe { prio::PR_Close(self.fd.cast()) }; + }; + let _output = self.io.take_output(); + self.fd = null_mut(); + } + + /// State returns the status of the handshake. + #[must_use] + pub fn state(&self) -> &HandshakeState { + &self.state + } + + /// Take a read secret. This will only return a non-`None` value once. + #[must_use] + pub fn read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> { + self.secrets.take_read(epoch) + } + + /// Take a write secret. + #[must_use] + pub fn write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> { + self.secrets.take_write(epoch) + } + + /// Get the active ECH configuration, which is empty if ECH is disabled. + #[must_use] + pub fn ech_config(&self) -> &[u8] { + &self.ech_config + } +} + +impl Drop for SecretAgent { + fn drop(&mut self) { + self.close(); + } +} + +impl ::std::fmt::Display for SecretAgent { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Agent {:p}", self.fd) + } +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone)] +pub struct ResumptionToken { + token: Vec<u8>, + expiration_time: Instant, +} + +impl AsRef<[u8]> for ResumptionToken { + fn as_ref(&self) -> &[u8] { + &self.token + } +} + +impl ResumptionToken { + #[must_use] + pub fn new(token: Vec<u8>, expiration_time: Instant) -> Self { + Self { + token, + expiration_time, + } + } + + #[must_use] + pub fn expiration_time(&self) -> Instant { + self.expiration_time + } +} + +/// A TLS Client. +#[derive(Debug)] +#[allow( + renamed_and_removed_lints, + clippy::box_vec, + unknown_lints, + clippy::box_collection +)] // We need the Box. +pub struct Client { + agent: SecretAgent, + + /// The name of the server we're attempting a connection to. + server_name: String, + /// Records the resumption tokens we've received. + resumption: Pin<Box<Vec<ResumptionToken>>>, +} + +impl Client { + /// Create a new client agent. + /// + /// # Errors + /// + /// Errors returned if the socket can't be created or configured. + pub fn new(server_name: impl Into<String>, grease: bool) -> Res<Self> { + let server_name = server_name.into(); + let mut agent = SecretAgent::new()?; + let url = CString::new(server_name.as_bytes())?; + secstatus_to_res(unsafe { ssl::SSL_SetURL(agent.fd, url.as_ptr()) })?; + agent.ready(false, grease)?; + let mut client = Self { + agent, + server_name, + resumption: Box::pin(Vec::new()), + }; + client.ready()?; + Ok(client) + } + + unsafe extern "C" fn resumption_token_cb( + fd: *mut ssl::PRFileDesc, + token: *const u8, + len: c_uint, + arg: *mut c_void, + ) -> ssl::SECStatus { + let mut info: MaybeUninit<ssl::SSLResumptionTokenInfo> = MaybeUninit::uninit(); + if ssl::SSL_GetResumptionTokenInfo( + token, + len, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLResumptionTokenInfo>()).unwrap(), + ) + .is_err() + { + // Ignore the token. + return ssl::SECSuccess; + } + let expiration_time = info.assume_init().expirationTime; + if ssl::SSL_DestroyResumptionTokenInfo(info.as_mut_ptr()).is_err() { + // Ignore the token. + return ssl::SECSuccess; + } + let resumption = arg.cast::<Vec<ResumptionToken>>().as_mut().unwrap(); + let len = usize::try_from(len).unwrap(); + let mut v = Vec::with_capacity(len); + v.extend_from_slice(std::slice::from_raw_parts(token, len)); + qinfo!( + [format!("{fd:p}")], + "Got resumption token {}", + hex_snip_middle(&v) + ); + + if resumption.len() >= MAX_TICKETS { + resumption.remove(0); + } + if let Ok(t) = Time::try_from(expiration_time) { + resumption.push(ResumptionToken::new(v, *t)); + } + ssl::SECSuccess + } + + #[must_use] + pub fn server_name(&self) -> &str { + &self.server_name + } + + fn ready(&mut self) -> Res<()> { + let fd = self.fd; + unsafe { + ssl::SSL_SetResumptionTokenCallback( + fd, + Some(Self::resumption_token_cb), + as_c_void(&mut self.resumption), + ) + } + } + + /// Take a resumption token. + #[must_use] + pub fn resumption_token(&mut self) -> Option<ResumptionToken> { + (*self.resumption).pop() + } + + /// Check if there are more resumption tokens. + #[must_use] + pub fn has_resumption_token(&self) -> bool { + !(*self.resumption).is_empty() + } + + /// Enable resumption, using a token previously provided. + /// + /// # Errors + /// + /// Error returned when the resumption token is invalid or + /// the socket is not able to use the value. + pub fn enable_resumption(&mut self, token: impl AsRef<[u8]>) -> Res<()> { + unsafe { + ssl::SSL_SetResumptionToken( + self.agent.fd, + token.as_ref().as_ptr(), + c_uint::try_from(token.as_ref().len())?, + ) + } + } + + /// Enable encrypted client hello (ECH), using the encoded `ECHConfigList`. + /// + /// When ECH is enabled, a client needs to look for `Error::EchRetry` as a + /// failure code. If `Error::EchRetry` is received when connecting, the + /// connection attempt should be retried and the included value provided + /// to this function (instead of what is received from DNS). + /// + /// Calling this function with an empty value for `ech_config_list` enables + /// ECH greasing. When that is done, there is no need to look for `EchRetry` + /// + /// # Errors + /// + /// Error returned when the configuration is invalid. + pub fn enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> { + let config = ech_config_list.as_ref(); + qdebug!([self], "Enable ECH for a server: {}", hex_with_len(config)); + self.ech_config = Vec::from(config); + if config.is_empty() { + unsafe { ech::SSL_EnableTls13GreaseEch(self.agent.fd, PRBool::from(true)) } + } else { + unsafe { + ech::SSL_SetClientEchConfigs( + self.agent.fd, + config.as_ptr(), + c_uint::try_from(config.len())?, + ) + } + } + } +} + +impl Deref for Client { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + &self.agent + } +} + +impl DerefMut for Client { + fn deref_mut(&mut self) -> &mut SecretAgent { + &mut self.agent + } +} + +impl ::std::fmt::Display for Client { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Client {:p}", self.agent.fd) + } +} + +/// `ZeroRttCheckResult` encapsulates the options for handling a `ClientHello`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ZeroRttCheckResult { + /// Accept 0-RTT. + Accept, + /// Reject 0-RTT, but continue the handshake normally. + Reject, + /// Send HelloRetryRequest (probably not needed for QUIC). + HelloRetryRequest(Vec<u8>), + /// Fail the handshake. + Fail, +} + +/// A `ZeroRttChecker` is used by the agent to validate the application token (as provided by +/// `send_ticket`) +pub trait ZeroRttChecker: std::fmt::Debug + std::marker::Unpin { + fn check(&self, token: &[u8]) -> ZeroRttCheckResult; +} + +/// Using `AllowZeroRtt` for the implementation of `ZeroRttChecker` means +/// accepting 0-RTT always. This generally isn't a great idea, so this +/// generates a strong warning when it is used. +#[derive(Debug)] +pub struct AllowZeroRtt {} +impl ZeroRttChecker for AllowZeroRtt { + fn check(&self, _token: &[u8]) -> ZeroRttCheckResult { + qwarn!("AllowZeroRtt accepting 0-RTT"); + ZeroRttCheckResult::Accept + } +} + +#[derive(Debug)] +struct ZeroRttCheckState { + checker: Pin<Box<dyn ZeroRttChecker>>, +} + +impl ZeroRttCheckState { + pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self { + Self { + checker: Pin::new(checker), + } + } +} + +#[derive(Debug)] +pub struct Server { + agent: SecretAgent, + /// This holds the HRR callback context. + zero_rtt_check: Option<Pin<Box<ZeroRttCheckState>>>, +} + +impl Server { + /// Create a new server agent. + /// + /// # Errors + /// + /// Errors returned when NSS fails. + pub fn new(certificates: &[impl AsRef<str>]) -> Res<Self> { + let mut agent = SecretAgent::new()?; + + for n in certificates { + let c = CString::new(n.as_ref())?; + let cert_ptr = unsafe { p11::PK11_FindCertFromNickname(c.as_ptr(), null_mut()) }; + let Ok(cert) = p11::Certificate::from_ptr(cert_ptr) else { + return Err(Error::CertificateLoading); + }; + let key_ptr = unsafe { p11::PK11_FindKeyByAnyCert(*cert, null_mut()) }; + let Ok(key) = p11::PrivateKey::from_ptr(key_ptr) else { + return Err(Error::CertificateLoading); + }; + secstatus_to_res(unsafe { + ssl::SSL_ConfigServerCert(agent.fd, *cert, *key, null(), 0) + })?; + } + + agent.ready(true, true)?; + Ok(Self { + agent, + zero_rtt_check: None, + }) + } + + unsafe extern "C" fn hello_retry_cb( + first_hello: PRBool, + client_token: *const u8, + client_token_len: c_uint, + retry_token: *mut u8, + retry_token_len: *mut c_uint, + retry_token_max: c_uint, + arg: *mut c_void, + ) -> ssl::SSLHelloRetryRequestAction::Type { + if first_hello == 0 { + // On the second ClientHello after HelloRetryRequest, skip checks. + return ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept; + } + + let check_state = arg.cast::<ZeroRttCheckState>().as_mut().unwrap(); + let token = if client_token.is_null() { + &[] + } else { + std::slice::from_raw_parts(client_token, usize::try_from(client_token_len).unwrap()) + }; + match check_state.checker.check(token) { + ZeroRttCheckResult::Accept => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept, + ZeroRttCheckResult::Fail => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_fail, + ZeroRttCheckResult::Reject => { + ssl::SSLHelloRetryRequestAction::ssl_hello_retry_reject_0rtt + } + ZeroRttCheckResult::HelloRetryRequest(tok) => { + // Don't bother propagating errors from this, because it should be caught in + // testing. + assert!(tok.len() <= usize::try_from(retry_token_max).unwrap()); + let slc = std::slice::from_raw_parts_mut(retry_token, tok.len()); + slc.copy_from_slice(&tok); + *retry_token_len = c_uint::try_from(tok.len()).unwrap(); + ssl::SSLHelloRetryRequestAction::ssl_hello_retry_request + } + } + } + + /// Enable 0-RTT. This shadows the function of the same name that can be accessed + /// via the Deref implementation on Server. + /// + /// # Errors + /// + /// Returns an error if the underlying NSS functions fail. + pub fn enable_0rtt( + &mut self, + anti_replay: &AntiReplay, + max_early_data: u32, + checker: Box<dyn ZeroRttChecker>, + ) -> Res<()> { + let mut check_state = Box::pin(ZeroRttCheckState::new(checker)); + unsafe { + ssl::SSL_HelloRetryRequestCallback( + self.agent.fd, + Some(Self::hello_retry_cb), + as_c_void(&mut check_state), + ) + }?; + unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?; + self.zero_rtt_check = Some(check_state); + self.agent.enable_0rtt()?; + anti_replay.config_socket(self.fd)?; + Ok(()) + } + + /// Send a session ticket to the client. + /// This adds |extra| application-specific content into that ticket. + /// The records that are sent are captured and returned. + /// + /// # Errors + /// + /// If NSS is unable to send a ticket, or if this agent is incorrectly configured. + pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList> { + self.agent.now.set(now)?; + let records = self.setup_raw()?; + + unsafe { + ssl::SSL_SendSessionTicket(self.fd, extra.as_ptr(), c_uint::try_from(extra.len())?) + }?; + + Ok(*Pin::into_inner(records)) + } + + /// Enable encrypted client hello (ECH). + /// + /// # Errors + /// + /// Fails when NSS cannot create a key pair. + pub fn enable_ech( + &mut self, + config: u8, + public_name: &str, + sk: &PrivateKey, + pk: &PublicKey, + ) -> Res<()> { + let cfg = ech::encode_config(config, public_name, pk)?; + qdebug!([self], "Enable ECH for a server: {}", hex_with_len(&cfg)); + unsafe { + ech::SSL_SetServerEchConfigs( + self.agent.fd, + **pk, + **sk, + cfg.as_ptr(), + c_uint::try_from(cfg.len())?, + )?; + }; + self.ech_config = cfg; + Ok(()) + } +} + +impl Deref for Server { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + &self.agent + } +} + +impl DerefMut for Server { + fn deref_mut(&mut self) -> &mut SecretAgent { + &mut self.agent + } +} + +impl ::std::fmt::Display for Server { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Server {:p}", self.agent.fd) + } +} + +/// A generic container for Client or Server. +#[derive(Debug)] +pub enum Agent { + Client(crate::agent::Client), + Server(crate::agent::Server), +} + +impl Deref for Agent { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + match self { + Self::Client(c) => c, + Self::Server(s) => s, + } + } +} + +impl DerefMut for Agent { + fn deref_mut(&mut self) -> &mut SecretAgent { + match self { + Self::Client(c) => c, + Self::Server(s) => s, + } + } +} + +impl From<Client> for Agent { + #[must_use] + fn from(c: Client) -> Self { + Self::Client(c) + } +} + +impl From<Server> for Agent { + #[must_use] + fn from(s: Server) -> Self { + Self::Server(s) + } +} diff --git a/third_party/rust/neqo-crypto/src/agentio.rs b/third_party/rust/neqo-crypto/src/agentio.rs new file mode 100644 index 0000000000..2bcc540530 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/agentio.rs @@ -0,0 +1,396 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cmp::min, + convert::{TryFrom, TryInto}, + fmt, mem, + ops::Deref, + os::raw::{c_uint, c_void}, + pin::Pin, + ptr::{null, null_mut}, + vec::Vec, +}; + +use neqo_common::{hex, hex_with_len, qtrace}; + +use crate::{ + constants::{ContentType, Epoch}, + err::{nspr, Error, PR_SetError, Res}, + prio, ssl, +}; + +// Alias common types. +type PrFd = *mut prio::PRFileDesc; +type PrStatus = prio::PRStatus::Type; +const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS; +const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE; + +/// Convert a pinned, boxed object into a void pointer. +pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { + (Pin::into_inner(pin.as_mut()) as *mut T).cast() +} + +/// A slice of the output. +#[derive(Default)] +pub struct Record { + pub epoch: Epoch, + pub ct: ContentType, + pub data: Vec<u8>, +} + +impl Record { + #[must_use] + pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self { + Self { + epoch, + ct, + data: data.to_vec(), + } + } + + // Shoves this record into the socket, returns true if blocked. + pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> { + qtrace!("write {:?}", self); + unsafe { + ssl::SSL_RecordLayerData( + fd, + self.epoch, + ssl::SSLContentType::Type::from(self.ct), + self.data.as_ptr(), + c_uint::try_from(self.data.len())?, + ) + } + } +} + +impl fmt::Debug for Record { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Record {:?}:{:?} {}", + self.epoch, + self.ct, + hex_with_len(&self.data[..]) + ) + } +} + +#[derive(Debug, Default)] +pub struct RecordList { + records: Vec<Record>, +} + +impl RecordList { + fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) { + self.records.push(Record::new(epoch, ct, data)); + } + + #[allow(clippy::unused_self)] + unsafe extern "C" fn ingest( + _fd: *mut ssl::PRFileDesc, + epoch: ssl::PRUint16, + ct: ssl::SSLContentType::Type, + data: *const ssl::PRUint8, + len: c_uint, + arg: *mut c_void, + ) -> ssl::SECStatus { + let records = arg.cast::<Self>().as_mut().unwrap(); + + let slice = std::slice::from_raw_parts(data, len as usize); + records.append(epoch, ContentType::try_from(ct).unwrap(), slice); + ssl::SECSuccess + } + + /// Create a new record list. + pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> { + let mut records = Box::pin(Self::default()); + unsafe { + ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records)) + }?; + Ok(records) + } +} + +impl Deref for RecordList { + type Target = Vec<Record>; + #[must_use] + fn deref(&self) -> &Vec<Record> { + &self.records + } +} + +pub struct RecordListIter(std::vec::IntoIter<Record>); + +impl Iterator for RecordListIter { + type Item = Record; + fn next(&mut self) -> Option<Self::Item> { + self.0.next() + } +} + +impl IntoIterator for RecordList { + type Item = Record; + type IntoIter = RecordListIter; + #[must_use] + fn into_iter(self) -> Self::IntoIter { + RecordListIter(self.records.into_iter()) + } +} + +pub struct AgentIoInputContext<'a> { + input: &'a mut AgentIoInput, +} + +impl<'a> Drop for AgentIoInputContext<'a> { + fn drop(&mut self) { + self.input.reset(); + } +} + +#[derive(Debug)] +struct AgentIoInput { + // input is data that is read by TLS. + input: *const u8, + // input_available is how much data is left for reading. + available: usize, +} + +impl AgentIoInput { + fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { + assert!(self.input.is_null()); + self.input = input.as_ptr(); + self.available = input.len(); + qtrace!("AgentIoInput wrap {:p}", self.input); + AgentIoInputContext { input: self } + } + + // Take the data provided as input and provide it to the TLS stack. + fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> { + let amount = min(self.available, count); + if amount == 0 { + unsafe { + PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0); + } + return Err(Error::NoDataAvailable); + } + + let src = unsafe { std::slice::from_raw_parts(self.input, amount) }; + qtrace!([self], "read {}", hex(src)); + let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) }; + dst.copy_from_slice(src); + self.input = self.input.wrapping_add(amount); + self.available -= amount; + Ok(amount) + } + + fn reset(&mut self) { + qtrace!([self], "reset"); + self.input = null(); + self.available = 0; + } +} + +impl ::std::fmt::Display for AgentIoInput { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "AgentIoInput {:p}", self.input) + } +} + +#[derive(Debug)] +pub struct AgentIo { + // input collects the input we might provide to TLS. + input: AgentIoInput, + + // output contains data that is written by TLS. + output: Vec<u8>, +} + +impl AgentIo { + pub fn new() -> Self { + Self { + input: AgentIoInput { + input: null(), + available: 0, + }, + output: Vec::new(), + } + } + + unsafe fn borrow(fd: &mut PrFd) -> &mut Self { + #[allow(clippy::cast_ptr_alignment)] + (**fd).secret.cast::<Self>().as_mut().unwrap() + } + + pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { + assert_eq!(self.output.len(), 0); + self.input.wrap(input) + } + + // Stage output from TLS into the output buffer. + fn save_output(&mut self, buf: *const u8, count: usize) { + let slice = unsafe { std::slice::from_raw_parts(buf, count) }; + qtrace!([self], "save output {}", hex(slice)); + self.output.extend_from_slice(slice); + } + + pub fn take_output(&mut self) -> Vec<u8> { + qtrace!([self], "take output"); + mem::take(&mut self.output) + } +} + +impl ::std::fmt::Display for AgentIo { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "AgentIo") + } +} + +unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus { + (*fd).secret = null_mut(); + if let Some(dtor) = (*fd).dtor { + dtor(fd); + } + PR_SUCCESS +} + +unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus { + let io = AgentIo::borrow(&mut fd); + if let Ok(a) = usize::try_from(amount) { + match io.input.read_input(buf.cast(), a) { + Ok(_) => PR_SUCCESS, + Err(_) => PR_FAILURE, + } + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_recv( + mut fd: PrFd, + buf: *mut c_void, + amount: prio::PRInt32, + flags: prio::PRIntn, + _timeout: prio::PRIntervalTime, +) -> prio::PRInt32 { + let io = AgentIo::borrow(&mut fd); + if flags != 0 { + return PR_FAILURE; + } + if let Ok(a) = usize::try_from(amount) { + match io.input.read_input(buf.cast(), a) { + Ok(v) => prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE), + Err(_) => PR_FAILURE, + } + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_write( + mut fd: PrFd, + buf: *const c_void, + amount: prio::PRInt32, +) -> PrStatus { + let io = AgentIo::borrow(&mut fd); + if let Ok(a) = usize::try_from(amount) { + io.save_output(buf.cast(), a); + amount + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_send( + mut fd: PrFd, + buf: *const c_void, + amount: prio::PRInt32, + flags: prio::PRIntn, + _timeout: prio::PRIntervalTime, +) -> prio::PRInt32 { + let io = AgentIo::borrow(&mut fd); + + if flags != 0 { + return PR_FAILURE; + } + if let Ok(a) = usize::try_from(amount) { + io.save_output(buf.cast(), a); + amount + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 { + let io = AgentIo::borrow(&mut fd); + io.input.available.try_into().unwrap_or(PR_FAILURE) +} + +unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 { + let io = AgentIo::borrow(&mut fd); + io.input + .available + .try_into() + .unwrap_or_else(|_| PR_FAILURE.into()) +} + +#[allow(clippy::cast_possible_truncation)] +unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus { + let a = addr.as_mut().unwrap(); + // Cast is safe because prio::PR_AF_INET is 2 + a.inet.family = prio::PR_AF_INET as prio::PRUint16; + a.inet.port = 0; + a.inet.ip = 0; + PR_SUCCESS +} + +unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus { + let o = opt.as_mut().unwrap(); + if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking { + o.value.non_blocking = 1; + return PR_SUCCESS; + } + PR_FAILURE +} + +pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods { + file_type: prio::PRDescType::PR_DESC_LAYERED, + close: Some(agent_close), + read: Some(agent_read), + write: Some(agent_write), + available: Some(agent_available), + available64: Some(agent_available64), + fsync: None, + seek: None, + seek64: None, + fileInfo: None, + fileInfo64: None, + writev: None, + connect: None, + accept: None, + bind: None, + listen: None, + shutdown: None, + recv: Some(agent_recv), + send: Some(agent_send), + recvfrom: None, + sendto: None, + poll: None, + acceptread: None, + transmitfile: None, + getsockname: Some(agent_getname), + getpeername: Some(agent_getname), + reserved_fn_6: None, + reserved_fn_5: None, + getsocketoption: Some(agent_getsockopt), + setsocketoption: None, + sendfile: None, + connectcontinue: None, + reserved_fn_3: None, + reserved_fn_2: None, + reserved_fn_1: None, + reserved_fn_0: None, +}; diff --git a/third_party/rust/neqo-crypto/src/auth.rs b/third_party/rust/neqo-crypto/src/auth.rs new file mode 100644 index 0000000000..2932cdf2eb --- /dev/null +++ b/third_party/rust/neqo-crypto/src/auth.rs @@ -0,0 +1,108 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::err::{mozpkix, sec, ssl, PRErrorCode}; + +/// The outcome of authentication. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AuthenticationStatus { + Ok, + CaInvalid, + CaNotV3, + CertAlgorithmDisabled, + CertExpired, + CertInvalidTime, + CertIsCa, + CertKeyUsage, + CertMitm, + CertNotYetValid, + CertRevoked, + CertSelfSigned, + CertSubjectInvalid, + CertUntrusted, + CertWeakKey, + IssuerEmptyName, + IssuerExpired, + IssuerNotYetValid, + IssuerUnknown, + IssuerUntrusted, + PolicyRejection, + Unknown, +} + +impl From<AuthenticationStatus> for PRErrorCode { + #[must_use] + fn from(v: AuthenticationStatus) -> Self { + match v { + AuthenticationStatus::Ok => 0, + AuthenticationStatus::CaInvalid => sec::SEC_ERROR_CA_CERT_INVALID, + AuthenticationStatus::CaNotV3 => mozpkix::MOZILLA_PKIX_ERROR_V1_CERT_USED_AS_CA, + AuthenticationStatus::CertAlgorithmDisabled => { + sec::SEC_ERROR_CERT_SIGNATURE_ALGORITHM_DISABLED + } + AuthenticationStatus::CertExpired => sec::SEC_ERROR_EXPIRED_CERTIFICATE, + AuthenticationStatus::CertInvalidTime => sec::SEC_ERROR_INVALID_TIME, + AuthenticationStatus::CertIsCa => { + mozpkix::MOZILLA_PKIX_ERROR_CA_CERT_USED_AS_END_ENTITY + } + AuthenticationStatus::CertKeyUsage => sec::SEC_ERROR_INADEQUATE_KEY_USAGE, + AuthenticationStatus::CertMitm => mozpkix::MOZILLA_PKIX_ERROR_MITM_DETECTED, + AuthenticationStatus::CertNotYetValid => { + mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_CERTIFICATE + } + AuthenticationStatus::CertRevoked => sec::SEC_ERROR_REVOKED_CERTIFICATE, + AuthenticationStatus::CertSelfSigned => mozpkix::MOZILLA_PKIX_ERROR_SELF_SIGNED_CERT, + AuthenticationStatus::CertSubjectInvalid => ssl::SSL_ERROR_BAD_CERT_DOMAIN, + AuthenticationStatus::CertUntrusted => sec::SEC_ERROR_UNTRUSTED_CERT, + AuthenticationStatus::CertWeakKey => mozpkix::MOZILLA_PKIX_ERROR_INADEQUATE_KEY_SIZE, + AuthenticationStatus::IssuerEmptyName => mozpkix::MOZILLA_PKIX_ERROR_EMPTY_ISSUER_NAME, + AuthenticationStatus::IssuerExpired => sec::SEC_ERROR_EXPIRED_ISSUER_CERTIFICATE, + AuthenticationStatus::IssuerNotYetValid => { + mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_ISSUER_CERTIFICATE + } + AuthenticationStatus::IssuerUnknown => sec::SEC_ERROR_UNKNOWN_ISSUER, + AuthenticationStatus::IssuerUntrusted => sec::SEC_ERROR_UNTRUSTED_ISSUER, + AuthenticationStatus::PolicyRejection => { + mozpkix::MOZILLA_PKIX_ERROR_ADDITIONAL_POLICY_CONSTRAINT_FAILED + } + AuthenticationStatus::Unknown => sec::SEC_ERROR_LIBRARY_FAILURE, + } + } +} + +// Note that this mapping should be removed after gecko eventually learns how to +// map into the enumerated type. +impl From<PRErrorCode> for AuthenticationStatus { + #[must_use] + fn from(v: PRErrorCode) -> Self { + match v { + 0 => Self::Ok, + sec::SEC_ERROR_CA_CERT_INVALID => Self::CaInvalid, + mozpkix::MOZILLA_PKIX_ERROR_V1_CERT_USED_AS_CA => Self::CaNotV3, + sec::SEC_ERROR_CERT_SIGNATURE_ALGORITHM_DISABLED => Self::CertAlgorithmDisabled, + sec::SEC_ERROR_EXPIRED_CERTIFICATE => Self::CertExpired, + sec::SEC_ERROR_INVALID_TIME => Self::CertInvalidTime, + mozpkix::MOZILLA_PKIX_ERROR_CA_CERT_USED_AS_END_ENTITY => Self::CertIsCa, + sec::SEC_ERROR_INADEQUATE_KEY_USAGE => Self::CertKeyUsage, + mozpkix::MOZILLA_PKIX_ERROR_MITM_DETECTED => Self::CertMitm, + mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_CERTIFICATE => Self::CertNotYetValid, + sec::SEC_ERROR_REVOKED_CERTIFICATE => Self::CertRevoked, + mozpkix::MOZILLA_PKIX_ERROR_SELF_SIGNED_CERT => Self::CertSelfSigned, + ssl::SSL_ERROR_BAD_CERT_DOMAIN => Self::CertSubjectInvalid, + sec::SEC_ERROR_UNTRUSTED_CERT => Self::CertUntrusted, + mozpkix::MOZILLA_PKIX_ERROR_INADEQUATE_KEY_SIZE => Self::CertWeakKey, + mozpkix::MOZILLA_PKIX_ERROR_EMPTY_ISSUER_NAME => Self::IssuerEmptyName, + sec::SEC_ERROR_EXPIRED_ISSUER_CERTIFICATE => Self::IssuerExpired, + mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_ISSUER_CERTIFICATE => Self::IssuerNotYetValid, + sec::SEC_ERROR_UNKNOWN_ISSUER => Self::IssuerUnknown, + sec::SEC_ERROR_UNTRUSTED_ISSUER => Self::IssuerUntrusted, + mozpkix::MOZILLA_PKIX_ERROR_ADDITIONAL_POLICY_CONSTRAINT_FAILED => { + Self::PolicyRejection + } + _ => Self::Unknown, + } + } +} diff --git a/third_party/rust/neqo-crypto/src/cert.rs b/third_party/rust/neqo-crypto/src/cert.rs new file mode 100644 index 0000000000..64e63ec71a --- /dev/null +++ b/third_party/rust/neqo-crypto/src/cert.rs @@ -0,0 +1,120 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + convert::TryFrom, + ptr::{addr_of, NonNull}, + slice, +}; + +use neqo_common::qerror; + +use crate::{ + err::secstatus_to_res, + p11::{CERTCertListNode, CERT_GetCertificateDer, CertList, Item, SECItem, SECItemArray}, + ssl::{ + PRFileDesc, SSL_PeerCertificateChain, SSL_PeerSignedCertTimestamps, + SSL_PeerStapledOCSPResponses, + }, +}; + +pub struct CertificateInfo { + certs: CertList, + cursor: *const CERTCertListNode, + /// stapled_ocsp_responses and signed_cert_timestamp are properties + /// associated with each of the certificates. Right now, NSS only + /// reports the value for the end-entity certificate (the first). + stapled_ocsp_responses: Option<Vec<Vec<u8>>>, + signed_cert_timestamp: Option<Vec<u8>>, +} + +fn peer_certificate_chain(fd: *mut PRFileDesc) -> Option<(CertList, *const CERTCertListNode)> { + let chain = unsafe { SSL_PeerCertificateChain(fd) }; + CertList::from_ptr(chain.cast()).ok().map(|certs| { + let cursor = CertificateInfo::head(&certs); + (certs, cursor) + }) +} + +// As explained in rfc6961, an OCSPResponseList can have at most +// 2^24 items. Casting its length is therefore safe even on 32 bits targets. +fn stapled_ocsp_responses(fd: *mut PRFileDesc) -> Option<Vec<Vec<u8>>> { + let ocsp_nss = unsafe { SSL_PeerStapledOCSPResponses(fd) }; + match NonNull::new(ocsp_nss as *mut SECItemArray) { + Some(ocsp_ptr) => { + let mut ocsp_helper: Vec<Vec<u8>> = Vec::new(); + let Ok(len) = isize::try_from(unsafe { ocsp_ptr.as_ref().len }) else { + qerror!([format!("{fd:p}")], "Received illegal OSCP length"); + return None; + }; + for idx in 0..len { + let itemp: *const SECItem = unsafe { ocsp_ptr.as_ref().items.offset(idx).cast() }; + let item = unsafe { slice::from_raw_parts((*itemp).data, (*itemp).len as usize) }; + ocsp_helper.push(item.to_owned()); + } + Some(ocsp_helper) + } + None => None, + } +} + +fn signed_cert_timestamp(fd: *mut PRFileDesc) -> Option<Vec<u8>> { + let sct_nss = unsafe { SSL_PeerSignedCertTimestamps(fd) }; + match NonNull::new(sct_nss as *mut SECItem) { + Some(sct_ptr) => { + if unsafe { sct_ptr.as_ref().len == 0 || sct_ptr.as_ref().data.is_null() } { + Some(Vec::new()) + } else { + let sct_slice = unsafe { + slice::from_raw_parts(sct_ptr.as_ref().data, sct_ptr.as_ref().len as usize) + }; + Some(sct_slice.to_owned()) + } + } + None => None, + } +} + +impl CertificateInfo { + pub(crate) fn new(fd: *mut PRFileDesc) -> Option<Self> { + peer_certificate_chain(fd).map(|(certs, cursor)| Self { + certs, + cursor, + stapled_ocsp_responses: stapled_ocsp_responses(fd), + signed_cert_timestamp: signed_cert_timestamp(fd), + }) + } + + fn head(certs: &CertList) -> *const CERTCertListNode { + // Three stars: one for the reference, one for the wrapper, one to deference the pointer. + unsafe { addr_of!((***certs).list).cast() } + } +} + +impl<'a> Iterator for &'a mut CertificateInfo { + type Item = &'a [u8]; + fn next(&mut self) -> Option<&'a [u8]> { + self.cursor = unsafe { *self.cursor }.links.next.cast(); + if self.cursor == CertificateInfo::head(&self.certs) { + return None; + } + let mut item = Item::make_empty(); + let cert = unsafe { *self.cursor }.cert; + secstatus_to_res(unsafe { CERT_GetCertificateDer(cert, &mut item) }) + .expect("getting DER from certificate should work"); + Some(unsafe { std::slice::from_raw_parts(item.data, item.len as usize) }) + } +} + +impl CertificateInfo { + pub fn stapled_ocsp_responses(&mut self) -> &Option<Vec<Vec<u8>>> { + &self.stapled_ocsp_responses + } + + pub fn signed_cert_timestamp(&mut self) -> &Option<Vec<u8>> { + &self.signed_cert_timestamp + } +} diff --git a/third_party/rust/neqo-crypto/src/constants.rs b/third_party/rust/neqo-crypto/src/constants.rs new file mode 100644 index 0000000000..76db972290 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/constants.rs @@ -0,0 +1,146 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code)] + +use crate::ssl; + +// Ideally all of these would be enums, but size matters and we need to allow +// for values outside of those that are defined here. + +pub type Alert = u8; + +pub type Epoch = u16; +// TLS doesn't really have an "initial" concept that maps to QUIC so directly, +// but this should be clear enough. +pub const TLS_EPOCH_INITIAL: Epoch = 0_u16; +pub const TLS_EPOCH_ZERO_RTT: Epoch = 1_u16; +pub const TLS_EPOCH_HANDSHAKE: Epoch = 2_u16; +// Also, we don't use TLS epochs > 3. +pub const TLS_EPOCH_APPLICATION_DATA: Epoch = 3_u16; + +/// Rather than defining a type alias and a bunch of constants, which leads to a ton of repetition, +/// use this macro. +macro_rules! remap_enum { + { $t:ident: $s:ty { $( $n:ident = $v:path ),+ $(,)? } } => { + pub type $t = $s; + $( pub const $n: $t = $v as $t; )+ + }; + { $t:ident: $s:ty => $e:ident { $( $n:ident = $v:ident ),+ $(,)? } } => { + remap_enum!{ $t: $s { $( $n = $e::$v ),+ } } + }; + { $t:ident: $s:ty => $p:ident::$e:ident { $( $n:ident = $v:ident ),+ $(,)? } } => { + remap_enum!{ $t: $s { $( $n = $p::$e::$v ),+ } } + }; +} + +remap_enum! { + Version: u16 => ssl { + TLS_VERSION_1_2 = SSL_LIBRARY_VERSION_TLS_1_2, + TLS_VERSION_1_3 = SSL_LIBRARY_VERSION_TLS_1_3, + } +} + +mod ciphers { + include!(concat!(env!("OUT_DIR"), "/nss_ciphers.rs")); +} + +remap_enum! { + Cipher: u16 => ciphers { + TLS_AES_128_GCM_SHA256 = TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384 = TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256 = TLS_CHACHA20_POLY1305_SHA256, + } +} + +remap_enum! { + Group: u16 => ssl::SSLNamedGroup { + TLS_GRP_EC_SECP256R1 = ssl_grp_ec_secp256r1, + TLS_GRP_EC_SECP384R1 = ssl_grp_ec_secp384r1, + TLS_GRP_EC_SECP521R1 = ssl_grp_ec_secp521r1, + TLS_GRP_EC_X25519 = ssl_grp_ec_curve25519, + TLS_GRP_KEM_XYBER768D00 = ssl_grp_kem_xyber768d00, + } +} + +remap_enum! { + HandshakeMessage: u8 => ssl::SSLHandshakeType { + TLS_HS_HELLO_REQUEST = ssl_hs_hello_request, + TLS_HS_CLIENT_HELLO = ssl_hs_client_hello, + TLS_HS_SERVER_HELLO = ssl_hs_server_hello, + TLS_HS_HELLO_VERIFY_REQUEST = ssl_hs_hello_verify_request, + TLS_HS_NEW_SESSION_TICKET = ssl_hs_new_session_ticket, + TLS_HS_END_OF_EARLY_DATA = ssl_hs_end_of_early_data, + TLS_HS_HELLO_RETRY_REQUEST = ssl_hs_hello_retry_request, + TLS_HS_ENCRYPTED_EXTENSIONS = ssl_hs_encrypted_extensions, + TLS_HS_CERTIFICATE = ssl_hs_certificate, + TLS_HS_SERVER_KEY_EXCHANGE = ssl_hs_server_key_exchange, + TLS_HS_CERTIFICATE_REQUEST = ssl_hs_certificate_request, + TLS_HS_SERVER_HELLO_DONE = ssl_hs_server_hello_done, + TLS_HS_CERTIFICATE_VERIFY = ssl_hs_certificate_verify, + TLS_HS_CLIENT_KEY_EXCHANGE = ssl_hs_client_key_exchange, + TLS_HS_FINISHED = ssl_hs_finished, + TLS_HS_CERT_STATUS = ssl_hs_certificate_status, + TLS_HS_KEY_UDPATE = ssl_hs_key_update, + } +} + +remap_enum! { + ContentType: u8 => ssl::SSLContentType { + TLS_CT_CHANGE_CIPHER_SPEC = ssl_ct_change_cipher_spec, + TLS_CT_ALERT = ssl_ct_alert, + TLS_CT_HANDSHAKE = ssl_ct_handshake, + TLS_CT_APPLICATION_DATA = ssl_ct_application_data, + TLS_CT_ACK = ssl_ct_ack, + } +} + +remap_enum! { + Extension: u16 => ssl::SSLExtensionType { + TLS_EXT_SERVER_NAME = ssl_server_name_xtn, + TLS_EXT_CERT_STATUS = ssl_cert_status_xtn, + TLS_EXT_GROUPS = ssl_supported_groups_xtn, + TLS_EXT_EC_POINT_FORMATS = ssl_ec_point_formats_xtn, + TLS_EXT_SIG_SCHEMES = ssl_signature_algorithms_xtn, + TLS_EXT_USE_SRTP = ssl_use_srtp_xtn, + TLS_EXT_ALPN = ssl_app_layer_protocol_xtn, + TLS_EXT_SCT = ssl_signed_cert_timestamp_xtn, + TLS_EXT_PADDING = ssl_padding_xtn, + TLS_EXT_EMS = ssl_extended_master_secret_xtn, + TLS_EXT_RECORD_SIZE = ssl_record_size_limit_xtn, + TLS_EXT_SESSION_TICKET = ssl_session_ticket_xtn, + TLS_EXT_PSK = ssl_tls13_pre_shared_key_xtn, + TLS_EXT_EARLY_DATA = ssl_tls13_early_data_xtn, + TLS_EXT_VERSIONS = ssl_tls13_supported_versions_xtn, + TLS_EXT_COOKIE = ssl_tls13_cookie_xtn, + TLS_EXT_PSK_MODES = ssl_tls13_psk_key_exchange_modes_xtn, + TLS_EXT_CA = ssl_tls13_certificate_authorities_xtn, + TLS_EXT_POST_HS_AUTH = ssl_tls13_post_handshake_auth_xtn, + TLS_EXT_CERT_SIG_SCHEMES = ssl_signature_algorithms_cert_xtn, + TLS_EXT_KEY_SHARE = ssl_tls13_key_share_xtn, + TLS_EXT_RENEGOTIATION_INFO = ssl_renegotiation_info_xtn, + } +} + +remap_enum! { + SignatureScheme: u16 => ssl::SSLSignatureScheme { + TLS_SIG_NONE = ssl_sig_none, + TLS_SIG_RSA_PKCS1_SHA256 = ssl_sig_rsa_pkcs1_sha256, + TLS_SIG_RSA_PKCS1_SHA384 = ssl_sig_rsa_pkcs1_sha384, + TLS_SIG_RSA_PKCS1_SHA512 = ssl_sig_rsa_pkcs1_sha512, + TLS_SIG_ECDSA_SECP256R1_SHA256 = ssl_sig_ecdsa_secp256r1_sha256, + TLS_SIG_ECDSA_SECP384R1_SHA384 = ssl_sig_ecdsa_secp384r1_sha384, + TLS_SIG_ECDSA_SECP512R1_SHA512 = ssl_sig_ecdsa_secp521r1_sha512, + TLS_SIG_RSA_PSS_RSAE_SHA256 = ssl_sig_rsa_pss_rsae_sha256, + TLS_SIG_RSA_PSS_RSAE_SHA384 = ssl_sig_rsa_pss_rsae_sha384, + TLS_SIG_RSA_PSS_RSAE_SHA512 = ssl_sig_rsa_pss_rsae_sha512, + TLS_SIG_ED25519 = ssl_sig_ed25519, + TLS_SIG_ED448 = ssl_sig_ed448, + TLS_SIG_RSA_PSS_PSS_SHA256 = ssl_sig_rsa_pss_pss_sha256, + TLS_SIG_RSA_PSS_PSS_SHA384 = ssl_sig_rsa_pss_pss_sha384, + TLS_SIG_RSA_PSS_PSS_SHA512 = ssl_sig_rsa_pss_pss_sha512, + } +} diff --git a/third_party/rust/neqo-crypto/src/ech.rs b/third_party/rust/neqo-crypto/src/ech.rs new file mode 100644 index 0000000000..1f54c4592e --- /dev/null +++ b/third_party/rust/neqo-crypto/src/ech.rs @@ -0,0 +1,204 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + convert::TryFrom, + ffi::CString, + os::raw::{c_char, c_uint}, + ptr::{addr_of_mut, null_mut}, +}; + +use neqo_common::qtrace; + +use crate::{ + err::{ssl::SSL_ERROR_ECH_RETRY_WITH_ECH, Error, Res}, + experimental_api, + p11::{ + self, Item, PrivateKey, PublicKey, SECITEM_FreeItem, SECItem, SECKEYPrivateKey, + SECKEYPublicKey, Slot, + }, + ssl::{PRBool, PRFileDesc}, +}; +pub use crate::{ + p11::{HpkeAeadId as AeadId, HpkeKdfId as KdfId, HpkeKemId as KemId}, + ssl::HpkeSymmetricSuite as SymmetricSuite, +}; + +experimental_api!(SSL_EnableTls13GreaseEch( + fd: *mut PRFileDesc, + enabled: PRBool, +)); + +experimental_api!(SSL_GetEchRetryConfigs( + fd: *mut PRFileDesc, + config: *mut SECItem, +)); + +experimental_api!(SSL_SetClientEchConfigs( + fd: *mut PRFileDesc, + config_list: *const u8, + config_list_len: c_uint, +)); + +experimental_api!(SSL_SetServerEchConfigs( + fd: *mut PRFileDesc, + pk: *const SECKEYPublicKey, + sk: *const SECKEYPrivateKey, + record: *const u8, + record_len: c_uint, +)); + +experimental_api!(SSL_EncodeEchConfigId( + config_id: u8, + public_name: *const c_char, + max_name_len: c_uint, + kem_id: KemId::Type, + pk: *const SECKEYPublicKey, + hpke_suites: *const SymmetricSuite, + hpke_suite_count: c_uint, + out: *mut u8, + out_len: *mut c_uint, + max_len: c_uint, +)); + +/// Convert any result that contains an ECH error into a result with an `EchRetry`. +pub fn convert_ech_error(fd: *mut PRFileDesc, err: Error) -> Error { + if let Error::NssError { + code: SSL_ERROR_ECH_RETRY_WITH_ECH, + .. + } = &err + { + let mut item = Item::make_empty(); + if unsafe { SSL_GetEchRetryConfigs(fd, &mut item).is_err() } { + return Error::InternalError; + } + let buf = unsafe { + let slc = std::slice::from_raw_parts(item.data, usize::try_from(item.len).unwrap()); + let buf = Vec::from(slc); + SECITEM_FreeItem(&mut item, PRBool::from(false)); + buf + }; + Error::EchRetry(buf) + } else { + err + } +} + +/// Generate a key pair for encrypted client hello (ECH). +/// +/// # Errors +/// +/// When NSS fails to generate a key pair or when the KEM is not supported. +/// +/// # Panics +/// +/// When underlying types aren't large enough to hold keys. So never. +pub fn generate_keys() -> Res<(PrivateKey, PublicKey)> { + let slot = Slot::internal()?; + + let oid_data = unsafe { p11::SECOID_FindOIDByTag(p11::SECOidTag::SEC_OID_CURVE25519) }; + let oid = unsafe { oid_data.as_ref() }.ok_or(Error::InternalError)?; + let oid_slc = + unsafe { std::slice::from_raw_parts(oid.oid.data, usize::try_from(oid.oid.len).unwrap()) }; + let mut params: Vec<u8> = Vec::with_capacity(oid_slc.len() + 2); + params.push(u8::try_from(p11::SEC_ASN1_OBJECT_ID).unwrap()); + params.push(u8::try_from(oid.oid.len).unwrap()); + params.extend_from_slice(oid_slc); + + let mut public_ptr: *mut SECKEYPublicKey = null_mut(); + let mut param_item = Item::wrap(¶ms); + + // If we have tracing on, try to ensure that key data can be read. + let insensitive_secret_ptr = if log::log_enabled!(log::Level::Trace) { + #[allow(clippy::useless_conversion)] // TODO: Remove when we bump the MSRV to 1.74.0. + unsafe { + p11::PK11_GenerateKeyPairWithOpFlags( + *slot, + p11::CK_MECHANISM_TYPE::from(p11::CKM_EC_KEY_PAIR_GEN), + addr_of_mut!(param_item).cast(), + &mut public_ptr, + p11::PK11_ATTR_SESSION | p11::PK11_ATTR_INSENSITIVE | p11::PK11_ATTR_PUBLIC, + p11::CK_FLAGS::from(p11::CKF_DERIVE), + p11::CK_FLAGS::from(p11::CKF_DERIVE), + null_mut(), + ) + } + } else { + null_mut() + }; + assert_eq!(insensitive_secret_ptr.is_null(), public_ptr.is_null()); + let secret_ptr = if insensitive_secret_ptr.is_null() { + #[allow(clippy::useless_conversion)] // TODO: Remove when we bump the MSRV to 1.74.0. + unsafe { + p11::PK11_GenerateKeyPairWithOpFlags( + *slot, + p11::CK_MECHANISM_TYPE::from(p11::CKM_EC_KEY_PAIR_GEN), + addr_of_mut!(param_item).cast(), + &mut public_ptr, + p11::PK11_ATTR_SESSION | p11::PK11_ATTR_SENSITIVE | p11::PK11_ATTR_PRIVATE, + p11::CK_FLAGS::from(p11::CKF_DERIVE), + p11::CK_FLAGS::from(p11::CKF_DERIVE), + null_mut(), + ) + } + } else { + insensitive_secret_ptr + }; + assert_eq!(secret_ptr.is_null(), public_ptr.is_null()); + let sk = PrivateKey::from_ptr(secret_ptr)?; + let pk = PublicKey::from_ptr(public_ptr)?; + qtrace!("Generated key pair: sk={:?} pk={:?}", sk, pk); + Ok((sk, pk)) +} + +/// Encode a configuration for encrypted client hello (ECH). +/// +/// # Errors +/// +/// When NSS fails to generate a valid configuration encoding (i.e., unlikely). +pub fn encode_config(config: u8, public_name: &str, pk: &PublicKey) -> Res<Vec<u8>> { + // A sensible fixed value for the maximum length of a name. + const MAX_NAME_LEN: c_uint = 64; + // Enable a selection of suites. + // NSS supports SHA-512 as well, which could be added here. + const SUITES: &[SymmetricSuite] = &[ + SymmetricSuite { + kdfId: KdfId::HpkeKdfHkdfSha256, + aeadId: AeadId::HpkeAeadAes128Gcm, + }, + SymmetricSuite { + kdfId: KdfId::HpkeKdfHkdfSha256, + aeadId: AeadId::HpkeAeadChaCha20Poly1305, + }, + SymmetricSuite { + kdfId: KdfId::HpkeKdfHkdfSha384, + aeadId: AeadId::HpkeAeadAes128Gcm, + }, + SymmetricSuite { + kdfId: KdfId::HpkeKdfHkdfSha384, + aeadId: AeadId::HpkeAeadChaCha20Poly1305, + }, + ]; + + let name = CString::new(public_name)?; + let mut encoded = [0; 1024]; + let mut encoded_len = 0; + unsafe { + SSL_EncodeEchConfigId( + config, + name.as_ptr(), + MAX_NAME_LEN, + KemId::HpkeDhKemX25519Sha256, + **pk, + SUITES.as_ptr(), + c_uint::try_from(SUITES.len())?, + encoded.as_mut_ptr(), + &mut encoded_len, + c_uint::try_from(encoded.len())?, + )?; + } + Ok(Vec::from(&encoded[..usize::try_from(encoded_len)?])) +} diff --git a/third_party/rust/neqo-crypto/src/err.rs b/third_party/rust/neqo-crypto/src/err.rs new file mode 100644 index 0000000000..187303d2a9 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/err.rs @@ -0,0 +1,214 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code)] +#![allow(clippy::upper_case_acronyms)] + +use std::{os::raw::c_char, str::Utf8Error}; + +use crate::ssl::{SECStatus, SECSuccess}; + +include!(concat!(env!("OUT_DIR"), "/nspr_error.rs")); +mod codes { + #![allow(non_snake_case)] + include!(concat!(env!("OUT_DIR"), "/nss_secerr.rs")); + include!(concat!(env!("OUT_DIR"), "/nss_sslerr.rs")); + include!(concat!(env!("OUT_DIR"), "/mozpkix.rs")); +} +pub use codes::{mozilla_pkix_ErrorCode as mozpkix, SECErrorCodes as sec, SSLErrorCodes as ssl}; +pub mod nspr { + include!(concat!(env!("OUT_DIR"), "/nspr_err.rs")); +} + +pub type Res<T> = Result<T, Error>; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +pub enum Error { + AeadError, + CertificateLoading, + CipherInitFailure, + CreateSslSocket, + EchRetry(Vec<u8>), + HkdfError, + InternalError, + IntegerOverflow, + InvalidEpoch, + MixedHandshakeMethod, + NoDataAvailable, + NssError { + name: String, + code: PRErrorCode, + desc: String, + }, + OverrunError, + SelfEncryptFailure, + StringError, + TimeTravelError, + UnsupportedCipher, + UnsupportedVersion, +} + +impl Error { + pub(crate) fn last_nss_error() -> Self { + Self::from(unsafe { PR_GetError() }) + } +} + +impl std::error::Error for Error { + #[must_use] + fn cause(&self) -> Option<&dyn std::error::Error> { + None + } + #[must_use] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Error: {self:?}") + } +} + +impl From<std::num::TryFromIntError> for Error { + #[must_use] + fn from(_: std::num::TryFromIntError) -> Self { + Self::IntegerOverflow + } +} +impl From<std::ffi::NulError> for Error { + #[must_use] + fn from(_: std::ffi::NulError) -> Self { + Self::InternalError + } +} +impl From<Utf8Error> for Error { + fn from(_: Utf8Error) -> Self { + Self::StringError + } +} +impl From<PRErrorCode> for Error { + fn from(code: PRErrorCode) -> Self { + let name = wrap_str_fn(|| unsafe { PR_ErrorToName(code) }, "UNKNOWN_ERROR"); + let desc = wrap_str_fn( + || unsafe { PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT) }, + "...", + ); + Self::NssError { name, code, desc } + } +} + +use std::ffi::CStr; + +fn wrap_str_fn<F>(f: F, dflt: &str) -> String +where + F: FnOnce() -> *const c_char, +{ + unsafe { + let p = f(); + if p.is_null() { + return dflt.to_string(); + } + CStr::from_ptr(p).to_string_lossy().into_owned() + } +} + +pub fn secstatus_to_res(rv: SECStatus) -> Res<()> { + if rv == SECSuccess { + Ok(()) + } else { + Err(Error::last_nss_error()) + } +} + +pub fn is_blocked(result: &Res<()>) -> bool { + match result { + Err(Error::NssError { code, .. }) => *code == nspr::PR_WOULD_BLOCK_ERROR, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use test_fixture::fixture_init; + + use crate::{ + err::{self, is_blocked, secstatus_to_res, Error, PRErrorCode, PR_SetError}, + ssl::{SECFailure, SECSuccess}, + }; + + fn set_error_code(code: PRErrorCode) { + // This code doesn't work without initializing NSS first. + fixture_init(); + unsafe { + PR_SetError(code, 0); + } + } + + #[test] + fn error_code() { + fixture_init(); + assert_eq!(15 - 0x3000, err::ssl::SSL_ERROR_BAD_MAC_READ); + assert_eq!(166 - 0x2000, err::sec::SEC_ERROR_LIBPKIX_INTERNAL); + assert_eq!(-5998, err::nspr::PR_WOULD_BLOCK_ERROR); + } + + #[test] + fn is_ok() { + assert!(secstatus_to_res(SECSuccess).is_ok()); + } + + #[test] + fn is_err() { + set_error_code(err::ssl::SSL_ERROR_BAD_MAC_READ); + let r = secstatus_to_res(SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "SSL_ERROR_BAD_MAC_READ"); + assert_eq!(code, -12273); + assert_eq!( + desc, + "SSL received a record with an incorrect Message Authentication Code." + ); + } + _ => unreachable!(), + } + } + + #[test] + fn is_err_zero_code() { + set_error_code(0); + let r = secstatus_to_res(SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, .. } => { + assert_eq!(name, "UNKNOWN_ERROR"); + assert_eq!(code, 0); + // Note that we don't test |desc| here because that comes from + // strerror(0), which is platform-dependent. + } + _ => unreachable!(), + } + } + + #[test] + fn blocked() { + set_error_code(err::nspr::PR_WOULD_BLOCK_ERROR); + let r = secstatus_to_res(SECFailure); + assert!(r.is_err()); + assert!(is_blocked(&r)); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "PR_WOULD_BLOCK_ERROR"); + assert_eq!(code, -5998); + assert_eq!(desc, "The operation would have blocked"); + } + _ => panic!("bad error type"), + } + } +} diff --git a/third_party/rust/neqo-crypto/src/exp.rs b/third_party/rust/neqo-crypto/src/exp.rs new file mode 100644 index 0000000000..75867d80bb --- /dev/null +++ b/third_party/rust/neqo-crypto/src/exp.rs @@ -0,0 +1,24 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[macro_export] +macro_rules! experimental_api { + ( $n:ident ( $( $a:ident : $t:ty ),* $(,)? ) ) => { + #[allow(non_snake_case)] + #[allow(clippy::too_many_arguments)] + pub(crate) unsafe fn $n ( $( $a : $t ),* ) -> Result<(), $crate::err::Error> { + const EXP_FUNCTION: &str = stringify!($n); + let n = ::std::ffi::CString::new(EXP_FUNCTION)?; + let f = $crate::ssl::SSL_GetExperimentalAPI(n.as_ptr()); + if f.is_null() { + return Err($crate::err::Error::InternalError); + } + let f: unsafe extern "C" fn( $( $t ),* ) -> $crate::ssl::SECStatus = ::std::mem::transmute(f); + let rv = f( $( $a ),* ); + $crate::err::secstatus_to_res(rv) + } + }; +} diff --git a/third_party/rust/neqo-crypto/src/ext.rs b/third_party/rust/neqo-crypto/src/ext.rs new file mode 100644 index 0000000000..310e87a1b7 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/ext.rs @@ -0,0 +1,169 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + convert::TryFrom, + os::raw::{c_uint, c_void}, + pin::Pin, + rc::Rc, +}; + +use crate::{ + agentio::as_c_void, + constants::{Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS}, + err::Res, + ssl::{ + PRBool, PRFileDesc, SECFailure, SECStatus, SECSuccess, SSLAlertDescription, + SSLExtensionHandler, SSLExtensionWriter, SSLHandshakeType, + }, +}; + +experimental_api!(SSL_InstallExtensionHooks( + fd: *mut PRFileDesc, + extension: u16, + writer: SSLExtensionWriter, + writer_arg: *mut c_void, + handler: SSLExtensionHandler, + handler_arg: *mut c_void, +)); + +pub enum ExtensionWriterResult { + Write(usize), + Skip, +} + +pub enum ExtensionHandlerResult { + Ok, + Alert(crate::constants::Alert), +} + +pub trait ExtensionHandler { + fn write(&mut self, msg: HandshakeMessage, _d: &mut [u8]) -> ExtensionWriterResult { + match msg { + TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0), + _ => ExtensionWriterResult::Skip, + } + } + + fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult { + match msg { + TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok, + _ => ExtensionHandlerResult::Alert(110), // unsupported_extension + } + } +} + +type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>; + +pub struct ExtensionTracker { + extension: Extension, + handler: Pin<Box<BoxedExtensionHandler>>, +} + +impl ExtensionTracker { + // Technically the as_mut() call here is the only unsafe bit, + // but don't call this function lightly. + unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T + where + F: FnOnce(&mut dyn ExtensionHandler) -> T, + { + let rc = arg.cast::<BoxedExtensionHandler>().as_mut().unwrap(); + f(&mut *rc.borrow_mut()) + } + + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + unsafe extern "C" fn extension_writer( + _fd: *mut PRFileDesc, + message: SSLHandshakeType::Type, + data: *mut u8, + len: *mut c_uint, + max_len: c_uint, + arg: *mut c_void, + ) -> PRBool { + let d = std::slice::from_raw_parts_mut(data, max_len as usize); + Self::wrap_handler_call(arg, |handler| { + // Cast is safe here because the message type is always part of the enum + match handler.write(message as HandshakeMessage, d) { + ExtensionWriterResult::Write(sz) => { + *len = c_uint::try_from(sz).expect("integer overflow from extension writer"); + 1 + } + ExtensionWriterResult::Skip => 0, + } + }) + } + + unsafe extern "C" fn extension_handler( + _fd: *mut PRFileDesc, + message: SSLHandshakeType::Type, + data: *const u8, + len: c_uint, + alert: *mut SSLAlertDescription, + arg: *mut c_void, + ) -> SECStatus { + let d = std::slice::from_raw_parts(data, len as usize); + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + Self::wrap_handler_call(arg, |handler| { + // Cast is safe here because the message type is always part of the enum + match handler.handle(message as HandshakeMessage, d) { + ExtensionHandlerResult::Ok => SECSuccess, + ExtensionHandlerResult::Alert(a) => { + *alert = a; + SECFailure + } + } + }) + } + + /// Use the provided handler to manage an extension. This is quite unsafe. + /// + /// # Safety + /// + /// The holder of this `ExtensionTracker` needs to ensure that it lives at + /// least as long as the file descriptor, as NSS provides no way to remove + /// an extension handler once it is configured. + /// + /// # Errors + /// + /// If the underlying NSS API fails to register a handler. + pub unsafe fn new( + fd: *mut PRFileDesc, + extension: Extension, + handler: Rc<RefCell<dyn ExtensionHandler>>, + ) -> Res<Self> { + // The ergonomics here aren't great for users of this API, but it's + // horrific here. The pinned outer box gives us a stable pointer to the inner + // box. This is the pointer that is passed to NSS. + // + // The inner box points to the reference-counted object. This inner box is + // what we end up with a reference to in callbacks. That extra wrapper around + // the Rc avoid any touching of reference counts in callbacks, which would + // inevitably lead to leaks as we don't control how many times the callback + // is invoked. + // + // This way, only this "outer" code deals with the reference count. + let mut tracker = Self { + extension, + handler: Box::pin(Box::new(handler)), + }; + SSL_InstallExtensionHooks( + fd, + extension, + Some(Self::extension_writer), + as_c_void(&mut tracker.handler), + Some(Self::extension_handler), + as_c_void(&mut tracker.handler), + )?; + Ok(tracker) + } +} + +impl std::fmt::Debug for ExtensionTracker { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "ExtensionTracker: {:?}", self.extension) + } +} diff --git a/third_party/rust/neqo-crypto/src/hkdf.rs b/third_party/rust/neqo-crypto/src/hkdf.rs new file mode 100644 index 0000000000..e3cf77418c --- /dev/null +++ b/third_party/rust/neqo-crypto/src/hkdf.rs @@ -0,0 +1,137 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + convert::TryFrom, + os::raw::{c_char, c_uint}, + ptr::null_mut, +}; + +use crate::{ + constants::{ + Cipher, Version, TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, TLS_VERSION_1_3, + }, + err::{Error, Res}, + p11::{ + random, Item, PK11Origin, PK11SymKey, PK11_ImportDataKey, Slot, SymKey, CKA_DERIVE, + CKM_HKDF_DERIVE, CK_ATTRIBUTE_TYPE, CK_MECHANISM_TYPE, + }, +}; + +experimental_api!(SSL_HkdfExtract( + version: Version, + cipher: Cipher, + salt: *mut PK11SymKey, + ikm: *mut PK11SymKey, + prk: *mut *mut PK11SymKey, +)); +experimental_api!(SSL_HkdfExpandLabel( + version: Version, + cipher: Cipher, + prk: *mut PK11SymKey, + handshake_hash: *const u8, + handshake_hash_len: c_uint, + label: *const c_char, + label_len: c_uint, + secret: *mut *mut PK11SymKey, +)); + +fn key_size(version: Version, cipher: Cipher) -> Res<usize> { + if version != TLS_VERSION_1_3 { + return Err(Error::UnsupportedVersion); + } + Ok(match cipher { + TLS_AES_128_GCM_SHA256 | TLS_CHACHA20_POLY1305_SHA256 => 32, + TLS_AES_256_GCM_SHA384 => 48, + _ => return Err(Error::UnsupportedCipher), + }) +} + +/// Generate a random key of the right size for the given suite. +/// +/// # Errors +/// +/// Only if NSS fails. +pub fn generate_key(version: Version, cipher: Cipher) -> Res<SymKey> { + import_key(version, &random(key_size(version, cipher)?)) +} + +/// Import a symmetric key for use with HKDF. +/// +/// # Errors +/// +/// Errors returned if the key buffer is an incompatible size or the NSS functions fail. +pub fn import_key(version: Version, buf: &[u8]) -> Res<SymKey> { + if version != TLS_VERSION_1_3 { + return Err(Error::UnsupportedVersion); + } + let slot = Slot::internal()?; + #[allow(clippy::useless_conversion)] // TODO: Remove when we bump the MSRV to 1.74.0. + let key_ptr = unsafe { + PK11_ImportDataKey( + *slot, + CK_MECHANISM_TYPE::from(CKM_HKDF_DERIVE), + PK11Origin::PK11_OriginUnwrap, + CK_ATTRIBUTE_TYPE::from(CKA_DERIVE), + &mut Item::wrap(buf), + null_mut(), + ) + }; + SymKey::from_ptr(key_ptr) +} + +/// Extract a PRK from the given salt and IKM using the algorithm defined in RFC 5869. +/// +/// # Errors +/// +/// Errors returned if inputs are too large or the NSS functions fail. +pub fn extract( + version: Version, + cipher: Cipher, + salt: Option<&SymKey>, + ikm: &SymKey, +) -> Res<SymKey> { + let mut prk: *mut PK11SymKey = null_mut(); + let salt_ptr: *mut PK11SymKey = match salt { + Some(s) => **s, + None => null_mut(), + }; + unsafe { SSL_HkdfExtract(version, cipher, salt_ptr, **ikm, &mut prk) }?; + SymKey::from_ptr(prk) +} + +/// Expand a PRK using the HKDF-Expand-Label function defined in RFC 8446. +/// +/// # Errors +/// +/// Errors returned if inputs are too large or the NSS functions fail. +pub fn expand_label( + version: Version, + cipher: Cipher, + prk: &SymKey, + handshake_hash: &[u8], + label: &str, +) -> Res<SymKey> { + let l = label.as_bytes(); + let mut secret: *mut PK11SymKey = null_mut(); + + // Note that this doesn't allow for passing null() for the handshake hash. + // A zero-length slice produces an identical result. + unsafe { + SSL_HkdfExpandLabel( + version, + cipher, + **prk, + handshake_hash.as_ptr(), + c_uint::try_from(handshake_hash.len())?, + l.as_ptr().cast(), + c_uint::try_from(l.len())?, + &mut secret, + ) + }?; + SymKey::from_ptr(secret) +} diff --git a/third_party/rust/neqo-crypto/src/hp.rs b/third_party/rust/neqo-crypto/src/hp.rs new file mode 100644 index 0000000000..2479eff8f5 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/hp.rs @@ -0,0 +1,203 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + cell::RefCell, + convert::TryFrom, + fmt::{self, Debug}, + os::raw::{c_char, c_int, c_uint}, + ptr::{addr_of_mut, null, null_mut}, + rc::Rc, +}; + +use crate::{ + constants::{ + Cipher, Version, TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + }, + err::{secstatus_to_res, Error, Res}, + p11::{ + Context, Item, PK11SymKey, PK11_CipherOp, PK11_CreateContextBySymKey, PK11_Encrypt, + PK11_GetBlockSize, SymKey, CKA_ENCRYPT, CKM_AES_ECB, CKM_CHACHA20, CK_ATTRIBUTE_TYPE, + CK_CHACHA20_PARAMS, CK_MECHANISM_TYPE, + }, +}; + +experimental_api!(SSL_HkdfExpandLabelWithMech( + version: Version, + cipher: Cipher, + prk: *mut PK11SymKey, + handshake_hash: *const u8, + handshake_hash_len: c_uint, + label: *const c_char, + label_len: c_uint, + mech: CK_MECHANISM_TYPE, + key_size: c_uint, + secret: *mut *mut PK11SymKey, +)); + +#[derive(Clone)] +pub enum HpKey { + /// An AES encryption context. + /// Note: as we need to clone this object, we clone the pointer and + /// track references using `Rc`. `PK11Context` can't be used with `PK11_CloneContext` + /// as that is not supported for these contexts. + Aes(Rc<RefCell<Context>>), + /// The ChaCha20 mask has to invoke a new PK11_Encrypt every time as it needs to + /// change the counter and nonce on each invocation. + Chacha(SymKey), +} + +impl Debug for HpKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "HpKey") + } +} + +impl HpKey { + const SAMPLE_SIZE: usize = 16; + + /// QUIC-specific API for extracting a header-protection key. + /// + /// # Errors + /// + /// Errors if HKDF fails or if the label is too long to fit in a `c_uint`. + /// + /// # Panics + /// + /// When `cipher` is not known to this code. + #[allow(clippy::cast_sign_loss)] // Cast for PK11_GetBlockSize is safe. + pub fn extract(version: Version, cipher: Cipher, prk: &SymKey, label: &str) -> Res<Self> { + const ZERO: &[u8] = &[0; 12]; + + let l = label.as_bytes(); + let mut secret: *mut PK11SymKey = null_mut(); + + #[allow(clippy::useless_conversion)] // TODO: Remove when we bump the MSRV to 1.74.0. + let (mech, key_size) = match cipher { + TLS_AES_128_GCM_SHA256 => (CK_MECHANISM_TYPE::from(CKM_AES_ECB), 16), + TLS_AES_256_GCM_SHA384 => (CK_MECHANISM_TYPE::from(CKM_AES_ECB), 32), + TLS_CHACHA20_POLY1305_SHA256 => (CK_MECHANISM_TYPE::from(CKM_CHACHA20), 32), + _ => unreachable!(), + }; + + // Note that this doesn't allow for passing null() for the handshake hash. + // A zero-length slice produces an identical result. + unsafe { + SSL_HkdfExpandLabelWithMech( + version, + cipher, + **prk, + null(), + 0, + l.as_ptr().cast(), + c_uint::try_from(l.len())?, + mech, + key_size, + &mut secret, + ) + }?; + let key = SymKey::from_ptr(secret).or(Err(Error::HkdfError))?; + + let res = match cipher { + TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => { + // TODO: Remove when we bump the MSRV to 1.74.0. + #[allow(clippy::useless_conversion)] + let context_ptr = unsafe { + PK11_CreateContextBySymKey( + mech, + CK_ATTRIBUTE_TYPE::from(CKA_ENCRYPT), + *key, + &Item::wrap(&ZERO[..0]), // Borrow a zero-length slice of ZERO. + ) + }; + let context = Context::from_ptr(context_ptr).or(Err(Error::CipherInitFailure))?; + Self::Aes(Rc::new(RefCell::new(context))) + } + TLS_CHACHA20_POLY1305_SHA256 => Self::Chacha(key), + _ => unreachable!(), + }; + + debug_assert_eq!( + res.block_size(), + usize::try_from(unsafe { PK11_GetBlockSize(mech, null_mut()) }).unwrap() + ); + Ok(res) + } + + /// Get the sample size, which is also the output size. + #[must_use] + #[allow(clippy::unused_self)] // To maintain an API contract. + pub fn sample_size(&self) -> usize { + Self::SAMPLE_SIZE + } + + fn block_size(&self) -> usize { + match self { + Self::Aes(_) => 16, + Self::Chacha(_) => 64, + } + } + + /// Generate a header protection mask for QUIC. + /// + /// # Errors + /// + /// An error is returned if the NSS functions fail; a sample of the + /// wrong size is the obvious cause. + /// + /// # Panics + /// + /// When the mechanism for our key is not supported. + pub fn mask(&self, sample: &[u8]) -> Res<Vec<u8>> { + let mut output = vec![0_u8; self.block_size()]; + + match self { + Self::Aes(context) => { + let mut output_len: c_int = 0; + secstatus_to_res(unsafe { + PK11_CipherOp( + **context.borrow_mut(), + output.as_mut_ptr(), + &mut output_len, + c_int::try_from(output.len())?, + sample[..Self::SAMPLE_SIZE].as_ptr().cast(), + c_int::try_from(Self::SAMPLE_SIZE).unwrap(), + ) + })?; + assert_eq!(usize::try_from(output_len).unwrap(), output.len()); + Ok(output) + } + + Self::Chacha(key) => { + let params: CK_CHACHA20_PARAMS = CK_CHACHA20_PARAMS { + pBlockCounter: sample.as_ptr().cast_mut(), + blockCounterBits: 32, + pNonce: sample[4..Self::SAMPLE_SIZE].as_ptr().cast_mut(), + ulNonceBits: 96, + }; + let mut output_len: c_uint = 0; + let mut param_item = Item::wrap_struct(¶ms); + // TODO: Remove when we bump the MSRV to 1.74.0. + #[allow(clippy::useless_conversion)] + secstatus_to_res(unsafe { + PK11_Encrypt( + **key, + CK_MECHANISM_TYPE::from(CKM_CHACHA20), + addr_of_mut!(param_item), + output[..].as_mut_ptr(), + &mut output_len, + c_uint::try_from(output.len())?, + output[..].as_ptr(), + c_uint::try_from(output.len())?, + ) + })?; + assert_eq!(usize::try_from(output_len).unwrap(), output.len()); + Ok(output) + } + } + } +} diff --git a/third_party/rust/neqo-crypto/src/lib.rs b/third_party/rust/neqo-crypto/src/lib.rs new file mode 100644 index 0000000000..05424ee1f3 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/lib.rs @@ -0,0 +1,208 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::pedantic)] +// Bindgen auto generated code +// won't adhere to the clippy rules below +#![allow(clippy::module_name_repetitions)] +#![allow(clippy::unseparated_literal_suffix)] +#![allow(clippy::used_underscore_binding)] + +mod aead; +#[cfg(feature = "fuzzing")] +mod aead_fuzzing; +pub mod agent; +mod agentio; +mod auth; +mod cert; +pub mod constants; +mod ech; +mod err; +#[macro_use] +mod exp; +pub mod ext; +pub mod hkdf; +pub mod hp; +mod once; +#[macro_use] +mod p11; +mod prio; +mod replay; +mod secrets; +pub mod selfencrypt; +mod ssl; +mod time; + +use std::{ + ffi::CString, + path::{Path, PathBuf}, + ptr::null, +}; + +#[cfg(not(feature = "fuzzing"))] +pub use self::aead::RealAead as Aead; +#[cfg(feature = "fuzzing")] +pub use self::aead::RealAead; +#[cfg(feature = "fuzzing")] +pub use self::aead_fuzzing::FuzzingAead as Aead; +use self::once::OnceResult; +pub use self::{ + agent::{ + Agent, AllowZeroRtt, Client, HandshakeState, Record, RecordList, ResumptionToken, + SecretAgent, SecretAgentInfo, SecretAgentPreInfo, Server, ZeroRttCheckResult, + ZeroRttChecker, + }, + auth::AuthenticationStatus, + constants::*, + ech::{ + encode_config as encode_ech_config, generate_keys as generate_ech_keys, AeadId, KdfId, + KemId, SymmetricSuite, + }, + err::{Error, PRErrorCode, Res}, + ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult}, + p11::{random, PrivateKey, PublicKey, SymKey}, + replay::AntiReplay, + secrets::SecretDirection, + ssl::Opt, +}; + +const MINIMUM_NSS_VERSION: &str = "3.97"; + +#[allow(non_upper_case_globals, clippy::redundant_static_lifetimes)] +#[allow(clippy::upper_case_acronyms)] +#[allow(unknown_lints, clippy::borrow_as_ptr)] +mod nss { + include!(concat!(env!("OUT_DIR"), "/nss_init.rs")); +} + +// Need to map the types through. +fn secstatus_to_res(code: nss::SECStatus) -> Res<()> { + crate::err::secstatus_to_res(code as crate::ssl::SECStatus) +} + +enum NssLoaded { + External, + NoDb, + Db(Box<Path>), +} + +impl Drop for NssLoaded { + fn drop(&mut self) { + if !matches!(self, Self::External) { + unsafe { + secstatus_to_res(nss::NSS_Shutdown()).expect("NSS Shutdown failed"); + } + } + } +} + +static mut INITIALIZED: OnceResult<NssLoaded> = OnceResult::new(); + +fn already_initialized() -> bool { + unsafe { nss::NSS_IsInitialized() != 0 } +} + +fn version_check() { + let min_ver = CString::new(MINIMUM_NSS_VERSION).unwrap(); + assert_ne!( + unsafe { nss::NSS_VersionCheck(min_ver.as_ptr()) }, + 0, + "Minimum NSS version of {MINIMUM_NSS_VERSION} not supported", + ); +} + +/// Initialize NSS. This only executes the initialization routines once, so if there is any chance +/// that +/// +/// # Panics +/// +/// When NSS initialization fails. +pub fn init() { + // Set time zero. + time::init(); + unsafe { + INITIALIZED.call_once(|| { + version_check(); + if already_initialized() { + return NssLoaded::External; + } + + secstatus_to_res(nss::NSS_NoDB_Init(null())).expect("NSS_NoDB_Init failed"); + secstatus_to_res(nss::NSS_SetDomesticPolicy()).expect("NSS_SetDomesticPolicy failed"); + + NssLoaded::NoDb + }); + } +} + +/// This enables SSLTRACE by calling a simple, harmless function to trigger its +/// side effects. SSLTRACE is not enabled in NSS until a socket is made or +/// global options are accessed. Reading an option is the least impact approach. +/// This allows us to use SSLTRACE in all of our unit tests and programs. +#[cfg(debug_assertions)] +fn enable_ssl_trace() { + let opt = ssl::Opt::Locking.as_int(); + let mut v: ::std::os::raw::c_int = 0; + secstatus_to_res(unsafe { ssl::SSL_OptionGetDefault(opt, &mut v) }) + .expect("SSL_OptionGetDefault failed"); +} + +/// Initialize with a database. +/// +/// # Panics +/// +/// If NSS cannot be initialized. +pub fn init_db<P: Into<PathBuf>>(dir: P) { + time::init(); + unsafe { + INITIALIZED.call_once(|| { + version_check(); + if already_initialized() { + return NssLoaded::External; + } + + let path = dir.into(); + assert!(path.is_dir()); + let pathstr = path.to_str().expect("path converts to string").to_string(); + let dircstr = CString::new(pathstr).unwrap(); + let empty = CString::new("").unwrap(); + secstatus_to_res(nss::NSS_Initialize( + dircstr.as_ptr(), + empty.as_ptr(), + empty.as_ptr(), + nss::SECMOD_DB.as_ptr().cast(), + nss::NSS_INIT_READONLY, + )) + .expect("NSS_Initialize failed"); + + secstatus_to_res(nss::NSS_SetDomesticPolicy()).expect("NSS_SetDomesticPolicy failed"); + secstatus_to_res(ssl::SSL_ConfigServerSessionIDCache( + 1024, + 0, + 0, + dircstr.as_ptr(), + )) + .expect("SSL_ConfigServerSessionIDCache failed"); + + #[cfg(debug_assertions)] + enable_ssl_trace(); + + NssLoaded::Db(path.into_boxed_path()) + }); + } +} + +/// # Panics +/// +/// If NSS isn't initialized. +pub fn assert_initialized() { + unsafe { + INITIALIZED.call_once(|| { + panic!("NSS not initialized with init or init_db"); + }); + } +} diff --git a/third_party/rust/neqo-crypto/src/once.rs b/third_party/rust/neqo-crypto/src/once.rs new file mode 100644 index 0000000000..80657cfe26 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/once.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::sync::Once; + +#[allow(clippy::module_name_repetitions)] +pub struct OnceResult<T> { + once: Once, + v: Option<T>, +} + +impl<T> OnceResult<T> { + #[must_use] + pub const fn new() -> Self { + Self { + once: Once::new(), + v: None, + } + } + + pub fn call_once<F: FnOnce() -> T>(&mut self, f: F) -> &T { + let v = &mut self.v; + self.once.call_once(|| { + *v = Some(f()); + }); + self.v.as_ref().unwrap() + } +} + +#[cfg(test)] +mod test { + use super::OnceResult; + + static mut STATIC_ONCE_RESULT: OnceResult<u64> = OnceResult::new(); + + #[test] + fn static_update() { + assert_eq!(*unsafe { STATIC_ONCE_RESULT.call_once(|| 23) }, 23); + assert_eq!(*unsafe { STATIC_ONCE_RESULT.call_once(|| 24) }, 23); + } +} diff --git a/third_party/rust/neqo-crypto/src/p11.rs b/third_party/rust/neqo-crypto/src/p11.rs new file mode 100644 index 0000000000..508d240062 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/p11.rs @@ -0,0 +1,320 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code)] +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +use std::{ + convert::TryFrom, + mem, + ops::{Deref, DerefMut}, + os::raw::{c_int, c_uint}, + ptr::null_mut, +}; + +use neqo_common::hex_with_len; + +use crate::err::{secstatus_to_res, Error, Res}; + +#[allow(clippy::upper_case_acronyms)] +#[allow(clippy::unreadable_literal)] +#[allow(unknown_lints, clippy::borrow_as_ptr)] +mod nss_p11 { + include!(concat!(env!("OUT_DIR"), "/nss_p11.rs")); +} + +pub use nss_p11::*; + +#[macro_export] +macro_rules! scoped_ptr { + ($scoped:ident, $target:ty, $dtor:path) => { + pub struct $scoped { + ptr: *mut $target, + } + + impl $scoped { + /// Create a new instance of `$scoped` from a pointer. + /// + /// # Errors + /// + /// When passed a null pointer generates an error. + pub fn from_ptr(ptr: *mut $target) -> Result<Self, $crate::err::Error> { + if ptr.is_null() { + Err($crate::err::Error::last_nss_error()) + } else { + Ok(Self { ptr }) + } + } + } + + impl Deref for $scoped { + type Target = *mut $target; + #[must_use] + fn deref(&self) -> &*mut $target { + &self.ptr + } + } + + impl DerefMut for $scoped { + fn deref_mut(&mut self) -> &mut *mut $target { + &mut self.ptr + } + } + + impl Drop for $scoped { + #[allow(unused_must_use)] + fn drop(&mut self) { + unsafe { $dtor(self.ptr) }; + } + } + }; +} + +scoped_ptr!(Certificate, CERTCertificate, CERT_DestroyCertificate); +scoped_ptr!(CertList, CERTCertList, CERT_DestroyCertList); +scoped_ptr!(PublicKey, SECKEYPublicKey, SECKEY_DestroyPublicKey); + +impl PublicKey { + /// Get the HPKE serialization of the public key. + /// + /// # Errors + /// + /// When the key cannot be exported, which can be because the type is not supported. + /// + /// # Panics + /// + /// When keys are too large to fit in `c_uint/usize`. So only on programming error. + pub fn key_data(&self) -> Res<Vec<u8>> { + let mut buf = vec![0; 100]; + let mut len: c_uint = 0; + secstatus_to_res(unsafe { + PK11_HPKE_Serialize( + **self, + buf.as_mut_ptr(), + &mut len, + c_uint::try_from(buf.len()).unwrap(), + ) + })?; + buf.truncate(usize::try_from(len).unwrap()); + Ok(buf) + } +} + +impl Clone for PublicKey { + #[must_use] + fn clone(&self) -> Self { + let ptr = unsafe { SECKEY_CopyPublicKey(self.ptr) }; + assert!(!ptr.is_null()); + Self { ptr } + } +} + +impl std::fmt::Debug for PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if let Ok(b) = self.key_data() { + write!(f, "PublicKey {}", hex_with_len(b)) + } else { + write!(f, "Opaque PublicKey") + } + } +} + +scoped_ptr!(PrivateKey, SECKEYPrivateKey, SECKEY_DestroyPrivateKey); + +impl PrivateKey { + /// Get the bits of the private key. + /// + /// # Errors + /// + /// When the key cannot be exported, which can be because the type is not supported + /// or because the key data cannot be extracted from the PKCS#11 module. + /// + /// # Panics + /// + /// When the values are too large to fit. So never. + pub fn key_data(&self) -> Res<Vec<u8>> { + let mut key_item = Item::make_empty(); + #[allow(clippy::useless_conversion)] // TODO: Remove when we bump the MSRV to 1.74.0. + secstatus_to_res(unsafe { + PK11_ReadRawAttribute( + PK11ObjectType::PK11_TypePrivKey, + (**self).cast(), + CK_ATTRIBUTE_TYPE::from(CKA_VALUE), + &mut key_item, + ) + })?; + let slc = unsafe { + std::slice::from_raw_parts(key_item.data, usize::try_from(key_item.len).unwrap()) + }; + let key = Vec::from(slc); + // The data that `key_item` refers to needs to be freed, but we can't + // use the scoped `Item` implementation. This is OK as long as nothing + // panics between `PK11_ReadRawAttribute` succeeding and here. + unsafe { + SECITEM_FreeItem(&mut key_item, PRBool::from(false)); + } + Ok(key) + } +} +unsafe impl Send for PrivateKey {} + +impl Clone for PrivateKey { + #[must_use] + fn clone(&self) -> Self { + let ptr = unsafe { SECKEY_CopyPrivateKey(self.ptr) }; + assert!(!ptr.is_null()); + Self { ptr } + } +} + +impl std::fmt::Debug for PrivateKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if let Ok(b) = self.key_data() { + write!(f, "PrivateKey {}", hex_with_len(b)) + } else { + write!(f, "Opaque PrivateKey") + } + } +} + +scoped_ptr!(Slot, PK11SlotInfo, PK11_FreeSlot); + +impl Slot { + pub fn internal() -> Res<Self> { + let p = unsafe { PK11_GetInternalSlot() }; + Slot::from_ptr(p) + } +} + +scoped_ptr!(SymKey, PK11SymKey, PK11_FreeSymKey); + +impl SymKey { + /// You really don't want to use this. + /// + /// # Errors + /// + /// Internal errors in case of failures in NSS. + pub fn as_bytes(&self) -> Res<&[u8]> { + secstatus_to_res(unsafe { PK11_ExtractKeyValue(self.ptr) })?; + + let key_item = unsafe { PK11_GetKeyData(self.ptr) }; + // This is accessing a value attached to the key, so we can treat this as a borrow. + match unsafe { key_item.as_mut() } { + None => Err(Error::InternalError), + Some(key) => Ok(unsafe { std::slice::from_raw_parts(key.data, key.len as usize) }), + } + } +} + +impl Clone for SymKey { + #[must_use] + fn clone(&self) -> Self { + let ptr = unsafe { PK11_ReferenceSymKey(self.ptr) }; + assert!(!ptr.is_null()); + Self { ptr } + } +} + +impl std::fmt::Debug for SymKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if let Ok(b) = self.as_bytes() { + write!(f, "SymKey {}", hex_with_len(b)) + } else { + write!(f, "Opaque SymKey") + } + } +} + +unsafe fn destroy_pk11_context(ctxt: *mut PK11Context) { + PK11_DestroyContext(ctxt, PRBool::from(true)); +} +scoped_ptr!(Context, PK11Context, destroy_pk11_context); + +unsafe fn destroy_secitem(item: *mut SECItem) { + SECITEM_FreeItem(item, PRBool::from(true)); +} +scoped_ptr!(Item, SECItem, destroy_secitem); + +impl Item { + /// Create a wrapper for a slice of this object. + /// Creating this object is technically safe, but using it is extremely dangerous. + /// Minimally, it can only be passed as a `const SECItem*` argument to functions, + /// or those that treat their argument as `const`. + pub fn wrap(buf: &[u8]) -> SECItem { + SECItem { + type_: SECItemType::siBuffer, + data: buf.as_ptr().cast_mut(), + len: c_uint::try_from(buf.len()).unwrap(), + } + } + + /// Create a wrapper for a struct. + /// Creating this object is technically safe, but using it is extremely dangerous. + /// Minimally, it can only be passed as a `const SECItem*` argument to functions, + /// or those that treat their argument as `const`. + pub fn wrap_struct<T>(v: &T) -> SECItem { + let data: *const T = v; + SECItem { + type_: SECItemType::siBuffer, + data: data.cast_mut().cast(), + len: c_uint::try_from(mem::size_of::<T>()).unwrap(), + } + } + + /// Make an empty `SECItem` for passing as a mutable `SECItem*` argument. + pub fn make_empty() -> SECItem { + SECItem { + type_: SECItemType::siBuffer, + data: null_mut(), + len: 0, + } + } + + /// This dereferences the pointer held by the item and makes a copy of the + /// content that is referenced there. + /// + /// # Safety + /// + /// This dereferences two pointers. It doesn't get much less safe. + pub unsafe fn into_vec(self) -> Vec<u8> { + let b = self.ptr.as_ref().unwrap(); + // Sanity check the type, as some types don't count bytes in `Item::len`. + assert_eq!(b.type_, SECItemType::siBuffer); + let slc = std::slice::from_raw_parts(b.data, usize::try_from(b.len).unwrap()); + Vec::from(slc) + } +} + +/// Generate a randomized buffer. +/// +/// # Panics +/// +/// When `size` is too large or NSS fails. +#[must_use] +pub fn random(size: usize) -> Vec<u8> { + let mut buf = vec![0; size]; + secstatus_to_res(unsafe { + PK11_GenerateRandom(buf.as_mut_ptr(), c_int::try_from(buf.len()).unwrap()) + }) + .unwrap(); + buf +} + +#[cfg(test)] +mod test { + use test_fixture::fixture_init; + + use super::random; + + #[test] + fn randomness() { + fixture_init(); + // If this ever fails, there is either a bug, or it's time to buy a lottery ticket. + assert_ne!(random(16), random(16)); + } +} diff --git a/third_party/rust/neqo-crypto/src/prio.rs b/third_party/rust/neqo-crypto/src/prio.rs new file mode 100644 index 0000000000..527d8739c8 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/prio.rs @@ -0,0 +1,25 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(clippy::upper_case_acronyms)] +#![allow( + dead_code, + non_upper_case_globals, + non_snake_case, + clippy::cognitive_complexity, + clippy::empty_enum, + clippy::too_many_lines, + unknown_lints, + clippy::borrow_as_ptr +)] + +include!(concat!(env!("OUT_DIR"), "/nspr_io.rs")); + +pub enum PRFileInfo {} +pub enum PRFileInfo64 {} +pub enum PRFilePrivate {} +pub enum PRIOVec {} +pub enum PRSendFileData {} diff --git a/third_party/rust/neqo-crypto/src/replay.rs b/third_party/rust/neqo-crypto/src/replay.rs new file mode 100644 index 0000000000..d4d3677f5c --- /dev/null +++ b/third_party/rust/neqo-crypto/src/replay.rs @@ -0,0 +1,83 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{ + convert::{TryFrom, TryInto}, + ops::{Deref, DerefMut}, + os::raw::c_uint, + ptr::null_mut, + time::{Duration, Instant}, +}; + +use crate::{ + err::Res, + ssl::PRFileDesc, + time::{Interval, PRTime, Time}, +}; + +// This is an opaque struct in NSS. +#[allow(clippy::upper_case_acronyms)] +#[allow(clippy::empty_enum)] +pub enum SSLAntiReplayContext {} + +experimental_api!(SSL_CreateAntiReplayContext( + now: PRTime, + window: PRTime, + k: c_uint, + bits: c_uint, + ctx: *mut *mut SSLAntiReplayContext, +)); +experimental_api!(SSL_ReleaseAntiReplayContext(ctx: *mut SSLAntiReplayContext)); +experimental_api!(SSL_SetAntiReplayContext( + fd: *mut PRFileDesc, + ctx: *mut SSLAntiReplayContext, +)); + +scoped_ptr!( + AntiReplayContext, + SSLAntiReplayContext, + SSL_ReleaseAntiReplayContext +); + +/// `AntiReplay` is used by servers when processing 0-RTT handshakes. +/// It limits the exposure of servers to replay attack by rejecting 0-RTT +/// if it appears to be a replay. There is a false-positive rate that can be +/// managed by tuning the parameters used to create the context. +#[allow(clippy::module_name_repetitions)] +pub struct AntiReplay { + ctx: AntiReplayContext, +} + +impl AntiReplay { + /// Make a new anti-replay context. + /// See the documentation in NSS for advice on how to set these values. + /// + /// # Errors + /// + /// Returns an error if `now` is in the past relative to our baseline or + /// NSS is unable to generate an anti-replay context. + pub fn new(now: Instant, window: Duration, k: usize, bits: usize) -> Res<Self> { + let mut ctx: *mut SSLAntiReplayContext = null_mut(); + unsafe { + SSL_CreateAntiReplayContext( + Time::from(now).try_into()?, + Interval::from(window).try_into()?, + c_uint::try_from(k)?, + c_uint::try_from(bits)?, + &mut ctx, + ) + }?; + + Ok(Self { + ctx: AntiReplayContext::from_ptr(ctx)?, + }) + } + + /// Configure the provided socket with this anti-replay context. + pub(crate) fn config_socket(&self, fd: *mut PRFileDesc) -> Res<()> { + unsafe { SSL_SetAntiReplayContext(fd, *self.ctx) } + } +} diff --git a/third_party/rust/neqo-crypto/src/result.rs b/third_party/rust/neqo-crypto/src/result.rs new file mode 100644 index 0000000000..e304fcea7f --- /dev/null +++ b/third_party/rust/neqo-crypto/src/result.rs @@ -0,0 +1,135 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::{ + err::{nspr, Error, PR_ErrorToName, PR_ErrorToString, PR_GetError, Res, PR_LANGUAGE_I_DEFAULT}, + ssl, +}; + +use std::ffi::CStr; + +pub fn result(rv: ssl::SECStatus) -> Res<()> { + _ = result_helper(rv, false)?; + Ok(()) +} + +pub fn result_or_blocked(rv: ssl::SECStatus) -> Res<bool> { + result_helper(rv, true) +} + +fn wrap_str_fn<F>(f: F, dflt: &str) -> String +where + F: FnOnce() -> *const i8, +{ + unsafe { + let p = f(); + if p.is_null() { + return dflt.to_string(); + } + CStr::from_ptr(p).to_string_lossy().into_owned() + } +} + +fn result_helper(rv: ssl::SECStatus, allow_blocked: bool) -> Res<bool> { + if rv == ssl::_SECStatus_SECSuccess { + return Ok(false); + } + + let code = unsafe { PR_GetError() }; + if allow_blocked && code == nspr::PR_WOULD_BLOCK_ERROR { + return Ok(true); + } + + let name = wrap_str_fn(|| unsafe { PR_ErrorToName(code) }, "UNKNOWN_ERROR"); + let desc = wrap_str_fn( + || unsafe { PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT) }, + "...", + ); + Err(Error::NssError { name, code, desc }) +} + +#[cfg(test)] +mod tests { + use super::{result, result_or_blocked}; + use crate::{ + err::{self, nspr, Error, PRErrorCode, PR_SetError}, + ssl, + }; + use test_fixture::fixture_init; + + fn set_error_code(code: PRErrorCode) { + unsafe { PR_SetError(code, 0) }; + } + + #[test] + fn is_ok() { + assert!(result(ssl::SECSuccess).is_ok()); + } + + #[test] + fn is_err() { + // This code doesn't work without initializing NSS first. + fixture_init(); + + set_error_code(err::ssl::SSL_ERROR_BAD_MAC_READ); + let r = result(ssl::SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "SSL_ERROR_BAD_MAC_READ"); + assert_eq!(code, -12273); + assert_eq!( + desc, + "SSL received a record with an incorrect Message Authentication Code." + ); + } + _ => unreachable!(), + } + } + + #[test] + fn is_err_zero_code() { + // This code doesn't work without initializing NSS first. + fixture_init(); + + set_error_code(0); + let r = result(ssl::SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, .. } => { + assert_eq!(name, "UNKNOWN_ERROR"); + assert_eq!(code, 0); + // Note that we don't test |desc| here because that comes from + // strerror(0), which is platform-dependent. + } + _ => unreachable!(), + } + } + + #[test] + fn blocked_as_error() { + // This code doesn't work without initializing NSS first. + fixture_init(); + + set_error_code(nspr::PR_WOULD_BLOCK_ERROR); + let r = result(ssl::SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "PR_WOULD_BLOCK_ERROR"); + assert_eq!(code, -5998); + assert_eq!(desc, "The operation would have blocked"); + } + _ => panic!("bad error type"), + } + } + + #[test] + fn is_blocked() { + set_error_code(nspr::PR_WOULD_BLOCK_ERROR); + assert!(result_or_blocked(ssl::SECFailure).unwrap()); + } +} diff --git a/third_party/rust/neqo-crypto/src/secrets.rs b/third_party/rust/neqo-crypto/src/secrets.rs new file mode 100644 index 0000000000..75677636b6 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/secrets.rs @@ -0,0 +1,129 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::{os::raw::c_void, pin::Pin}; + +use neqo_common::qdebug; + +use crate::{ + agentio::as_c_void, + constants::Epoch, + err::Res, + p11::{PK11SymKey, PK11_ReferenceSymKey, SymKey}, + ssl::{PRFileDesc, SSLSecretCallback, SSLSecretDirection}, +}; + +experimental_api!(SSL_SecretCallback( + fd: *mut PRFileDesc, + cb: SSLSecretCallback, + arg: *mut c_void, +)); + +#[derive(Clone, Copy, Debug)] +pub enum SecretDirection { + Read, + Write, +} + +impl From<SSLSecretDirection::Type> for SecretDirection { + #[must_use] + fn from(dir: SSLSecretDirection::Type) -> Self { + match dir { + SSLSecretDirection::ssl_secret_read => Self::Read, + SSLSecretDirection::ssl_secret_write => Self::Write, + _ => unreachable!(), + } + } +} + +#[derive(Debug, Default)] +#[allow(clippy::module_name_repetitions)] +pub struct DirectionalSecrets { + // We only need to maintain 3 secrets for the epochs used during the handshake. + secrets: [Option<SymKey>; 3], +} + +impl DirectionalSecrets { + fn put(&mut self, epoch: Epoch, key: SymKey) { + assert!(epoch > 0); + let i = (epoch - 1) as usize; + assert!(i < self.secrets.len()); + // assert!(self.secrets[i].is_none()); + self.secrets[i] = Some(key); + } + + pub fn take(&mut self, epoch: Epoch) -> Option<SymKey> { + assert!(epoch > 0); + let i = (epoch - 1) as usize; + assert!(i < self.secrets.len()); + self.secrets[i].take() + } +} + +#[derive(Debug, Default)] +pub struct Secrets { + r: DirectionalSecrets, + w: DirectionalSecrets, +} + +impl Secrets { + #[allow(clippy::unused_self)] + unsafe extern "C" fn secret_available( + _fd: *mut PRFileDesc, + epoch: u16, + dir: SSLSecretDirection::Type, + secret: *mut PK11SymKey, + arg: *mut c_void, + ) { + let secrets = arg.cast::<Self>().as_mut().unwrap(); + secrets.put_raw(epoch, dir, secret); + } + + fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) { + let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) }; + let key = SymKey::from_ptr(key_ptr).expect("NSS shouldn't be passing out NULL secrets"); + self.put(SecretDirection::from(dir), epoch, key); + } + + fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) { + qdebug!("{:?} secret available for {:?}: {:?}", dir, epoch, key); + let keys = match dir { + SecretDirection::Read => &mut self.r, + SecretDirection::Write => &mut self.w, + }; + keys.put(epoch, key); + } +} + +#[derive(Debug)] +pub struct SecretHolder { + secrets: Pin<Box<Secrets>>, +} + +impl SecretHolder { + /// This registers with NSS. The lifetime of this object needs to match the lifetime + /// of the connection, or bad things might happen. + pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> { + let p = as_c_void(&mut self.secrets); + unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) } + } + + pub fn take_read(&mut self, epoch: Epoch) -> Option<SymKey> { + self.secrets.r.take(epoch) + } + + pub fn take_write(&mut self, epoch: Epoch) -> Option<SymKey> { + self.secrets.w.take(epoch) + } +} + +impl Default for SecretHolder { + fn default() -> Self { + Self { + secrets: Box::pin(Secrets::default()), + } + } +} diff --git a/third_party/rust/neqo-crypto/src/selfencrypt.rs b/third_party/rust/neqo-crypto/src/selfencrypt.rs new file mode 100644 index 0000000000..b8a63153fd --- /dev/null +++ b/third_party/rust/neqo-crypto/src/selfencrypt.rs @@ -0,0 +1,161 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::mem; + +use neqo_common::{hex, qinfo, qtrace, Encoder}; + +use crate::{ + constants::{Cipher, Version}, + err::{Error, Res}, + hkdf, + p11::{random, SymKey}, + Aead, +}; + +#[derive(Debug)] +pub struct SelfEncrypt { + version: Version, + cipher: Cipher, + key_id: u8, + key: SymKey, + old_key: Option<SymKey>, +} + +impl SelfEncrypt { + const VERSION: u8 = 1; + const SALT_LENGTH: usize = 16; + + /// # Errors + /// + /// Failure to generate a new HKDF key using NSS results in an error. + pub fn new(version: Version, cipher: Cipher) -> Res<Self> { + let key = hkdf::generate_key(version, cipher)?; + Ok(Self { + version, + cipher, + key_id: 0, + key, + old_key: None, + }) + } + + fn make_aead(&self, k: &SymKey, salt: &[u8]) -> Res<Aead> { + debug_assert_eq!(salt.len(), Self::SALT_LENGTH); + let salt = hkdf::import_key(self.version, salt)?; + let secret = hkdf::extract(self.version, self.cipher, Some(&salt), k)?; + Aead::new(false, self.version, self.cipher, &secret, "neqo self") + } + + /// Rotate keys. This causes any previous key that is being held to be replaced by the current + /// key. + /// + /// # Errors + /// + /// Failure to generate a new HKDF key using NSS results in an error. + pub fn rotate(&mut self) -> Res<()> { + let new_key = hkdf::generate_key(self.version, self.cipher)?; + self.old_key = Some(mem::replace(&mut self.key, new_key)); + let (kid, _) = self.key_id.overflowing_add(1); + self.key_id = kid; + qinfo!(["SelfEncrypt"], "Rotated keys to {}", self.key_id); + Ok(()) + } + + /// Seal an item using the underlying key. This produces a single buffer that contains + /// the encrypted `plaintext`, plus a version number and salt. + /// `aad` is only used as input to the AEAD, it is not included in the output; the + /// caller is responsible for carrying the AAD as appropriate. + /// + /// # Errors + /// + /// Failure to protect using NSS AEAD APIs produces an error. + pub fn seal(&self, aad: &[u8], plaintext: &[u8]) -> Res<Vec<u8>> { + // Format is: + // struct { + // uint8 version; + // uint8 key_id; + // uint8 salt[16]; + // opaque aead_encrypted(plaintext)[length as expanded]; + // }; + // AAD covers the entire header, plus the value of the AAD parameter that is provided. + let salt = random(Self::SALT_LENGTH); + let cipher = self.make_aead(&self.key, &salt)?; + let encoded_len = 2 + salt.len() + plaintext.len() + cipher.expansion(); + + let mut enc = Encoder::with_capacity(encoded_len); + enc.encode_byte(Self::VERSION); + enc.encode_byte(self.key_id); + enc.encode(&salt); + + let mut extended_aad = enc.clone(); + extended_aad.encode(aad); + + let offset = enc.len(); + let mut output: Vec<u8> = enc.into(); + output.resize(encoded_len, 0); + cipher.encrypt(0, extended_aad.as_ref(), plaintext, &mut output[offset..])?; + qtrace!( + ["SelfEncrypt"], + "seal {} {} -> {}", + hex(aad), + hex(plaintext), + hex(&output) + ); + Ok(output) + } + + fn select_key(&self, kid: u8) -> Option<&SymKey> { + if kid == self.key_id { + Some(&self.key) + } else { + let (prev_key_id, _) = self.key_id.overflowing_sub(1); + if kid == prev_key_id { + self.old_key.as_ref() + } else { + None + } + } + } + + /// Open the protected `ciphertext`. + /// + /// # Errors + /// + /// Returns an error when the self-encrypted object is invalid; + /// when the keys have been rotated; or when NSS fails. + #[allow(clippy::similar_names)] // aad is similar to aead + pub fn open(&self, aad: &[u8], ciphertext: &[u8]) -> Res<Vec<u8>> { + if ciphertext[0] != Self::VERSION { + return Err(Error::SelfEncryptFailure); + } + let Some(key) = self.select_key(ciphertext[1]) else { + return Err(Error::SelfEncryptFailure); + }; + let offset = 2 + Self::SALT_LENGTH; + + let mut extended_aad = Encoder::with_capacity(offset + aad.len()); + extended_aad.encode(&ciphertext[0..offset]); + extended_aad.encode(aad); + + let aead = self.make_aead(key, &ciphertext[2..offset])?; + // NSS insists on having extra space available for decryption. + let padded_len = ciphertext.len() - offset; + let mut output = vec![0; padded_len]; + let decrypted = + aead.decrypt(0, extended_aad.as_ref(), &ciphertext[offset..], &mut output)?; + let final_len = decrypted.len(); + output.truncate(final_len); + qtrace!( + ["SelfEncrypt"], + "open {} {} -> {}", + hex(aad), + hex(ciphertext), + hex(&output) + ); + Ok(output) + } +} diff --git a/third_party/rust/neqo-crypto/src/ssl.rs b/third_party/rust/neqo-crypto/src/ssl.rs new file mode 100644 index 0000000000..8aaacffae6 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/ssl.rs @@ -0,0 +1,153 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow( + dead_code, + non_upper_case_globals, + non_snake_case, + clippy::cognitive_complexity, + clippy::too_many_lines, + clippy::upper_case_acronyms, + unknown_lints, + clippy::borrow_as_ptr +)] + +use std::os::raw::{c_uint, c_void}; + +use crate::{ + constants::Epoch, + err::{secstatus_to_res, Res}, +}; + +include!(concat!(env!("OUT_DIR"), "/nss_ssl.rs")); +mod SSLOption { + include!(concat!(env!("OUT_DIR"), "/nss_sslopt.rs")); +} + +// I clearly don't understand how bindgen operates. +#[allow(clippy::empty_enum)] +pub enum PLArenaPool {} +#[allow(clippy::empty_enum)] +pub enum PRFileDesc {} + +// Remap some constants. +pub const SECSuccess: SECStatus = _SECStatus_SECSuccess; +pub const SECFailure: SECStatus = _SECStatus_SECFailure; + +#[derive(Debug, Copy, Clone)] +pub enum Opt { + Locking, + Tickets, + OcspStapling, + Alpn, + ExtendedMasterSecret, + SignedCertificateTimestamps, + EarlyData, + RecordSizeLimit, + Tls13CompatMode, + HelloDowngradeCheck, + SuppressEndOfEarlyData, + Grease, +} + +impl Opt { + // Cast is safe here because SSLOptions are within the i32 range + #[allow(clippy::cast_possible_wrap)] + pub(crate) fn as_int(self) -> PRInt32 { + let i = match self { + Self::Locking => SSLOption::SSL_NO_LOCKS, + Self::Tickets => SSLOption::SSL_ENABLE_SESSION_TICKETS, + Self::OcspStapling => SSLOption::SSL_ENABLE_OCSP_STAPLING, + Self::Alpn => SSLOption::SSL_ENABLE_ALPN, + Self::ExtendedMasterSecret => SSLOption::SSL_ENABLE_EXTENDED_MASTER_SECRET, + Self::SignedCertificateTimestamps => SSLOption::SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, + Self::EarlyData => SSLOption::SSL_ENABLE_0RTT_DATA, + Self::RecordSizeLimit => SSLOption::SSL_RECORD_SIZE_LIMIT, + Self::Tls13CompatMode => SSLOption::SSL_ENABLE_TLS13_COMPAT_MODE, + Self::HelloDowngradeCheck => SSLOption::SSL_ENABLE_HELLO_DOWNGRADE_CHECK, + Self::SuppressEndOfEarlyData => SSLOption::SSL_SUPPRESS_END_OF_EARLY_DATA, + Self::Grease => SSLOption::SSL_ENABLE_GREASE, + }; + i as PRInt32 + } + + // Some options are backwards, like SSL_NO_LOCKS, so use this to manage that. + fn map_enabled(self, enabled: bool) -> PRIntn { + let v = match self { + Self::Locking => !enabled, + _ => enabled, + }; + PRIntn::from(v) + } + + pub(crate) fn set(self, fd: *mut PRFileDesc, value: bool) -> Res<()> { + secstatus_to_res(unsafe { SSL_OptionSet(fd, self.as_int(), self.map_enabled(value)) }) + } +} + +experimental_api!(SSL_GetCurrentEpoch( + fd: *mut PRFileDesc, + read_epoch: *mut u16, + write_epoch: *mut u16, +)); +experimental_api!(SSL_HelloRetryRequestCallback( + fd: *mut PRFileDesc, + cb: SSLHelloRetryRequestCallback, + arg: *mut c_void, +)); +experimental_api!(SSL_RecordLayerWriteCallback( + fd: *mut PRFileDesc, + cb: SSLRecordWriteCallback, + arg: *mut c_void, +)); +experimental_api!(SSL_RecordLayerData( + fd: *mut PRFileDesc, + epoch: Epoch, + ct: SSLContentType::Type, + data: *const u8, + len: c_uint, +)); +experimental_api!(SSL_SendSessionTicket( + fd: *mut PRFileDesc, + extra: *const u8, + len: c_uint, +)); +experimental_api!(SSL_SetMaxEarlyDataSize(fd: *mut PRFileDesc, size: u32)); +experimental_api!(SSL_SetResumptionToken( + fd: *mut PRFileDesc, + token: *const u8, + len: c_uint, +)); +experimental_api!(SSL_SetResumptionTokenCallback( + fd: *mut PRFileDesc, + cb: SSLResumptionTokenCallback, + arg: *mut c_void, +)); + +experimental_api!(SSL_GetResumptionTokenInfo( + token: *const u8, + token_len: c_uint, + info: *mut SSLResumptionTokenInfo, + len: c_uint, +)); + +experimental_api!(SSL_DestroyResumptionTokenInfo( + info: *mut SSLResumptionTokenInfo, +)); + +#[cfg(test)] +mod tests { + use super::{SSL_GetNumImplementedCiphers, SSL_NumImplementedCiphers}; + + #[test] + fn num_ciphers() { + assert!(unsafe { SSL_NumImplementedCiphers } > 0); + assert!(unsafe { SSL_GetNumImplementedCiphers() } > 0); + assert_eq!(unsafe { SSL_NumImplementedCiphers }, unsafe { + SSL_GetNumImplementedCiphers() + }); + } +} diff --git a/third_party/rust/neqo-crypto/src/time.rs b/third_party/rust/neqo-crypto/src/time.rs new file mode 100644 index 0000000000..84dbfdb4a5 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/time.rs @@ -0,0 +1,259 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(clippy::upper_case_acronyms)] + +use std::{ + boxed::Box, + convert::{TryFrom, TryInto}, + ops::Deref, + os::raw::c_void, + pin::Pin, + time::{Duration, Instant}, +}; + +use crate::{ + agentio::as_c_void, + err::{Error, Res}, + once::OnceResult, + ssl::{PRFileDesc, SSLTimeFunc}, +}; + +include!(concat!(env!("OUT_DIR"), "/nspr_time.rs")); + +experimental_api!(SSL_SetTimeFunc( + fd: *mut PRFileDesc, + cb: SSLTimeFunc, + arg: *mut c_void, +)); + +/// This struct holds the zero time used for converting between `Instant` and `PRTime`. +#[derive(Debug)] +struct TimeZero { + instant: Instant, + prtime: PRTime, +} + +impl TimeZero { + /// This function sets a baseline from an instance of `Instant`. + /// This allows for the possibility that code that uses these APIs will create + /// instances of `Instant` before any of this code is run. If `Instant`s older than + /// `BASE_TIME` are used with these conversion functions, they will fail. + /// To avoid that, we make sure that this sets the base time using the first value + /// it sees if it is in the past. If it is not, then use `Instant::now()` instead. + pub fn baseline(t: Instant) -> Self { + let now = Instant::now(); + let prnow = unsafe { PR_Now() }; + + if now <= t { + // `t` is in the future, just use `now`. + Self { + instant: now, + prtime: prnow, + } + } else { + let elapsed = Interval::from(now.duration_since(now)); + // An error from these unwrap functions would require + // ridiculously long application running time. + let prelapsed: PRTime = elapsed.try_into().unwrap(); + Self { + instant: t, + prtime: prnow.checked_sub(prelapsed).unwrap(), + } + } + } +} + +static mut BASE_TIME: OnceResult<TimeZero> = OnceResult::new(); + +fn get_base() -> &'static TimeZero { + let f = || TimeZero { + instant: Instant::now(), + prtime: unsafe { PR_Now() }, + }; + unsafe { BASE_TIME.call_once(f) } +} + +pub(crate) fn init() { + _ = get_base(); +} + +/// Time wraps Instant and provides conversion functions into `PRTime`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Time { + t: Instant, +} + +impl Deref for Time { + type Target = Instant; + fn deref(&self) -> &Self::Target { + &self.t + } +} + +impl From<Instant> for Time { + /// Convert from an Instant into a Time. + fn from(t: Instant) -> Self { + // Call `TimeZero::baseline(t)` so that time zero can be set. + let f = || TimeZero::baseline(t); + _ = unsafe { BASE_TIME.call_once(f) }; + Self { t } + } +} + +impl TryFrom<PRTime> for Time { + type Error = Error; + fn try_from(prtime: PRTime) -> Res<Self> { + let base = get_base(); + if let Some(delta) = prtime.checked_sub(base.prtime) { + let d = Duration::from_micros(delta.try_into()?); + base.instant + .checked_add(d) + .map_or(Err(Error::TimeTravelError), |t| Ok(Self { t })) + } else { + Err(Error::TimeTravelError) + } + } +} + +impl TryInto<PRTime> for Time { + type Error = Error; + fn try_into(self) -> Res<PRTime> { + let base = get_base(); + let delta = self + .t + .checked_duration_since(base.instant) + .ok_or(Error::TimeTravelError)?; + if let Ok(d) = PRTime::try_from(delta.as_micros()) { + d.checked_add(base.prtime).ok_or(Error::TimeTravelError) + } else { + Err(Error::TimeTravelError) + } + } +} + +impl From<Time> for Instant { + #[must_use] + fn from(t: Time) -> Self { + t.t + } +} + +/// Interval wraps Duration and provides conversion functions into `PRTime`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Interval { + d: Duration, +} + +impl Deref for Interval { + type Target = Duration; + fn deref(&self) -> &Self::Target { + &self.d + } +} + +impl TryFrom<PRTime> for Interval { + type Error = Error; + fn try_from(prtime: PRTime) -> Res<Self> { + Ok(Self { + d: Duration::from_micros(u64::try_from(prtime)?), + }) + } +} + +impl From<Duration> for Interval { + fn from(d: Duration) -> Self { + Self { d } + } +} + +impl TryInto<PRTime> for Interval { + type Error = Error; + fn try_into(self) -> Res<PRTime> { + Ok(PRTime::try_from(self.d.as_micros())?) + } +} + +/// `TimeHolder` maintains a `PRTime` value in a form that is accessible to the TLS stack. +#[derive(Debug)] +pub struct TimeHolder { + t: Pin<Box<PRTime>>, +} + +impl TimeHolder { + unsafe extern "C" fn time_func(arg: *mut c_void) -> PRTime { + let p = arg as *const PRTime; + *p.as_ref().unwrap() + } + + pub fn bind(&mut self, fd: *mut PRFileDesc) -> Res<()> { + unsafe { SSL_SetTimeFunc(fd, Some(Self::time_func), as_c_void(&mut self.t)) } + } + + pub fn set(&mut self, t: Instant) -> Res<()> { + *self.t = Time::from(t).try_into()?; + Ok(()) + } +} + +impl Default for TimeHolder { + fn default() -> Self { + TimeHolder { t: Box::pin(0) } + } +} + +#[cfg(test)] +mod test { + use std::{ + convert::{TryFrom, TryInto}, + time::{Duration, Instant}, + }; + + use super::{get_base, init, Interval, PRTime, Time}; + use crate::err::Res; + + #[test] + fn convert_stable() { + init(); + let now = Time::from(Instant::now()); + let pr: PRTime = now.try_into().expect("convert to PRTime with truncation"); + let t2 = Time::try_from(pr).expect("convert to Instant"); + let pr2: PRTime = t2.try_into().expect("convert to PRTime again"); + assert_eq!(pr, pr2); + let t3 = Time::try_from(pr2).expect("convert to Instant again"); + assert_eq!(t2, t3); + } + + #[test] + fn past_time() { + init(); + let base = get_base(); + assert!(Time::try_from(base.prtime - 1).is_err()); + } + + #[test] + fn negative_time() { + init(); + assert!(Time::try_from(-1).is_err()); + } + + #[test] + fn negative_interval() { + init(); + assert!(Interval::try_from(-1).is_err()); + } + + #[test] + // We allow replace_consts here because + // std::u64::max_value() isn't available + // in all of our targets + fn overflow_interval() { + init(); + let interval = Interval::from(Duration::from_micros(u64::max_value())); + let res: Res<PRTime> = interval.try_into(); + assert!(res.is_err()); + } +} |