diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/neqo-crypto/src | |
parent | Initial commit. (diff) | |
download | firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.tar.xz firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-crypto/src')
21 files changed, 3911 insertions, 0 deletions
diff --git a/third_party/rust/neqo-crypto/src/aead.rs b/third_party/rust/neqo-crypto/src/aead.rs new file mode 100644 index 0000000000..603e19d4c1 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/aead.rs @@ -0,0 +1,164 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::constants::{Cipher, Version}; +use crate::err::{Error, Res}; +use crate::p11::{PK11SymKey, SymKey}; +use crate::ssl; +use crate::ssl::{PRUint16, PRUint64, PRUint8, SSLAeadContext}; + +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::ops::{Deref, DerefMut}; +use std::os::raw::{c_char, c_uint}; +use std::ptr::{null_mut, NonNull}; + +experimental_api!(SSL_MakeAead( + version: PRUint16, + cipher: PRUint16, + secret: *mut PK11SymKey, + label_prefix: *const c_char, + label_prefix_len: c_uint, + ctx: *mut *mut SSLAeadContext, +)); +experimental_api!(SSL_AeadEncrypt( + ctx: *const SSLAeadContext, + counter: PRUint64, + aad: *const PRUint8, + aad_len: c_uint, + input: *const PRUint8, + input_len: c_uint, + output: *const PRUint8, + output_len: *mut c_uint, + max_output: c_uint +)); +experimental_api!(SSL_AeadDecrypt( + ctx: *const SSLAeadContext, + counter: PRUint64, + aad: *const PRUint8, + aad_len: c_uint, + input: *const PRUint8, + input_len: c_uint, + output: *const PRUint8, + output_len: *mut c_uint, + max_output: c_uint +)); +experimental_api!(SSL_DestroyAead(ctx: *mut SSLAeadContext)); +scoped_ptr!(AeadContext, SSLAeadContext, SSL_DestroyAead); + +pub struct Aead { + ctx: AeadContext, +} + +impl Aead { + /// Create a new AEAD based on the indicated TLS version and cipher suite. + /// + /// # Errors + /// Returns `Error` when the supporting NSS functions fail. + pub fn new(version: Version, cipher: Cipher, secret: &SymKey, prefix: &str) -> Res<Self> { + let s: *mut PK11SymKey = **secret; + unsafe { Self::from_raw(version, cipher, s, prefix) } + } + + #[must_use] + #[allow(clippy::unused_self)] + pub fn expansion(&self) -> usize { + 16 + } + + unsafe fn from_raw( + version: Version, + cipher: Cipher, + secret: *mut PK11SymKey, + prefix: &str, + ) -> Res<Self> { + let p = prefix.as_bytes(); + let mut ctx: *mut ssl::SSLAeadContext = null_mut(); + SSL_MakeAead( + version, + cipher, + secret, + p.as_ptr() as *const c_char, + c_uint::try_from(p.len())?, + &mut ctx, + )?; + match NonNull::new(ctx) { + Some(ctx_ptr) => Ok(Self { + ctx: AeadContext::new(ctx_ptr), + }), + None => Err(Error::InternalError), + } + } + + /// Decrypt a plaintext. + /// + /// The space provided in `output` needs to be larger than `input` by + /// the value provided in `Aead::expansion`. + /// + /// # Errors + /// If the input can't be protected or any input is too large for NSS. + pub fn encrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + let mut l: c_uint = 0; + unsafe { + SSL_AeadEncrypt( + *self.ctx.deref(), + count, + aad.as_ptr(), + c_uint::try_from(aad.len())?, + input.as_ptr(), + c_uint::try_from(input.len())?, + output.as_mut_ptr(), + &mut l, + c_uint::try_from(output.len())?, + ) + }?; + Ok(&output[0..(l.try_into()?)]) + } + + /// Decrypt a ciphertext. + /// + /// Note that NSS insists upon having extra space available for decryption, so + /// the buffer for `output` should be the same length as `input`, even though + /// the final result will be shorter. + /// + /// # Errors + /// If the input isn't authenticated or any input is too large for NSS. + pub fn decrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + let mut l: c_uint = 0; + unsafe { + SSL_AeadDecrypt( + *self.ctx.deref(), + count, + aad.as_ptr(), + c_uint::try_from(aad.len())?, + input.as_ptr(), + c_uint::try_from(input.len())?, + output.as_mut_ptr(), + &mut l, + c_uint::try_from(output.len())?, + ) + }?; + Ok(&output[0..(l.try_into()?)]) + } +} + +impl fmt::Debug for Aead { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "[AEAD Context]") + } +} diff --git a/third_party/rust/neqo-crypto/src/agent.rs b/third_party/rust/neqo-crypto/src/agent.rs new file mode 100644 index 0000000000..4b207a7707 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/agent.rs @@ -0,0 +1,1061 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +pub use crate::agentio::{as_c_void, Record, RecordList}; +use crate::agentio::{AgentIo, METHODS}; +use crate::assert_initialized; +use crate::auth::AuthenticationStatus; +pub use crate::cert::CertificateInfo; +use crate::constants::{ + Alert, Cipher, Epoch, Extension, Group, SignatureScheme, Version, TLS_VERSION_1_3, +}; +use crate::err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res}; +use crate::ext::{ExtensionHandler, ExtensionTracker}; +use crate::p11; +use crate::prio; +use crate::replay::AntiReplay; +use crate::secrets::SecretHolder; +use crate::ssl::{self, PRBool}; +use crate::time::{Time, TimeHolder}; + +use neqo_common::{hex_snip_middle, qdebug, qinfo, qtrace, qwarn}; +use std::cell::RefCell; +use std::convert::TryFrom; +use std::ffi::CString; +use std::mem::{self, MaybeUninit}; +use std::ops::{Deref, DerefMut}; +use std::os::raw::{c_uint, c_void}; +use std::pin::Pin; +use std::ptr::{null, null_mut, NonNull}; +use std::rc::Rc; +use std::time::Instant; + +/// The maximum number of tickets to remember for a given connection. +const MAX_TICKETS: usize = 4; + +#[derive(Clone, Debug, PartialEq)] +pub enum HandshakeState { + New, + InProgress, + AuthenticationPending, + Authenticated(PRErrorCode), + Complete(SecretAgentInfo), + Failed(Error), +} + +impl HandshakeState { + #[must_use] + pub fn is_connected(&self) -> bool { + matches!(self, Self::Complete(_)) + } + + #[must_use] + pub fn is_final(&self) -> bool { + matches!(self, Self::Complete(_) | Self::Failed(_)) + } +} + +fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>> { + let mut alpn_state = ssl::SSLNextProtoState::SSL_NEXT_PROTO_NO_SUPPORT; + let mut chosen = vec![0_u8; 255]; + let mut chosen_len: c_uint = 0; + secstatus_to_res(unsafe { + ssl::SSL_GetNextProto( + fd, + &mut alpn_state, + chosen.as_mut_ptr(), + &mut chosen_len, + c_uint::try_from(chosen.len())?, + ) + })?; + + let alpn = match (pre, alpn_state) { + (true, ssl::SSLNextProtoState::SSL_NEXT_PROTO_EARLY_VALUE) + | (false, ssl::SSLNextProtoState::SSL_NEXT_PROTO_NEGOTIATED) + | (false, ssl::SSLNextProtoState::SSL_NEXT_PROTO_SELECTED) => { + chosen.truncate(usize::try_from(chosen_len)?); + Some(match String::from_utf8(chosen) { + Ok(a) => a, + Err(_) => return Err(Error::InternalError), + }) + } + _ => None, + }; + qtrace!([format!("{:p}", fd)], "got ALPN {:?}", alpn); + Ok(alpn) +} + +pub struct SecretAgentPreInfo { + info: ssl::SSLPreliminaryChannelInfo, + alpn: Option<String>, +} + +macro_rules! preinfo_arg { + ($v:ident, $m:ident, $f:ident: $t:ident $(,)?) => { + #[must_use] + pub fn $v(&self) -> Option<$t> { + match self.info.valuesSet & ssl::$m { + 0 => None, + _ => Some($t::from(self.info.$f)), + } + } + }; +} + +impl SecretAgentPreInfo { + fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> { + let mut info: MaybeUninit<ssl::SSLPreliminaryChannelInfo> = MaybeUninit::uninit(); + secstatus_to_res(unsafe { + ssl::SSL_GetPreliminaryChannelInfo( + fd, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLPreliminaryChannelInfo>())?, + ) + })?; + + Ok(Self { + info: unsafe { info.assume_init() }, + alpn: get_alpn(fd, true)?, + }) + } + + preinfo_arg!(version, ssl_preinfo_version, protocolVersion: Version); + preinfo_arg!(cipher_suite, ssl_preinfo_cipher_suite, cipherSuite: Cipher); + #[must_use] + pub fn early_data(&self) -> bool { + self.info.canSendEarlyData != 0 + } + #[must_use] + pub fn max_early_data(&self) -> usize { + usize::try_from(self.info.maxEarlyDataSize).unwrap() + } + #[must_use] + pub fn alpn(&self) -> Option<&String> { + self.alpn.as_ref() + } + + preinfo_arg!( + early_data_cipher, + ssl_preinfo_0rtt_cipher_suite, + zeroRttCipherSuite: Cipher, + ); +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct SecretAgentInfo { + version: Version, + cipher: Cipher, + group: Group, + resumed: bool, + early_data: bool, + alpn: Option<String>, + signature_scheme: SignatureScheme, +} + +impl SecretAgentInfo { + fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> { + let mut info: MaybeUninit<ssl::SSLChannelInfo> = MaybeUninit::uninit(); + secstatus_to_res(unsafe { + ssl::SSL_GetChannelInfo( + fd, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLChannelInfo>())?, + ) + })?; + let info = unsafe { info.assume_init() }; + Ok(Self { + version: info.protocolVersion, + cipher: info.cipherSuite, + group: Group::try_from(info.keaGroup)?, + resumed: info.resumed != 0, + early_data: info.earlyDataAccepted != 0, + alpn: get_alpn(fd, false)?, + signature_scheme: SignatureScheme::try_from(info.signatureScheme)?, + }) + } + #[must_use] + pub fn version(&self) -> Version { + self.version + } + #[must_use] + pub fn cipher_suite(&self) -> Cipher { + self.cipher + } + #[must_use] + pub fn key_exchange(&self) -> Group { + self.group + } + #[must_use] + pub fn resumed(&self) -> bool { + self.resumed + } + #[must_use] + pub fn early_data_accepted(&self) -> bool { + self.early_data + } + #[must_use] + pub fn alpn(&self) -> Option<&String> { + self.alpn.as_ref() + } + #[must_use] + pub fn signature_scheme(&self) -> SignatureScheme { + self.signature_scheme + } +} + +/// `SecretAgent` holds the common parts of client and server. +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct SecretAgent { + fd: *mut ssl::PRFileDesc, + secrets: SecretHolder, + raw: Option<bool>, + io: Pin<Box<AgentIo>>, + state: HandshakeState, + + /// Records whether authentication of certificates is required. + auth_required: Pin<Box<bool>>, + /// Records any fatal alert that is sent by the stack. + alert: Pin<Box<Option<Alert>>>, + /// The current time. + now: TimeHolder, + + extension_handlers: Vec<ExtensionTracker>, + inf: Option<SecretAgentInfo>, +} + +impl SecretAgent { + fn new() -> Res<Self> { + let mut io = Box::pin(AgentIo::new()); + let fd = Self::create_fd(&mut io)?; + Ok(Self { + fd, + secrets: SecretHolder::default(), + raw: None, + io, + state: HandshakeState::New, + + auth_required: Box::pin(false), + alert: Box::pin(None), + now: TimeHolder::default(), + + extension_handlers: Vec::new(), + inf: None, + }) + } + + // Create a new SSL file descriptor. + // + // Note that we create separate bindings for PRFileDesc as both + // ssl::PRFileDesc and prio::PRFileDesc. This keeps the bindings + // minimal, but it means that the two forms need casts to translate + // between them. ssl::PRFileDesc is left as an opaque type, as the + // ssl::SSL_* APIs only need an opaque type. + fn create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc> { + assert_initialized(); + let label = CString::new("sslwrapper")?; + let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) }; + + let base_fd = unsafe { prio::PR_CreateIOLayerStub(id, METHODS) }; + if base_fd.is_null() { + return Err(Error::CreateSslSocket); + } + let fd = unsafe { + (*base_fd).secret = as_c_void(io) as *mut _; + ssl::SSL_ImportFD(null_mut(), base_fd as *mut ssl::PRFileDesc) + }; + if fd.is_null() { + unsafe { prio::PR_Close(base_fd) }; + return Err(Error::CreateSslSocket); + } + Ok(fd) + } + + unsafe extern "C" fn auth_complete_hook( + arg: *mut c_void, + _fd: *mut ssl::PRFileDesc, + _check_sig: ssl::PRBool, + _is_server: ssl::PRBool, + ) -> ssl::SECStatus { + let auth_required_ptr = arg as *mut bool; + *auth_required_ptr = true; + // NSS insists on getting SECWouldBlock here rather than accepting + // the usual combination of PR_WOULD_BLOCK_ERROR and SECFailure. + ssl::_SECStatus_SECWouldBlock + } + + unsafe extern "C" fn alert_sent_cb( + fd: *const ssl::PRFileDesc, + arg: *mut c_void, + alert: *const ssl::SSLAlert, + ) { + let alert = alert.as_ref().unwrap(); + if alert.level == 2 { + // Fatal alerts demand attention. + let p = arg as *mut Option<Alert>; + let st = p.as_mut().unwrap(); + if st.is_none() { + *st = Some(alert.description); + } else { + qwarn!( + [format!("{:p}", fd)], + "duplicate alert {}", + alert.description + ); + } + } + } + + // Ready this for connecting. + fn ready(&mut self, is_server: bool) -> Res<()> { + secstatus_to_res(unsafe { + ssl::SSL_AuthCertificateHook( + self.fd, + Some(Self::auth_complete_hook), + as_c_void(&mut self.auth_required), + ) + })?; + + secstatus_to_res(unsafe { + ssl::SSL_AlertSentCallback( + self.fd, + Some(Self::alert_sent_cb), + as_c_void(&mut self.alert), + ) + })?; + + self.now.bind(self.fd)?; + self.configure()?; + secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, ssl::PRBool::from(is_server)) }) + } + + /// Default configuration. + /// + /// # Errors + /// If `set_version_range` fails. + fn configure(&mut self) -> Res<()> { + self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; + self.set_option(ssl::Opt::Locking, false)?; + self.set_option(ssl::Opt::Tickets, false)?; + self.set_option(ssl::Opt::OcspStapling, true)?; + Ok(()) + } + + /// Set the versions that are supported. + /// + /// # Errors + /// If the range of versions isn't supported. + pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> { + let range = ssl::SSLVersionRange { min, max }; + secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) }) + } + + /// Enable a set of ciphers. Note that the order of these is not respected. + /// + /// # Errors + /// If NSS can't enable or disable ciphers. + pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> { + if self.state != HandshakeState::New { + qwarn!([self], "Cannot enable ciphers in state {:?}", self.state); + return Err(Error::InternalError); + } + + let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() }; + let cipher_count = usize::from(unsafe { ssl::SSL_GetNumImplementedCiphers() }); + for i in 0..cipher_count { + let p = all_ciphers.wrapping_add(i); + secstatus_to_res(unsafe { + ssl::SSL_CipherPrefSet(self.fd, i32::from(*p), ssl::PRBool::from(false)) + })?; + } + + for c in ciphers { + secstatus_to_res(unsafe { + ssl::SSL_CipherPrefSet(self.fd, i32::from(*c), ssl::PRBool::from(true)) + })?; + } + Ok(()) + } + + /// Set key exchange groups. + /// + /// # Errors + /// If the underlying API fails (which shouldn't happen). + pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> { + // SSLNamedGroup is a different size to Group, so copy one by one. + let group_vec: Vec<_> = groups + .iter() + .map(|&g| ssl::SSLNamedGroup::Type::from(g)) + .collect(); + + let ptr = group_vec.as_slice().as_ptr(); + secstatus_to_res(unsafe { + ssl::SSL_NamedGroupConfig(self.fd, ptr, c_uint::try_from(group_vec.len())?) + }) + } + + /// Set TLS options. + /// + /// # Errors + /// Returns an error if the option or option value is invalid; i.e., never. + pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> { + opt.set(self.fd, value) + } + + /// Enable 0-RTT. + /// + /// # Errors + /// See `set_option`. + pub fn enable_0rtt(&mut self) -> Res<()> { + self.set_option(ssl::Opt::EarlyData, true) + } + + /// Disable the `EndOfEarlyData` message. + /// + /// # Errors + /// See `set_option`. + pub fn disable_end_of_early_data(&mut self) -> Res<()> { + self.set_option(ssl::Opt::SuppressEndOfEarlyData, true) + } + + /// `set_alpn` sets a list of preferred protocols, starting with the most preferred. + /// Though ALPN [RFC7301] permits octet sequences, this only allows for UTF-8-encoded + /// strings. + /// + /// This asserts if no items are provided, or if any individual item is longer than + /// 255 octets in length. + /// + /// # Errors + /// This should always panic rather than return an error. + pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> { + // Validate and set length. + let mut encoded_len = protocols.len(); + for v in protocols { + assert!(v.as_ref().len() < 256); + encoded_len += v.as_ref().len(); + } + + // Prepare to encode. + let mut encoded = Vec::with_capacity(encoded_len); + let mut add = |v: &str| { + if let Ok(s) = u8::try_from(v.len()) { + encoded.push(s); + encoded.extend_from_slice(v.as_bytes()); + } + }; + + // NSS inherited an idiosyncratic API as a result of having implemented NPN + // before ALPN. For that reason, we need to put the "best" option last. + let (first, rest) = protocols + .split_first() + .expect("at least one ALPN value needed"); + for v in rest { + add(v.as_ref()); + } + add(first.as_ref()); + assert_eq!(encoded_len, encoded.len()); + + // Now give the result to NSS. + secstatus_to_res(unsafe { + ssl::SSL_SetNextProtoNego( + self.fd, + encoded.as_slice().as_ptr(), + c_uint::try_from(encoded.len())?, + ) + }) + } + + /// Install an extension handler. + /// + /// This can be called multiple times with different values for `ext`. The handler is provided as + /// Rc<RefCell<>> so that the caller is able to hold a reference to the handler and later access any + /// state that it accumulates. + /// + /// # Errors + /// When the extension handler can't be successfully installed. + pub fn extension_handler( + &mut self, + ext: Extension, + handler: Rc<RefCell<dyn ExtensionHandler>>, + ) -> Res<()> { + let tracker = unsafe { ExtensionTracker::new(self.fd, ext, handler) }?; + self.extension_handlers.push(tracker); + Ok(()) + } + + // This function tracks whether handshake() or handshake_raw() was used + // and prevents the other from being used. + fn set_raw(&mut self, r: bool) -> Res<()> { + if self.raw.is_none() { + self.secrets.register(self.fd)?; + self.raw = Some(r); + Ok(()) + } else if self.raw.unwrap() == r { + Ok(()) + } else { + Err(Error::MixedHandshakeMethod) + } + } + + /// Get information about the connection. + /// This includes the version, ciphersuite, and ALPN. + /// + /// Calling this function returns None until the connection is complete. + #[must_use] + pub fn info(&self) -> Option<&SecretAgentInfo> { + match self.state { + HandshakeState::Complete(ref info) => Some(info), + _ => None, + } + } + + /// Get any preliminary information about the status of the connection. + /// + /// This includes whether 0-RTT was accepted and any information related to that. + /// Calling this function collects all the relevant information. + /// + /// # Errors + /// When the underlying socket functions fail. + pub fn preinfo(&self) -> Res<SecretAgentPreInfo> { + SecretAgentPreInfo::new(self.fd) + } + + /// Get the peer's certificate chain. + #[must_use] + pub fn peer_certificate(&self) -> Option<CertificateInfo> { + CertificateInfo::new(self.fd) + } + + /// Return any fatal alert that the TLS stack might have sent. + #[must_use] + pub fn alert(&self) -> Option<&Alert> { + (&*self.alert).as_ref() + } + + /// Call this function to mark the peer as authenticated. + /// Only call this function if `handshake/handshake_raw` returns + /// `HandshakeState::AuthenticationPending`, or it will panic. + pub fn authenticated(&mut self, status: AuthenticationStatus) { + assert_eq!(self.state, HandshakeState::AuthenticationPending); + *self.auth_required = false; + self.state = HandshakeState::Authenticated(status.into()); + } + + fn capture_error<T>(&mut self, res: Res<T>) -> Res<T> { + if let Err(e) = &res { + qwarn!([self], "error: {:?}", e); + self.state = HandshakeState::Failed(e.clone()); + } + res + } + + fn update_state(&mut self, res: Res<()>) -> Res<()> { + self.state = if is_blocked(&res) { + if *self.auth_required { + HandshakeState::AuthenticationPending + } else { + HandshakeState::InProgress + } + } else { + self.capture_error(res)?; + let info = self.capture_error(SecretAgentInfo::new(self.fd))?; + HandshakeState::Complete(info) + }; + qinfo!([self], "state -> {:?}", self.state); + Ok(()) + } + + /// Drive the TLS handshake, taking bytes from `input` and putting + /// any bytes necessary into `output`. + /// This takes the current time as `now`. + /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake + /// is complete and how many bytes were written to `output`, respectively. + /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this + /// function if you want to proceed, because this will mark the certificate as OK. + /// + /// # Errors + /// When the handshake fails this returns an error. + pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>> { + self.now.set(now)?; + self.set_raw(false)?; + + let rv = { + // Within this scope, _h maintains a mutable reference to self.io. + let _h = self.io.wrap(input); + match self.state { + HandshakeState::Authenticated(ref err) => unsafe { + ssl::SSL_AuthCertificateComplete(self.fd, *err) + }, + _ => unsafe { ssl::SSL_ForceHandshake(self.fd) }, + } + }; + // Take before updating state so that we leave the output buffer empty + // even if there is an error. + let output = self.io.take_output(); + self.update_state(secstatus_to_res(rv))?; + Ok(output) + } + + /// Setup to receive records for raw handshake functions. + fn setup_raw(&mut self) -> Res<Pin<Box<RecordList>>> { + self.set_raw(true)?; + self.capture_error(RecordList::setup(self.fd)) + } + + /// Drive the TLS handshake, but get the raw content of records, not + /// protected records as bytes. This function is incompatible with + /// `handshake()`; use either this or `handshake()` exclusively. + /// + /// Ideally, this only includes records from the current epoch. + /// If you send data from multiple epochs, you might end up being sad. + /// + /// # Errors + /// When the handshake fails this returns an error. + pub fn handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList> { + self.now.set(now)?; + let records = self.setup_raw()?; + + // Fire off any authentication we might need to complete. + if let HandshakeState::Authenticated(ref err) = self.state { + let result = + secstatus_to_res(unsafe { ssl::SSL_AuthCertificateComplete(self.fd, *err) }); + qdebug!([self], "SSL_AuthCertificateComplete: {:?}", result); + // This should return SECSuccess, so don't use update_state(). + self.capture_error(result)?; + } + + // Feed in any records. + if let Some(rec) = input { + self.capture_error(rec.write(self.fd))?; + } + + // Drive the handshake once more. + let rv = secstatus_to_res(unsafe { ssl::SSL_ForceHandshake(self.fd) }); + self.update_state(rv)?; + + Ok(*Pin::into_inner(records)) + } + + pub fn close(&mut self) { + // It should be safe to close multiple times. + if self.fd.is_null() { + return; + } + if let Some(true) = self.raw { + // Need to hold the record list in scope until the close is done. + let _records = self.setup_raw().expect("Can only close"); + unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) }; + } else { + // Need to hold the IO wrapper in scope until the close is done. + let _io = self.io.wrap(&[]); + unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) }; + }; + let _output = self.io.take_output(); + self.fd = null_mut(); + } + + /// State returns the status of the handshake. + #[must_use] + pub fn state(&self) -> &HandshakeState { + &self.state + } + /// Take a read secret. This will only return a non-`None` value once. + #[must_use] + pub fn read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> { + self.secrets.take_read(epoch) + } + /// Take a write secret. + #[must_use] + pub fn write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> { + self.secrets.take_write(epoch) + } +} + +impl Drop for SecretAgent { + fn drop(&mut self) { + self.close(); + } +} + +impl ::std::fmt::Display for SecretAgent { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Agent {:p}", self.fd) + } +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone)] +pub struct ResumptionToken { + token: Vec<u8>, + expiration_time: Instant, +} + +impl AsRef<[u8]> for ResumptionToken { + fn as_ref(&self) -> &[u8] { + &self.token + } +} + +impl ResumptionToken { + #[must_use] + pub fn new(token: Vec<u8>, expiration_time: Instant) -> Self { + Self { + token, + expiration_time, + } + } + + #[must_use] + pub fn expiration_time(&self) -> Instant { + self.expiration_time + } +} + +/// A TLS Client. +#[derive(Debug)] +#[allow(clippy::box_vec)] // We need the Box. +pub struct Client { + agent: SecretAgent, + + /// Records the resumption tokens we've received. + resumption: Pin<Box<Vec<ResumptionToken>>>, +} + +impl Client { + /// Create a new client agent. + /// + /// # Errors + /// Errors returned if the socket can't be created or configured. + pub fn new(server_name: &str) -> Res<Self> { + let mut agent = SecretAgent::new()?; + let url = CString::new(server_name)?; + secstatus_to_res(unsafe { ssl::SSL_SetURL(agent.fd, url.as_ptr()) })?; + agent.ready(false)?; + let mut client = Self { + agent, + resumption: Box::pin(Vec::new()), + }; + client.ready()?; + Ok(client) + } + + unsafe extern "C" fn resumption_token_cb( + fd: *mut ssl::PRFileDesc, + token: *const u8, + len: c_uint, + arg: *mut c_void, + ) -> ssl::SECStatus { + let mut info: MaybeUninit<ssl::SSLResumptionTokenInfo> = MaybeUninit::uninit(); + if ssl::SSL_GetResumptionTokenInfo( + token, + len, + info.as_mut_ptr(), + c_uint::try_from(mem::size_of::<ssl::SSLResumptionTokenInfo>()).unwrap(), + ) + .is_err() + { + // Ignore the token. + return ssl::SECSuccess; + } + let expiration_time = info.assume_init().expirationTime; + if ssl::SSL_DestroyResumptionTokenInfo(info.as_mut_ptr()).is_err() { + // Ignore the token. + return ssl::SECSuccess; + } + let resumption_ptr = arg as *mut Vec<ResumptionToken>; + let resumption = resumption_ptr.as_mut().unwrap(); + let len = usize::try_from(len).unwrap(); + let mut v = Vec::with_capacity(len); + v.extend_from_slice(std::slice::from_raw_parts(token, len)); + qinfo!( + [format!("{:p}", fd)], + "Got resumption token {}", + hex_snip_middle(&v) + ); + + if resumption.len() >= MAX_TICKETS { + resumption.remove(0); + } + if let Ok(t) = Time::try_from(expiration_time) { + resumption.push(ResumptionToken::new(v, *t)); + } + ssl::SECSuccess + } + + fn ready(&mut self) -> Res<()> { + let fd = self.fd; + unsafe { + ssl::SSL_SetResumptionTokenCallback( + fd, + Some(Self::resumption_token_cb), + as_c_void(&mut self.resumption), + ) + } + } + + /// Take a resumption token. + #[must_use] + pub fn resumption_token(&mut self) -> Option<ResumptionToken> { + (*self.resumption).pop() + } + + /// Check if there are more resumption tokens. + #[must_use] + pub fn has_resumption_token(&self) -> bool { + !(*self.resumption).is_empty() + } + + /// Enable resumption, using a token previously provided. + /// + /// # Errors + /// Error returned when the resumption token is invalid or + /// the socket is not able to use the value. + pub fn enable_resumption(&mut self, token: impl AsRef<[u8]>) -> Res<()> { + unsafe { + ssl::SSL_SetResumptionToken( + self.agent.fd, + token.as_ref().as_ptr(), + c_uint::try_from(token.as_ref().len())?, + ) + } + } +} + +impl Deref for Client { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + &self.agent + } +} + +impl DerefMut for Client { + fn deref_mut(&mut self) -> &mut SecretAgent { + &mut self.agent + } +} + +/// `ZeroRttCheckResult` encapsulates the options for handling a `ClientHello`. +#[derive(Clone, Debug, PartialEq)] +pub enum ZeroRttCheckResult { + /// Accept 0-RTT. + Accept, + /// Reject 0-RTT, but continue the handshake normally. + Reject, + /// Send HelloRetryRequest (probably not needed for QUIC). + HelloRetryRequest(Vec<u8>), + /// Fail the handshake. + Fail, +} + +/// A `ZeroRttChecker` is used by the agent to validate the application token (as provided by `send_ticket`) +pub trait ZeroRttChecker: std::fmt::Debug + std::marker::Unpin { + fn check(&self, token: &[u8]) -> ZeroRttCheckResult; +} + +/// Using `AllowZeroRtt` for the implementation of `ZeroRttChecker` means +/// accepting 0-RTT always. This generally isn't a great idea, so this +/// generates a strong warning when it is used. +#[derive(Debug)] +pub struct AllowZeroRtt {} +impl ZeroRttChecker for AllowZeroRtt { + fn check(&self, _token: &[u8]) -> ZeroRttCheckResult { + qwarn!("AllowZeroRtt accepting 0-RTT"); + ZeroRttCheckResult::Accept + } +} + +#[derive(Debug)] +struct ZeroRttCheckState { + fd: *mut ssl::PRFileDesc, + checker: Pin<Box<dyn ZeroRttChecker>>, +} + +impl ZeroRttCheckState { + pub fn new(fd: *mut ssl::PRFileDesc, checker: Box<dyn ZeroRttChecker>) -> Self { + Self { + fd, + checker: Pin::new(checker), + } + } +} + +#[derive(Debug)] +pub struct Server { + agent: SecretAgent, + /// This holds the HRR callback context. + zero_rtt_check: Option<Pin<Box<ZeroRttCheckState>>>, +} + +impl Server { + /// Create a new server agent. + /// + /// # Errors + /// Errors returned when NSS fails. + pub fn new(certificates: &[impl AsRef<str>]) -> Res<Self> { + let mut agent = SecretAgent::new()?; + + for n in certificates { + let c = CString::new(n.as_ref())?; + let cert = match NonNull::new(unsafe { + p11::PK11_FindCertFromNickname(c.as_ptr(), null_mut()) + }) { + None => return Err(Error::CertificateLoading), + Some(ptr) => p11::Certificate::new(ptr), + }; + let key = match NonNull::new(unsafe { + p11::PK11_FindKeyByAnyCert(*cert.deref(), null_mut()) + }) { + None => return Err(Error::CertificateLoading), + Some(ptr) => p11::PrivateKey::new(ptr), + }; + secstatus_to_res(unsafe { + ssl::SSL_ConfigServerCert(agent.fd, *cert.deref(), *key.deref(), null(), 0) + })?; + } + + agent.ready(true)?; + Ok(Self { + agent, + zero_rtt_check: None, + }) + } + + unsafe extern "C" fn hello_retry_cb( + first_hello: PRBool, + client_token: *const u8, + client_token_len: c_uint, + retry_token: *mut u8, + retry_token_len: *mut c_uint, + retry_token_max: c_uint, + arg: *mut c_void, + ) -> ssl::SSLHelloRetryRequestAction::Type { + if first_hello == 0 { + // On the second ClientHello after HelloRetryRequest, skip checks. + return ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept; + } + + let p = arg as *mut ZeroRttCheckState; + let check_state = p.as_mut().unwrap(); + let token = if client_token.is_null() { + &[] + } else { + std::slice::from_raw_parts(client_token, usize::try_from(client_token_len).unwrap()) + }; + match check_state.checker.check(token) { + ZeroRttCheckResult::Accept => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept, + ZeroRttCheckResult::Fail => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_fail, + ZeroRttCheckResult::Reject => { + ssl::SSLHelloRetryRequestAction::ssl_hello_retry_reject_0rtt + } + ZeroRttCheckResult::HelloRetryRequest(tok) => { + // Don't bother propagating errors from this, because it should be caught in testing. + assert!(tok.len() <= usize::try_from(retry_token_max).unwrap()); + let slc = std::slice::from_raw_parts_mut(retry_token, tok.len()); + slc.copy_from_slice(&tok); + *retry_token_len = c_uint::try_from(tok.len()).unwrap(); + ssl::SSLHelloRetryRequestAction::ssl_hello_retry_request + } + } + } + + /// Enable 0-RTT. This shadows the function of the same name that can be accessed + /// via the Deref implementation on Server. + /// + /// # Errors + /// Returns an error if the underlying NSS functions fail. + pub fn enable_0rtt( + &mut self, + anti_replay: &AntiReplay, + max_early_data: u32, + checker: Box<dyn ZeroRttChecker>, + ) -> Res<()> { + let mut check_state = Box::pin(ZeroRttCheckState::new(self.agent.fd, checker)); + unsafe { + ssl::SSL_HelloRetryRequestCallback( + self.agent.fd, + Some(Self::hello_retry_cb), + as_c_void(&mut check_state), + ) + }?; + unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?; + self.zero_rtt_check = Some(check_state); + self.agent.enable_0rtt()?; + anti_replay.config_socket(self.fd)?; + Ok(()) + } + + /// Send a session ticket to the client. + /// This adds |extra| application-specific content into that ticket. + /// The records that are sent are captured and returned. + /// + /// # Errors + /// If NSS is unable to send a ticket, or if this agent is incorrectly configured. + pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList> { + self.agent.now.set(now)?; + let records = self.setup_raw()?; + + unsafe { + ssl::SSL_SendSessionTicket(self.fd, extra.as_ptr(), c_uint::try_from(extra.len())?) + }?; + + Ok(*Pin::into_inner(records)) + } +} + +impl Deref for Server { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + &self.agent + } +} + +impl DerefMut for Server { + fn deref_mut(&mut self) -> &mut SecretAgent { + &mut self.agent + } +} + +/// A generic container for Client or Server. +#[derive(Debug)] +pub enum Agent { + Client(crate::agent::Client), + Server(crate::agent::Server), +} + +impl Deref for Agent { + type Target = SecretAgent; + #[must_use] + fn deref(&self) -> &SecretAgent { + match self { + Self::Client(c) => &*c, + Self::Server(s) => &*s, + } + } +} + +impl DerefMut for Agent { + fn deref_mut(&mut self) -> &mut SecretAgent { + match self { + Self::Client(c) => &mut *c, + Self::Server(s) => &mut *s, + } + } +} + +impl From<Client> for Agent { + #[must_use] + fn from(c: Client) -> Self { + Self::Client(c) + } +} + +impl From<Server> for Agent { + #[must_use] + fn from(s: Server) -> Self { + Self::Server(s) + } +} diff --git a/third_party/rust/neqo-crypto/src/agentio.rs b/third_party/rust/neqo-crypto/src/agentio.rs new file mode 100644 index 0000000000..bc3cf7e948 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/agentio.rs @@ -0,0 +1,401 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::constants::{ContentType, Epoch}; +use crate::err::{nspr, Error, PR_SetError, Res}; +use crate::prio; +use crate::ssl; + +use neqo_common::{hex, hex_with_len, qtrace}; +use std::cmp::min; +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::mem; +use std::ops::Deref; +use std::os::raw::{c_uint, c_void}; +use std::pin::Pin; +use std::ptr::{null, null_mut}; +use std::vec::Vec; + +// Alias common types. +type PrFd = *mut prio::PRFileDesc; +type PrStatus = prio::PRStatus::Type; +const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS; +const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE; + +/// Convert a pinned, boxed object into a void pointer. +pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { + Pin::into_inner(pin.as_mut()) as *mut T as *mut c_void +} + +// This holds the length of the slice, not the slice itself. +#[derive(Default, Debug)] +struct RecordLength { + epoch: Epoch, + ct: ContentType, + len: usize, +} + +/// A slice of the output. +#[derive(Default)] +pub struct Record { + pub epoch: Epoch, + pub ct: ContentType, + pub data: Vec<u8>, +} + +impl Record { + #[must_use] + pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self { + Self { + epoch, + ct, + data: data.to_vec(), + } + } + + // Shoves this record into the socket, returns true if blocked. + pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> { + qtrace!("write {:?}", self); + unsafe { + ssl::SSL_RecordLayerData( + fd, + self.epoch, + ssl::SSLContentType::Type::from(self.ct), + self.data.as_ptr(), + c_uint::try_from(self.data.len())?, + ) + } + } +} + +impl fmt::Debug for Record { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Record {:?}:{:?} {}", + self.epoch, + self.ct, + hex_with_len(&self.data[..]) + ) + } +} + +#[derive(Debug, Default)] +pub struct RecordList { + records: Vec<Record>, +} + +impl RecordList { + fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) { + self.records.push(Record::new(epoch, ct, data)); + } + + #[allow(clippy::unused_self)] + unsafe extern "C" fn ingest( + _fd: *mut ssl::PRFileDesc, + epoch: ssl::PRUint16, + ct: ssl::SSLContentType::Type, + data: *const ssl::PRUint8, + len: c_uint, + arg: *mut c_void, + ) -> ssl::SECStatus { + let a = arg as *mut Self; + let records = a.as_mut().unwrap(); + + let slice = std::slice::from_raw_parts(data, len as usize); + records.append(epoch, ContentType::try_from(ct).unwrap(), slice); + ssl::SECSuccess + } + + /// Create a new record list. + pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> { + let mut records = Box::pin(Self::default()); + unsafe { + ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records)) + }?; + Ok(records) + } +} + +impl Deref for RecordList { + type Target = Vec<Record>; + #[must_use] + fn deref(&self) -> &Vec<Record> { + &self.records + } +} + +pub struct RecordListIter(std::vec::IntoIter<Record>); + +impl Iterator for RecordListIter { + type Item = Record; + fn next(&mut self) -> Option<Self::Item> { + self.0.next() + } +} + +impl IntoIterator for RecordList { + type Item = Record; + type IntoIter = RecordListIter; + #[must_use] + fn into_iter(self) -> Self::IntoIter { + RecordListIter(self.records.into_iter()) + } +} + +pub struct AgentIoInputContext<'a> { + input: &'a mut AgentIoInput, +} + +impl<'a> Drop for AgentIoInputContext<'a> { + fn drop(&mut self) { + self.input.reset(); + } +} + +#[derive(Debug)] +struct AgentIoInput { + // input is data that is read by TLS. + input: *const u8, + // input_available is how much data is left for reading. + available: usize, +} + +impl AgentIoInput { + fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { + assert!(self.input.is_null()); + self.input = input.as_ptr(); + self.available = input.len(); + qtrace!("AgentIoInput wrap {:p}", self.input); + AgentIoInputContext { input: self } + } + + // Take the data provided as input and provide it to the TLS stack. + fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> { + let amount = min(self.available, count); + if amount == 0 { + unsafe { PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0) }; + return Err(Error::NoDataAvailable); + } + + let src = unsafe { std::slice::from_raw_parts(self.input, amount) }; + qtrace!([self], "read {}", hex(src)); + let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) }; + dst.copy_from_slice(&src); + self.input = self.input.wrapping_add(amount); + self.available -= amount; + Ok(amount) + } + + fn reset(&mut self) { + qtrace!([self], "reset"); + self.input = null(); + self.available = 0; + } +} + +impl ::std::fmt::Display for AgentIoInput { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "AgentIoInput {:p}", self.input) + } +} + +#[derive(Debug)] +pub struct AgentIo { + // input collects the input we might provide to TLS. + input: AgentIoInput, + + // output contains data that is written by TLS. + output: Vec<u8>, +} + +impl AgentIo { + pub fn new() -> Self { + Self { + input: AgentIoInput { + input: null(), + available: 0, + }, + output: Vec::new(), + } + } + + unsafe fn borrow(fd: &mut PrFd) -> &mut Self { + #[allow(clippy::cast_ptr_alignment)] + let io = (**fd).secret as *mut Self; + io.as_mut().unwrap() + } + + pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { + assert_eq!(self.output.len(), 0); + self.input.wrap(input) + } + + // Stage output from TLS into the output buffer. + fn save_output(&mut self, buf: *const u8, count: usize) { + let slice = unsafe { std::slice::from_raw_parts(buf, count) }; + qtrace!([self], "save output {}", hex(slice)); + self.output.extend_from_slice(slice); + } + + pub fn take_output(&mut self) -> Vec<u8> { + qtrace!([self], "take output"); + mem::replace(&mut self.output, Vec::new()) + } +} + +impl ::std::fmt::Display for AgentIo { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "AgentIo") + } +} + +unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus { + (*fd).secret = null_mut(); + if let Some(dtor) = (*fd).dtor { + dtor(fd); + } + PR_SUCCESS +} + +unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus { + let io = AgentIo::borrow(&mut fd); + if let Ok(a) = usize::try_from(amount) { + match io.input.read_input(buf as *mut u8, a) { + Ok(_) => PR_SUCCESS, + Err(_) => PR_FAILURE, + } + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_recv( + mut fd: PrFd, + buf: *mut c_void, + amount: prio::PRInt32, + flags: prio::PRIntn, + _timeout: prio::PRIntervalTime, +) -> prio::PRInt32 { + let io = AgentIo::borrow(&mut fd); + if flags != 0 { + return PR_FAILURE; + } + if let Ok(a) = usize::try_from(amount) { + match io.input.read_input(buf as *mut u8, a) { + Ok(v) => prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE), + Err(_) => PR_FAILURE, + } + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_write( + mut fd: PrFd, + buf: *const c_void, + amount: prio::PRInt32, +) -> PrStatus { + let io = AgentIo::borrow(&mut fd); + if let Ok(a) = usize::try_from(amount) { + io.save_output(buf as *const u8, a); + amount + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_send( + mut fd: PrFd, + buf: *const c_void, + amount: prio::PRInt32, + flags: prio::PRIntn, + _timeout: prio::PRIntervalTime, +) -> prio::PRInt32 { + let io = AgentIo::borrow(&mut fd); + + if flags != 0 { + return PR_FAILURE; + } + if let Ok(a) = usize::try_from(amount) { + io.save_output(buf as *const u8, a); + amount + } else { + PR_FAILURE + } +} + +unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 { + let io = AgentIo::borrow(&mut fd); + io.input.available.try_into().unwrap_or(PR_FAILURE) +} + +unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 { + let io = AgentIo::borrow(&mut fd); + io.input + .available + .try_into() + .unwrap_or_else(|_| PR_FAILURE.into()) +} + +#[allow(clippy::cast_possible_truncation)] +unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus { + let a = addr.as_mut().unwrap(); + // Cast is safe because prio::PR_AF_INET is 2 + a.inet.family = prio::PR_AF_INET as prio::PRUint16; + a.inet.port = 0; + a.inet.ip = 0; + PR_SUCCESS +} + +unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus { + let o = opt.as_mut().unwrap(); + if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking { + o.value.non_blocking = 1; + return PR_SUCCESS; + } + PR_FAILURE +} + +pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods { + file_type: prio::PRDescType::PR_DESC_LAYERED, + close: Some(agent_close), + read: Some(agent_read), + write: Some(agent_write), + available: Some(agent_available), + available64: Some(agent_available64), + fsync: None, + seek: None, + seek64: None, + fileInfo: None, + fileInfo64: None, + writev: None, + connect: None, + accept: None, + bind: None, + listen: None, + shutdown: None, + recv: Some(agent_recv), + send: Some(agent_send), + recvfrom: None, + sendto: None, + poll: None, + acceptread: None, + transmitfile: None, + getsockname: Some(agent_getname), + getpeername: Some(agent_getname), + reserved_fn_6: None, + reserved_fn_5: None, + getsocketoption: Some(agent_getsockopt), + setsocketoption: None, + sendfile: None, + connectcontinue: None, + reserved_fn_3: None, + reserved_fn_2: None, + reserved_fn_1: None, + reserved_fn_0: None, +}; diff --git a/third_party/rust/neqo-crypto/src/auth.rs b/third_party/rust/neqo-crypto/src/auth.rs new file mode 100644 index 0000000000..a1c7de0aa0 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/auth.rs @@ -0,0 +1,100 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::err::{mozpkix, sec, ssl, PRErrorCode}; + +/// The outcome of authentication. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum AuthenticationStatus { + Ok, + CaInvalid, + CaNotV3, + CertAlgorithmDisabled, + CertExpired, + CertInvalidTime, + CertIsCa, + CertKeyUsage, + CertMitm, + CertNotYetValid, + CertRevoked, + CertSelfSigned, + CertSubjectInvalid, + CertUntrusted, + CertWeakKey, + IssuerEmptyName, + IssuerExpired, + IssuerNotYetValid, + IssuerUnknown, + IssuerUntrusted, + PolicyRejection, + Unknown, +} + +impl Into<PRErrorCode> for AuthenticationStatus { + #[must_use] + fn into(self) -> PRErrorCode { + match self { + Self::Ok => 0, + Self::CaInvalid => sec::SEC_ERROR_CA_CERT_INVALID, + Self::CaNotV3 => mozpkix::MOZILLA_PKIX_ERROR_V1_CERT_USED_AS_CA, + Self::CertAlgorithmDisabled => sec::SEC_ERROR_CERT_SIGNATURE_ALGORITHM_DISABLED, + Self::CertExpired => sec::SEC_ERROR_EXPIRED_CERTIFICATE, + Self::CertInvalidTime => sec::SEC_ERROR_INVALID_TIME, + Self::CertIsCa => mozpkix::MOZILLA_PKIX_ERROR_CA_CERT_USED_AS_END_ENTITY, + Self::CertKeyUsage => sec::SEC_ERROR_INADEQUATE_KEY_USAGE, + Self::CertMitm => mozpkix::MOZILLA_PKIX_ERROR_MITM_DETECTED, + Self::CertNotYetValid => mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_CERTIFICATE, + Self::CertRevoked => sec::SEC_ERROR_REVOKED_CERTIFICATE, + Self::CertSelfSigned => mozpkix::MOZILLA_PKIX_ERROR_SELF_SIGNED_CERT, + Self::CertSubjectInvalid => ssl::SSL_ERROR_BAD_CERT_DOMAIN, + Self::CertUntrusted => sec::SEC_ERROR_UNTRUSTED_CERT, + Self::CertWeakKey => mozpkix::MOZILLA_PKIX_ERROR_INADEQUATE_KEY_SIZE, + Self::IssuerEmptyName => mozpkix::MOZILLA_PKIX_ERROR_EMPTY_ISSUER_NAME, + Self::IssuerExpired => sec::SEC_ERROR_EXPIRED_ISSUER_CERTIFICATE, + Self::IssuerNotYetValid => mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_ISSUER_CERTIFICATE, + Self::IssuerUnknown => sec::SEC_ERROR_UNKNOWN_ISSUER, + Self::IssuerUntrusted => sec::SEC_ERROR_UNTRUSTED_ISSUER, + Self::PolicyRejection => { + mozpkix::MOZILLA_PKIX_ERROR_ADDITIONAL_POLICY_CONSTRAINT_FAILED + } + Self::Unknown => sec::SEC_ERROR_LIBRARY_FAILURE, + } + } +} + +// Note that this mapping should be removed after gecko eventually learns how to +// map into the enumerated type. +impl From<PRErrorCode> for AuthenticationStatus { + #[must_use] + fn from(v: PRErrorCode) -> Self { + match v { + 0 => Self::Ok, + sec::SEC_ERROR_CA_CERT_INVALID => Self::CaInvalid, + mozpkix::MOZILLA_PKIX_ERROR_V1_CERT_USED_AS_CA => Self::CaNotV3, + sec::SEC_ERROR_CERT_SIGNATURE_ALGORITHM_DISABLED => Self::CertAlgorithmDisabled, + sec::SEC_ERROR_EXPIRED_CERTIFICATE => Self::CertExpired, + sec::SEC_ERROR_INVALID_TIME => Self::CertInvalidTime, + mozpkix::MOZILLA_PKIX_ERROR_CA_CERT_USED_AS_END_ENTITY => Self::CertIsCa, + sec::SEC_ERROR_INADEQUATE_KEY_USAGE => Self::CertKeyUsage, + mozpkix::MOZILLA_PKIX_ERROR_MITM_DETECTED => Self::CertMitm, + mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_CERTIFICATE => Self::CertNotYetValid, + sec::SEC_ERROR_REVOKED_CERTIFICATE => Self::CertRevoked, + mozpkix::MOZILLA_PKIX_ERROR_SELF_SIGNED_CERT => Self::CertSelfSigned, + ssl::SSL_ERROR_BAD_CERT_DOMAIN => Self::CertSubjectInvalid, + sec::SEC_ERROR_UNTRUSTED_CERT => Self::CertUntrusted, + mozpkix::MOZILLA_PKIX_ERROR_INADEQUATE_KEY_SIZE => Self::CertWeakKey, + mozpkix::MOZILLA_PKIX_ERROR_EMPTY_ISSUER_NAME => Self::IssuerEmptyName, + sec::SEC_ERROR_EXPIRED_ISSUER_CERTIFICATE => Self::IssuerExpired, + mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_ISSUER_CERTIFICATE => Self::IssuerNotYetValid, + sec::SEC_ERROR_UNKNOWN_ISSUER => Self::IssuerUnknown, + sec::SEC_ERROR_UNTRUSTED_ISSUER => Self::IssuerUntrusted, + mozpkix::MOZILLA_PKIX_ERROR_ADDITIONAL_POLICY_CONSTRAINT_FAILED => { + Self::PolicyRejection + } + _ => Self::Unknown, + } + } +} diff --git a/third_party/rust/neqo-crypto/src/cert.rs b/third_party/rust/neqo-crypto/src/cert.rs new file mode 100644 index 0000000000..05bd5ca20e --- /dev/null +++ b/third_party/rust/neqo-crypto/src/cert.rs @@ -0,0 +1,126 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::err::secstatus_to_res; +use crate::p11::{ + CERTCertList, CERTCertListNode, CERT_GetCertificateDer, CertList, PRCList, SECItem, + SECItemArray, SECItemType, +}; +use crate::ssl::{ + PRFileDesc, SSL_PeerCertificateChain, SSL_PeerSignedCertTimestamps, + SSL_PeerStapledOCSPResponses, +}; +use neqo_common::qerror; + +use std::convert::TryFrom; +use std::ptr::{null_mut, NonNull}; + +use std::slice; + +pub struct CertificateInfo { + certs: CertList, + cursor: *const CERTCertListNode, + /// stapled_ocsp_responses and signed_cert_timestamp are properties + /// associated with each of the certificates. Right now, NSS only + /// reports the value for the end-entity certificate (the first). + stapled_ocsp_responses: Option<Vec<Vec<u8>>>, + signed_cert_timestamp: Option<Vec<u8>>, +} + +fn peer_certificate_chain(fd: *mut PRFileDesc) -> Option<(CertList, *const CERTCertListNode)> { + let chain = unsafe { SSL_PeerCertificateChain(fd) }; + let certs = match NonNull::new(chain as *mut CERTCertList) { + Some(certs_ptr) => CertList::new(certs_ptr), + None => return None, + }; + let cursor = CertificateInfo::head(&certs); + Some((certs, cursor)) +} + +// As explained in rfc6961, an OCSPResponseList can have at most +// 2^24 items. Casting its length is therefore safe even on 32 bits targets. +fn stapled_ocsp_responses(fd: *mut PRFileDesc) -> Option<Vec<Vec<u8>>> { + let ocsp_nss = unsafe { SSL_PeerStapledOCSPResponses(fd) }; + match NonNull::new(ocsp_nss as *mut SECItemArray) { + Some(ocsp_ptr) => { + let mut ocsp_helper: Vec<Vec<u8>> = Vec::new(); + let len = if let Ok(l) = isize::try_from(unsafe { ocsp_ptr.as_ref().len }) { + l + } else { + qerror!([format!("{:p}", fd)], "Received illegal OSCP length"); + return None; + }; + for idx in 0..len { + let itemp = unsafe { ocsp_ptr.as_ref().items.offset(idx) as *const SECItem }; + let item = unsafe { slice::from_raw_parts((*itemp).data, (*itemp).len as usize) }; + ocsp_helper.push(item.to_owned()); + } + Some(ocsp_helper) + } + None => None, + } +} + +fn signed_cert_timestamp(fd: *mut PRFileDesc) -> Option<Vec<u8>> { + let sct_nss = unsafe { SSL_PeerSignedCertTimestamps(fd) }; + match NonNull::new(sct_nss as *mut SECItem) { + Some(sct_ptr) => { + let sct_slice = unsafe { + slice::from_raw_parts(sct_ptr.as_ref().data, sct_ptr.as_ref().len as usize) + }; + Some(sct_slice.to_owned()) + } + None => None, + } +} + +impl CertificateInfo { + pub(crate) fn new(fd: *mut PRFileDesc) -> Option<Self> { + match peer_certificate_chain(fd) { + Some((certs, cursor)) => Some(Self { + certs, + cursor, + stapled_ocsp_responses: stapled_ocsp_responses(fd), + signed_cert_timestamp: signed_cert_timestamp(fd), + }), + None => None, + } + } + + fn head(certs: &CertList) -> *const CERTCertListNode { + // Three stars: one for the reference, one for the wrapper, one to deference the pointer. + unsafe { &(***certs).list as *const PRCList as *const CERTCertListNode } + } +} + +impl<'a> Iterator for &'a mut CertificateInfo { + type Item = &'a [u8]; + fn next(&mut self) -> Option<&'a [u8]> { + self.cursor = unsafe { *self.cursor }.links.next as *const CERTCertListNode; + if self.cursor == CertificateInfo::head(&self.certs) { + return None; + } + let mut item = SECItem { + type_: SECItemType::siBuffer, + data: null_mut(), + len: 0, + }; + let cert = unsafe { *self.cursor }.cert; + secstatus_to_res(unsafe { CERT_GetCertificateDer(cert, &mut item) }) + .expect("getting DER from certificate should work"); + Some(unsafe { std::slice::from_raw_parts(item.data, item.len as usize) }) + } +} + +impl CertificateInfo { + pub fn stapled_ocsp_responses(&mut self) -> &Option<Vec<Vec<u8>>> { + &self.stapled_ocsp_responses + } + + pub fn signed_cert_timestamp(&mut self) -> &Option<Vec<u8>> { + &self.signed_cert_timestamp + } +} diff --git a/third_party/rust/neqo-crypto/src/constants.rs b/third_party/rust/neqo-crypto/src/constants.rs new file mode 100644 index 0000000000..48467cb1d1 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/constants.rs @@ -0,0 +1,145 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code)] + +use crate::ssl; + +// Ideally all of these would be enums, but size matters and we need to allow +// for values outside of those that are defined here. + +pub type Alert = u8; + +pub type Epoch = u16; +// TLS doesn't really have an "initial" concept that maps to QUIC so directly, +// but this should be clear enough. +pub const TLS_EPOCH_INITIAL: Epoch = 0 as Epoch; +pub const TLS_EPOCH_ZERO_RTT: Epoch = 1 as Epoch; +pub const TLS_EPOCH_HANDSHAKE: Epoch = 2 as Epoch; +// Also, we don't use TLS epochs > 3. +pub const TLS_EPOCH_APPLICATION_DATA: Epoch = 3 as Epoch; + +/// Rather than defining a type alias and a bunch of constants, which leads to a ton of repetition, +/// use this macro. +macro_rules! remap_enum { + { $t:ident: $s:ty { $( $n:ident = $v:path ),+ $(,)? } } => { + pub type $t = $s; + $( pub const $n: $t = $v as $t; )+ + }; + { $t:ident: $s:ty => $e:ident { $( $n:ident = $v:ident ),+ $(,)? } } => { + remap_enum!{ $t: $s { $( $n = $e::$v ),+ } } + }; + { $t:ident: $s:ty => $p:ident::$e:ident { $( $n:ident = $v:ident ),+ $(,)? } } => { + remap_enum!{ $t: $s { $( $n = $p::$e::$v ),+ } } + }; +} + +remap_enum! { + Version: u16 => ssl { + TLS_VERSION_1_2 = SSL_LIBRARY_VERSION_TLS_1_2, + TLS_VERSION_1_3 = SSL_LIBRARY_VERSION_TLS_1_3, + } +} + +mod ciphers { + include!(concat!(env!("OUT_DIR"), "/nss_ciphers.rs")); +} + +remap_enum! { + Cipher: u16 => ciphers { + TLS_AES_128_GCM_SHA256 = TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384 = TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256 = TLS_CHACHA20_POLY1305_SHA256, + } +} + +remap_enum! { + Group: u16 => ssl::SSLNamedGroup { + TLS_GRP_EC_SECP256R1 = ssl_grp_ec_secp256r1, + TLS_GRP_EC_SECP384R1 = ssl_grp_ec_secp384r1, + TLS_GRP_EC_SECP521R1 = ssl_grp_ec_secp521r1, + TLS_GRP_EC_X25519 = ssl_grp_ec_curve25519, + } +} + +remap_enum! { + HandshakeMessage: u8 => ssl::SSLHandshakeType { + TLS_HS_HELLO_REQUEST = ssl_hs_hello_request, + TLS_HS_CLIENT_HELLO = ssl_hs_client_hello, + TLS_HS_SERVER_HELLO = ssl_hs_server_hello, + TLS_HS_HELLO_VERIFY_REQUEST = ssl_hs_hello_verify_request, + TLS_HS_NEW_SESSION_TICKET = ssl_hs_new_session_ticket, + TLS_HS_END_OF_EARLY_DATA = ssl_hs_end_of_early_data, + TLS_HS_HELLO_RETRY_REQUEST = ssl_hs_hello_retry_request, + TLS_HS_ENCRYPTED_EXTENSIONS = ssl_hs_encrypted_extensions, + TLS_HS_CERTIFICATE = ssl_hs_certificate, + TLS_HS_SERVER_KEY_EXCHANGE = ssl_hs_server_key_exchange, + TLS_HS_CERTIFICATE_REQUEST = ssl_hs_certificate_request, + TLS_HS_SERVER_HELLO_DONE = ssl_hs_server_hello_done, + TLS_HS_CERTIFICATE_VERIFY = ssl_hs_certificate_verify, + TLS_HS_CLIENT_KEY_EXCHANGE = ssl_hs_client_key_exchange, + TLS_HS_FINISHED = ssl_hs_finished, + TLS_HS_CERT_STATUS = ssl_hs_certificate_status, + TLS_HS_KEY_UDPATE = ssl_hs_key_update, + } +} + +remap_enum! { + ContentType: u8 => ssl::SSLContentType { + TLS_CT_CHANGE_CIPHER_SPEC = ssl_ct_change_cipher_spec, + TLS_CT_ALERT = ssl_ct_alert, + TLS_CT_HANDSHAKE = ssl_ct_handshake, + TLS_CT_APPLICATION_DATA = ssl_ct_application_data, + TLS_CT_ACK = ssl_ct_ack, + } +} + +remap_enum! { + Extension: u16 => ssl::SSLExtensionType { + TLS_EXT_SERVER_NAME = ssl_server_name_xtn, + TLS_EXT_CERT_STATUS = ssl_cert_status_xtn, + TLS_EXT_GROUPS = ssl_supported_groups_xtn, + TLS_EXT_EC_POINT_FORMATS = ssl_ec_point_formats_xtn, + TLS_EXT_SIG_SCHEMES = ssl_signature_algorithms_xtn, + TLS_EXT_USE_SRTP = ssl_use_srtp_xtn, + TLS_EXT_ALPN = ssl_app_layer_protocol_xtn, + TLS_EXT_SCT = ssl_signed_cert_timestamp_xtn, + TLS_EXT_PADDING = ssl_padding_xtn, + TLS_EXT_EMS = ssl_extended_master_secret_xtn, + TLS_EXT_RECORD_SIZE = ssl_record_size_limit_xtn, + TLS_EXT_SESSION_TICKET = ssl_session_ticket_xtn, + TLS_EXT_PSK = ssl_tls13_pre_shared_key_xtn, + TLS_EXT_EARLY_DATA = ssl_tls13_early_data_xtn, + TLS_EXT_VERSIONS = ssl_tls13_supported_versions_xtn, + TLS_EXT_COOKIE = ssl_tls13_cookie_xtn, + TLS_EXT_PSK_MODES = ssl_tls13_psk_key_exchange_modes_xtn, + TLS_EXT_CA = ssl_tls13_certificate_authorities_xtn, + TLS_EXT_POST_HS_AUTH = ssl_tls13_post_handshake_auth_xtn, + TLS_EXT_CERT_SIG_SCHEMES = ssl_signature_algorithms_cert_xtn, + TLS_EXT_KEY_SHARE = ssl_tls13_key_share_xtn, + TLS_EXT_RENEGOTIATION_INFO = ssl_renegotiation_info_xtn, + } +} + +remap_enum! { + SignatureScheme: u16 => ssl::SSLSignatureScheme { + TLS_SIG_NONE = ssl_sig_none, + TLS_SIG_RSA_PKCS1_SHA256 = ssl_sig_rsa_pkcs1_sha256, + TLS_SIG_RSA_PKCS1_SHA384 = ssl_sig_rsa_pkcs1_sha384, + TLS_SIG_RSA_PKCS1_SHA512 = ssl_sig_rsa_pkcs1_sha512, + TLS_SIG_ECDSA_SECP256R1_SHA256 = ssl_sig_ecdsa_secp256r1_sha256, + TLS_SIG_ECDSA_SECP384R1_SHA384 = ssl_sig_ecdsa_secp384r1_sha384, + TLS_SIG_ECDSA_SECP512R1_SHA512 = ssl_sig_ecdsa_secp521r1_sha512, + TLS_SIG_RSA_PSS_RSAE_SHA256 = ssl_sig_rsa_pss_rsae_sha256, + TLS_SIG_RSA_PSS_RSAE_SHA384 = ssl_sig_rsa_pss_rsae_sha384, + TLS_SIG_RSA_PSS_RSAE_SHA512 = ssl_sig_rsa_pss_rsae_sha512, + TLS_SIG_ED25519 = ssl_sig_ed25519, + TLS_SIG_ED448 = ssl_sig_ed448, + TLS_SIG_RSA_PSS_PSS_SHA256 = ssl_sig_rsa_pss_pss_sha256, + TLS_SIG_RSA_PSS_PSS_SHA384 = ssl_sig_rsa_pss_pss_sha384, + TLS_SIG_RSA_PSS_PSS_SHA512 = ssl_sig_rsa_pss_pss_sha512, + } +} diff --git a/third_party/rust/neqo-crypto/src/err.rs b/third_party/rust/neqo-crypto/src/err.rs new file mode 100644 index 0000000000..5ae34468c1 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/err.rs @@ -0,0 +1,194 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code)] + +use std::os::raw::c_char; + +use crate::ssl::{SECStatus, SECSuccess}; + +include!(concat!(env!("OUT_DIR"), "/nspr_error.rs")); +mod codes { + #![allow(non_snake_case)] + include!(concat!(env!("OUT_DIR"), "/nss_secerr.rs")); + include!(concat!(env!("OUT_DIR"), "/nss_sslerr.rs")); + include!(concat!(env!("OUT_DIR"), "/mozpkix.rs")); +} +pub use codes::mozilla_pkix_ErrorCode as mozpkix; +pub use codes::SECErrorCodes as sec; +pub use codes::SSLErrorCodes as ssl; +pub mod nspr { + include!(concat!(env!("OUT_DIR"), "/nspr_err.rs")); +} + +pub type Res<T> = Result<T, Error>; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)] +#[allow(clippy::pub_enum_variant_names)] +pub enum Error { + AeadInitFailure, + AeadError, + CertificateLoading, + CreateSslSocket, + HkdfError, + InternalError, + IntegerOverflow, + InvalidEpoch, + MixedHandshakeMethod, + NoDataAvailable, + NssError { + name: String, + code: PRErrorCode, + desc: String, + }, + OverrunError, + SelfEncryptFailure, + TimeTravelError, + UnsupportedCipher, + UnsupportedVersion, +} + +impl std::error::Error for Error { + #[must_use] + fn cause(&self) -> Option<&dyn std::error::Error> { + None + } + #[must_use] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Error: {:?}", self) + } +} + +impl From<std::num::TryFromIntError> for Error { + #[must_use] + fn from(_: std::num::TryFromIntError) -> Self { + Self::IntegerOverflow + } +} +impl From<std::ffi::NulError> for Error { + #[must_use] + fn from(_: std::ffi::NulError) -> Self { + Self::InternalError + } +} + +use std::ffi::CStr; + +fn wrap_str_fn<F>(f: F, dflt: &str) -> String +where + F: FnOnce() -> *const c_char, +{ + unsafe { + let p = f(); + if p.is_null() { + return dflt.to_string(); + } + CStr::from_ptr(p).to_string_lossy().into_owned() + } +} + +pub fn secstatus_to_res(rv: SECStatus) -> Res<()> { + if rv == SECSuccess { + return Ok(()); + } + + let code = unsafe { PR_GetError() }; + let name = wrap_str_fn(|| unsafe { PR_ErrorToName(code) }, "UNKNOWN_ERROR"); + let desc = wrap_str_fn( + || unsafe { PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT) }, + "...", + ); + Err(Error::NssError { name, code, desc }) +} + +pub fn is_blocked(result: &Res<()>) -> bool { + match result { + Err(Error::NssError { code, .. }) => *code == nspr::PR_WOULD_BLOCK_ERROR, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use crate::err::{self, is_blocked, secstatus_to_res, Error, PRErrorCode, PR_SetError}; + use crate::ssl::{SECFailure, SECSuccess}; + use test_fixture::fixture_init; + + fn set_error_code(code: PRErrorCode) { + // This code doesn't work without initializing NSS first. + fixture_init(); + unsafe { PR_SetError(code, 0) }; + } + + #[test] + fn error_code() { + fixture_init(); + assert_eq!(15 - 0x3000, err::ssl::SSL_ERROR_BAD_MAC_READ); + assert_eq!(166 - 0x2000, err::sec::SEC_ERROR_LIBPKIX_INTERNAL); + assert_eq!(-5998, err::nspr::PR_WOULD_BLOCK_ERROR); + } + + #[test] + fn is_ok() { + assert!(secstatus_to_res(SECSuccess).is_ok()); + } + + #[test] + fn is_err() { + set_error_code(err::ssl::SSL_ERROR_BAD_MAC_READ); + let r = secstatus_to_res(SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "SSL_ERROR_BAD_MAC_READ"); + assert_eq!(code, -12273); + assert_eq!( + desc, + "SSL received a record with an incorrect Message Authentication Code." + ); + } + _ => unreachable!(), + } + } + + #[test] + fn is_err_zero_code() { + set_error_code(0); + let r = secstatus_to_res(SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, .. } => { + assert_eq!(name, "UNKNOWN_ERROR"); + assert_eq!(code, 0); + // Note that we don't test |desc| here because that comes from + // strerror(0), which is platform-dependent. + } + _ => unreachable!(), + } + } + + #[test] + fn blocked() { + set_error_code(err::nspr::PR_WOULD_BLOCK_ERROR); + let r = secstatus_to_res(SECFailure); + assert!(r.is_err()); + assert!(is_blocked(&r)); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "PR_WOULD_BLOCK_ERROR"); + assert_eq!(code, -5998); + assert_eq!(desc, "The operation would have blocked"); + } + _ => panic!("bad error type"), + } + } +} diff --git a/third_party/rust/neqo-crypto/src/exp.rs b/third_party/rust/neqo-crypto/src/exp.rs new file mode 100644 index 0000000000..59b6cb0ba5 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/exp.rs @@ -0,0 +1,23 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +macro_rules! experimental_api { + ( $n:ident ( $( $a:ident : $t:ty ),* $(,)? ) ) => { + #[allow(non_snake_case)] + #[allow(clippy::too_many_arguments)] + pub(crate) unsafe fn $n ( $( $a : $t ),* ) -> Result<(), crate::err::Error> { + const EXP_FUNCTION: &str = stringify!($n); + let n = ::std::ffi::CString::new(EXP_FUNCTION)?; + let f = crate::ssl::SSL_GetExperimentalAPI(n.as_ptr()); + if f.is_null() { + return Err(crate::err::Error::InternalError); + } + let f: unsafe extern "C" fn( $( $t ),* ) -> crate::ssl::SECStatus = ::std::mem::transmute(f); + let rv = f( $( $a ),* ); + crate::err::secstatus_to_res(rv) + } + }; +} diff --git a/third_party/rust/neqo-crypto/src/ext.rs b/third_party/rust/neqo-crypto/src/ext.rs new file mode 100644 index 0000000000..30401f7ec1 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/ext.rs @@ -0,0 +1,166 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::agentio::as_c_void; +use crate::constants::{ + Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS, +}; +use crate::err::Res; +use crate::ssl::{ + PRBool, PRFileDesc, SECFailure, SECStatus, SECSuccess, SSLAlertDescription, + SSLExtensionHandler, SSLExtensionWriter, SSLHandshakeType, +}; + +use std::cell::RefCell; +use std::convert::TryFrom; +use std::os::raw::{c_uint, c_void}; +use std::pin::Pin; +use std::rc::Rc; + +experimental_api!(SSL_InstallExtensionHooks( + fd: *mut PRFileDesc, + extension: u16, + writer: SSLExtensionWriter, + writer_arg: *mut c_void, + handler: SSLExtensionHandler, + handler_arg: *mut c_void, +)); + +pub enum ExtensionWriterResult { + Write(usize), + Skip, +} + +pub enum ExtensionHandlerResult { + Ok, + Alert(crate::constants::Alert), +} + +pub trait ExtensionHandler { + fn write(&mut self, msg: HandshakeMessage, _d: &mut [u8]) -> ExtensionWriterResult { + match msg { + TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0), + _ => ExtensionWriterResult::Skip, + } + } + + fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult { + match msg { + TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok, + _ => ExtensionHandlerResult::Alert(110), // unsupported_extension + } + } +} + +type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>; + +pub struct ExtensionTracker { + extension: Extension, + handler: Pin<Box<BoxedExtensionHandler>>, +} + +impl ExtensionTracker { + // Technically the as_mut() call here is the only unsafe bit, + // but don't call this function lightly. + unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T + where + F: FnOnce(&mut dyn ExtensionHandler) -> T, + { + let handler_ptr = arg as *mut BoxedExtensionHandler; + let rc = handler_ptr.as_mut().unwrap(); + f(&mut *rc.borrow_mut()) + } + + #[allow(clippy::cast_possible_truncation)] + unsafe extern "C" fn extension_writer( + _fd: *mut PRFileDesc, + message: SSLHandshakeType::Type, + data: *mut u8, + len: *mut c_uint, + max_len: c_uint, + arg: *mut c_void, + ) -> PRBool { + let d = std::slice::from_raw_parts_mut(data, max_len as usize); + Self::wrap_handler_call(arg, |handler| { + // Cast is safe here because the message type is always part of the enum + match handler.write(message as HandshakeMessage, d) { + ExtensionWriterResult::Write(sz) => { + *len = c_uint::try_from(sz).expect("integer overflow from extension writer"); + 1 + } + ExtensionWriterResult::Skip => 0, + } + }) + } + + unsafe extern "C" fn extension_handler( + _fd: *mut PRFileDesc, + message: SSLHandshakeType::Type, + data: *const u8, + len: c_uint, + alert: *mut SSLAlertDescription, + arg: *mut c_void, + ) -> SECStatus { + let d = std::slice::from_raw_parts(data, len as usize); + #[allow(clippy::cast_possible_truncation)] + Self::wrap_handler_call(arg, |handler| { + // Cast is safe here because the message type is always part of the enum + match handler.handle(message as HandshakeMessage, d) { + ExtensionHandlerResult::Ok => SECSuccess, + ExtensionHandlerResult::Alert(a) => { + *alert = a; + SECFailure + } + } + }) + } + + /// Use the provided handler to manage an extension. This is quite unsafe. + /// + /// # Safety + /// The holder of this `ExtensionTracker` needs to ensure that it lives at + /// least as long as the file descriptor, as NSS provides no way to remove + /// an extension handler once it is configured. + /// + /// # Errors + /// If the underlying NSS API fails to register a handler. + pub unsafe fn new( + fd: *mut PRFileDesc, + extension: Extension, + handler: Rc<RefCell<dyn ExtensionHandler>>, + ) -> Res<Self> { + // The ergonomics here aren't great for users of this API, but it's + // horrific here. The pinned outer box gives us a stable pointer to the inner + // box. This is the pointer that is passed to NSS. + // + // The inner box points to the reference-counted object. This inner box is + // what we end up with a reference to in callbacks. That extra wrapper around + // the Rc avoid any touching of reference counts in callbacks, which would + // inevitably lead to leaks as we don't control how many times the callback + // is invoked. + // + // This way, only this "outer" code deals with the reference count. + let mut tracker = Self { + extension, + handler: Box::pin(Box::new(handler)), + }; + SSL_InstallExtensionHooks( + fd, + extension, + Some(Self::extension_writer), + as_c_void(&mut tracker.handler), + Some(Self::extension_handler), + as_c_void(&mut tracker.handler), + )?; + Ok(tracker) + } +} + +impl std::fmt::Debug for ExtensionTracker { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "ExtensionTracker: {:?}", self.extension) + } +} diff --git a/third_party/rust/neqo-crypto/src/hkdf.rs b/third_party/rust/neqo-crypto/src/hkdf.rs new file mode 100644 index 0000000000..13e1dd5758 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/hkdf.rs @@ -0,0 +1,152 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::constants::{ + Cipher, Version, TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, + TLS_VERSION_1_3, +}; +use crate::err::{Error, Res}; +use crate::p11::{ + random, PK11Origin, PK11SymKey, PK11_GetInternalSlot, PK11_ImportSymKey, SECItem, SECItemType, + Slot, SymKey, CKA_DERIVE, CKM_NSS_HKDF_SHA256, CKM_NSS_HKDF_SHA384, CK_ATTRIBUTE_TYPE, + CK_MECHANISM_TYPE, +}; + +use std::convert::TryFrom; +use std::os::raw::{c_char, c_uchar, c_uint}; +use std::ptr::{null_mut, NonNull}; + +experimental_api!(SSL_HkdfExtract( + version: Version, + cipher: Cipher, + salt: *mut PK11SymKey, + ikm: *mut PK11SymKey, + prk: *mut *mut PK11SymKey, +)); +experimental_api!(SSL_HkdfExpandLabel( + version: Version, + cipher: Cipher, + prk: *mut PK11SymKey, + handshake_hash: *const u8, + handshake_hash_len: c_uint, + label: *const c_char, + label_len: c_uint, + secret: *mut *mut PK11SymKey, +)); + +fn key_size(version: Version, cipher: Cipher) -> Res<usize> { + if version != TLS_VERSION_1_3 { + return Err(Error::UnsupportedVersion); + } + Ok(match cipher { + TLS_AES_128_GCM_SHA256 | TLS_CHACHA20_POLY1305_SHA256 => 32, + TLS_AES_256_GCM_SHA384 => 48, + _ => return Err(Error::UnsupportedCipher), + }) +} + +/// Generate a random key of the right size for the given suite. +/// +/// # Errors +/// Only if NSS fails. +pub fn generate_key(version: Version, cipher: Cipher) -> Res<SymKey> { + import_key(version, cipher, &random(key_size(version, cipher)?)) +} + +/// Import a symmetric key for use with HKDF. +/// +/// # Errors +/// Errors returned if the key buffer is an incompatible size or the NSS functions fail. +pub fn import_key(version: Version, cipher: Cipher, buf: &[u8]) -> Res<SymKey> { + if version != TLS_VERSION_1_3 { + return Err(Error::UnsupportedVersion); + } + let mech = match cipher { + TLS_AES_128_GCM_SHA256 | TLS_CHACHA20_POLY1305_SHA256 => CKM_NSS_HKDF_SHA256, + TLS_AES_256_GCM_SHA384 => CKM_NSS_HKDF_SHA384, + _ => return Err(Error::UnsupportedCipher), + }; + let mut item = SECItem { + type_: SECItemType::siBuffer, + data: buf.as_ptr() as *mut c_uchar, + len: c_uint::try_from(buf.len())?, + }; + let slot_ptr = unsafe { PK11_GetInternalSlot() }; + let slot = match NonNull::new(slot_ptr) { + Some(p) => Slot::new(p), + None => return Err(Error::InternalError), + }; + let key_ptr = unsafe { + PK11_ImportSymKey( + *slot, + CK_MECHANISM_TYPE::from(mech), + PK11Origin::PK11_OriginUnwrap, + CK_ATTRIBUTE_TYPE::from(CKA_DERIVE), + &mut item, + null_mut(), + ) + }; + match NonNull::new(key_ptr) { + Some(p) => Ok(SymKey::new(p)), + None => Err(Error::InternalError), + } +} + +/// Extract a PRK from the given salt and IKM using the algorithm defined in RFC 5869. +/// +/// # Errors +/// Errors returned if inputs are too large or the NSS functions fail. +pub fn extract( + version: Version, + cipher: Cipher, + salt: Option<&SymKey>, + ikm: &SymKey, +) -> Res<SymKey> { + let mut prk: *mut PK11SymKey = null_mut(); + let salt_ptr: *mut PK11SymKey = match salt { + Some(s) => **s, + None => null_mut(), + }; + unsafe { SSL_HkdfExtract(version, cipher, salt_ptr, **ikm, &mut prk) }?; + match NonNull::new(prk) { + Some(p) => Ok(SymKey::new(p)), + None => Err(Error::InternalError), + } +} + +/// Expand a PRK using the HKDF-Expand-Label function defined in RFC 8446. +/// +/// # Errors +/// Errors returned if inputs are too large or the NSS functions fail. +pub fn expand_label( + version: Version, + cipher: Cipher, + prk: &SymKey, + handshake_hash: &[u8], + label: &str, +) -> Res<SymKey> { + let l = label.as_bytes(); + let mut secret: *mut PK11SymKey = null_mut(); + + // Note that this doesn't allow for passing null() for the handshake hash. + // A zero-length slice produces an identical result. + unsafe { + SSL_HkdfExpandLabel( + version, + cipher, + **prk, + handshake_hash.as_ptr(), + c_uint::try_from(handshake_hash.len())?, + l.as_ptr() as *const c_char, + c_uint::try_from(l.len())?, + &mut secret, + ) + }?; + match NonNull::new(secret) { + Some(p) => Ok(SymKey::new(p)), + None => Err(Error::HkdfError), + } +} diff --git a/third_party/rust/neqo-crypto/src/hp.rs b/third_party/rust/neqo-crypto/src/hp.rs new file mode 100644 index 0000000000..e64fb0dcfe --- /dev/null +++ b/third_party/rust/neqo-crypto/src/hp.rs @@ -0,0 +1,133 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::constants::{ + Cipher, Version, TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, +}; +use crate::err::{secstatus_to_res, Error, Res}; +use crate::p11::{ + PK11SymKey, PK11_Encrypt, PK11_GetBlockSize, PK11_GetMechanism, SECItem, SECItemType, SymKey, + CKM_AES_ECB, CKM_NSS_CHACHA20_CTR, CK_MECHANISM_TYPE, +}; + +use std::convert::TryFrom; +use std::fmt::{self, Debug}; +use std::os::raw::{c_char, c_uint}; +use std::ptr::{null, null_mut, NonNull}; + +experimental_api!(SSL_HkdfExpandLabelWithMech( + version: Version, + cipher: Cipher, + prk: *mut PK11SymKey, + handshake_hash: *const u8, + handshake_hash_len: c_uint, + label: *const c_char, + label_len: c_uint, + mech: CK_MECHANISM_TYPE, + key_size: c_uint, + secret: *mut *mut PK11SymKey, +)); + +#[derive(Clone)] +pub struct HpKey(SymKey); + +impl Debug for HpKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "HP-{:?}", self.0) + } +} + +impl HpKey { + /// QUIC-specific API for extracting a header-protection key. + /// + /// # Errors + /// Errors if HKDF fails or if the label is too long to fit in a `c_uint`. + pub fn extract(version: Version, cipher: Cipher, prk: &SymKey, label: &str) -> Res<Self> { + let l = label.as_bytes(); + let mut secret: *mut PK11SymKey = null_mut(); + + let (mech, key_size) = match cipher { + TLS_AES_128_GCM_SHA256 => (CK_MECHANISM_TYPE::from(CKM_AES_ECB), 16), + TLS_AES_256_GCM_SHA384 => (CK_MECHANISM_TYPE::from(CKM_AES_ECB), 32), + TLS_CHACHA20_POLY1305_SHA256 => (CK_MECHANISM_TYPE::from(CKM_NSS_CHACHA20_CTR), 32), + _ => unreachable!(), + }; + + // Note that this doesn't allow for passing null() for the handshake hash. + // A zero-length slice produces an identical result. + unsafe { + SSL_HkdfExpandLabelWithMech( + version, + cipher, + **prk, + null(), + 0, + l.as_ptr() as *const c_char, + c_uint::try_from(l.len())?, + mech, + key_size, + &mut secret, + ) + }?; + match NonNull::new(secret) { + None => Err(Error::HkdfError), + Some(p) => Ok(Self(SymKey::new(p))), + } + } + + /// Get the sample size, which is also the output size. + #[allow(clippy::cast_sign_loss)] + #[must_use] + pub fn sample_size(&self) -> usize { + let k: *mut PK11SymKey = *self.0; + let mech = unsafe { PK11_GetMechanism(k) }; + // Cast is safe because block size is always greater than or equal to 0 + (unsafe { PK11_GetBlockSize(mech, null_mut()) }) as usize + } + + /// Generate a header protection mask for QUIC. + /// + /// # Errors + /// An error is returned if the NSS functions fail; a sample of the + /// wrong size is the obvious cause. + pub fn mask(&self, sample: &[u8]) -> Res<Vec<u8>> { + let k: *mut PK11SymKey = *self.0; + let mech = unsafe { PK11_GetMechanism(k) }; + let block_size = self.sample_size(); + + let mut output = vec![0_u8; block_size]; + let output_slice = &mut output[..]; + let mut output_len: c_uint = 0; + + let mut item = SECItem { + type_: SECItemType::siBuffer, + data: sample.as_ptr() as *mut u8, + len: c_uint::try_from(sample.len())?, + }; + let zero = vec![0_u8; block_size]; + let (iv, inbuf) = match () { + _ if mech == CK_MECHANISM_TYPE::from(CKM_AES_ECB) => (null_mut(), sample), + _ if mech == CK_MECHANISM_TYPE::from(CKM_NSS_CHACHA20_CTR) => { + (&mut item as *mut SECItem, &zero[..]) + } + _ => unreachable!(), + }; + secstatus_to_res(unsafe { + PK11_Encrypt( + k, + mech, + iv, + output_slice.as_mut_ptr(), + &mut output_len, + c_uint::try_from(output.len())?, + inbuf.as_ptr() as *const u8, + c_uint::try_from(inbuf.len())?, + ) + })?; + assert_eq!(output_len as usize, block_size); + Ok(output) + } +} diff --git a/third_party/rust/neqo-crypto/src/lib.rs b/third_party/rust/neqo-crypto/src/lib.rs new file mode 100644 index 0000000000..92f297f13f --- /dev/null +++ b/third_party/rust/neqo-crypto/src/lib.rs @@ -0,0 +1,167 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![cfg_attr(feature = "deny-warnings", deny(warnings))] +#![warn(clippy::pedantic)] +// Bindgen auto generated code +// won't adhere to the clippy rules below +#![allow(clippy::module_name_repetitions)] +#![allow(clippy::unseparated_literal_suffix)] +#![allow(clippy::used_underscore_binding)] + +#[macro_use] +mod exp; +#[macro_use] +mod p11; + +pub mod aead; +pub mod agent; +mod agentio; +mod auth; +mod cert; +pub mod constants; +mod err; +pub mod ext; +pub mod hkdf; +pub mod hp; +mod once; +mod prio; +mod replay; +mod secrets; +pub mod selfencrypt; +mod ssl; +mod time; + +pub use self::agent::{ + Agent, AllowZeroRtt, Client, HandshakeState, Record, RecordList, ResumptionToken, SecretAgent, + SecretAgentInfo, SecretAgentPreInfo, Server, ZeroRttCheckResult, ZeroRttChecker, +}; +pub use self::auth::AuthenticationStatus; +pub use self::constants::*; +pub use self::err::{Error, PRErrorCode, Res}; +pub use self::ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult}; +pub use self::p11::{random, SymKey}; +pub use self::replay::AntiReplay; +pub use self::secrets::SecretDirection; +pub use self::ssl::Opt; + +use self::once::OnceResult; + +use std::ffi::CString; +use std::os::raw::c_char; +use std::path::{Path, PathBuf}; +use std::ptr::null; + +mod nss { + #![allow(clippy::redundant_static_lifetimes, non_upper_case_globals)] + include!(concat!(env!("OUT_DIR"), "/nss_init.rs")); +} + +// Need to map the types through. +fn secstatus_to_res(code: nss::SECStatus) -> Res<()> { + crate::err::secstatus_to_res(code as crate::ssl::SECStatus) +} + +enum NssLoaded { + External, + NoDb, + Db(Box<Path>), +} + +impl Drop for NssLoaded { + fn drop(&mut self) { + match self { + Self::NoDb | Self::Db(_) => unsafe { + secstatus_to_res(nss::NSS_Shutdown()).expect("NSS Shutdown failed") + }, + _ => {} + } + } +} + +static mut INITIALIZED: OnceResult<NssLoaded> = OnceResult::new(); + +fn already_initialized() -> bool { + unsafe { nss::NSS_IsInitialized() != 0 } +} + +/// Initialize NSS. This only executes the initialization routines once, so if there is any chance that +pub fn init() { + // Set time zero. + time::init(); + unsafe { + INITIALIZED.call_once(|| { + if already_initialized() { + return NssLoaded::External; + } + + secstatus_to_res(nss::NSS_NoDB_Init(null())).expect("NSS_NoDB_Init failed"); + secstatus_to_res(nss::NSS_SetDomesticPolicy()).expect("NSS_SetDomesticPolicy failed"); + + NssLoaded::NoDb + }); + } +} + +/// This enables SSLTRACE by calling a simple, harmless function to trigger its +/// side effects. SSLTRACE is not enabled in NSS until a socket is made or +/// global options are accessed. Reading an option is the least impact approach. +/// This allows us to use SSLTRACE in all of our unit tests and programs. +#[cfg(debug_assertions)] +fn enable_ssl_trace() { + let opt = ssl::Opt::Locking.as_int(); + let mut _v: ::std::os::raw::c_int = 0; + secstatus_to_res(unsafe { ssl::SSL_OptionGetDefault(opt, &mut _v) }) + .expect("SSL_OptionGetDefault failed"); +} + +pub fn init_db<P: Into<PathBuf>>(dir: P) { + time::init(); + unsafe { + INITIALIZED.call_once(|| { + if already_initialized() { + return NssLoaded::External; + } + + let path = dir.into(); + assert!(path.is_dir()); + let pathstr = path.to_str().expect("path converts to string").to_string(); + let dircstr = CString::new(pathstr).expect("new CString"); + let empty = CString::new("").expect("new empty CString"); + secstatus_to_res(nss::NSS_Initialize( + dircstr.as_ptr(), + empty.as_ptr(), + empty.as_ptr(), + nss::SECMOD_DB.as_ptr() as *const c_char, + nss::NSS_INIT_READONLY, + )) + .expect("NSS_Initialize failed"); + + secstatus_to_res(nss::NSS_SetDomesticPolicy()).expect("NSS_SetDomesticPolicy failed"); + secstatus_to_res(ssl::SSL_ConfigServerSessionIDCache( + 1024, + 0, + 0, + dircstr.as_ptr(), + )) + .expect("SSL_ConfigServerSessionIDCache failed"); + + #[cfg(debug_assertions)] + enable_ssl_trace(); + + NssLoaded::Db(path.into_boxed_path()) + }); + } +} + +/// Panic if NSS isn't initialized. +pub fn assert_initialized() { + unsafe { + INITIALIZED.call_once(|| { + panic!("NSS not initialized with init or init_db"); + }); + } +} diff --git a/third_party/rust/neqo-crypto/src/once.rs b/third_party/rust/neqo-crypto/src/once.rs new file mode 100644 index 0000000000..80657cfe26 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/once.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::sync::Once; + +#[allow(clippy::module_name_repetitions)] +pub struct OnceResult<T> { + once: Once, + v: Option<T>, +} + +impl<T> OnceResult<T> { + #[must_use] + pub const fn new() -> Self { + Self { + once: Once::new(), + v: None, + } + } + + pub fn call_once<F: FnOnce() -> T>(&mut self, f: F) -> &T { + let v = &mut self.v; + self.once.call_once(|| { + *v = Some(f()); + }); + self.v.as_ref().unwrap() + } +} + +#[cfg(test)] +mod test { + use super::OnceResult; + + static mut STATIC_ONCE_RESULT: OnceResult<u64> = OnceResult::new(); + + #[test] + fn static_update() { + assert_eq!(*unsafe { STATIC_ONCE_RESULT.call_once(|| 23) }, 23); + assert_eq!(*unsafe { STATIC_ONCE_RESULT.call_once(|| 24) }, 23); + } +} diff --git a/third_party/rust/neqo-crypto/src/p11.rs b/third_party/rust/neqo-crypto/src/p11.rs new file mode 100644 index 0000000000..add5e2ccfe --- /dev/null +++ b/third_party/rust/neqo-crypto/src/p11.rs @@ -0,0 +1,126 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code)] +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +use crate::err::{secstatus_to_res, Error, Res}; + +use neqo_common::hex_with_len; + +use std::convert::TryInto; +use std::ops::{Deref, DerefMut}; +use std::ptr::NonNull; + +#[allow(clippy::unreadable_literal)] +mod nss_p11 { + include!(concat!(env!("OUT_DIR"), "/nss_p11.rs")); +} + +pub use nss_p11::*; + +macro_rules! scoped_ptr { + ($scoped:ident, $target:ty, $dtor:path) => { + pub struct $scoped { + ptr: *mut $target, + } + + impl $scoped { + #[must_use] + pub fn new(ptr: NonNull<$target>) -> Self { + Self { ptr: ptr.as_ptr() } + } + } + + impl Deref for $scoped { + type Target = *mut $target; + #[must_use] + fn deref(&self) -> &*mut $target { + &self.ptr + } + } + + impl DerefMut for $scoped { + fn deref_mut(&mut self) -> &mut *mut $target { + &mut self.ptr + } + } + + impl Drop for $scoped { + fn drop(&mut self) { + let _ = unsafe { $dtor(self.ptr) }; + } + } + }; +} + +scoped_ptr!(Certificate, CERTCertificate, CERT_DestroyCertificate); +scoped_ptr!(CertList, CERTCertList, CERT_DestroyCertList); +scoped_ptr!(PrivateKey, SECKEYPrivateKey, SECKEY_DestroyPrivateKey); +scoped_ptr!(SymKey, PK11SymKey, PK11_FreeSymKey); +scoped_ptr!(Slot, PK11SlotInfo, PK11_FreeSlot); + +impl SymKey { + /// You really don't want to use this. + /// + /// # Errors + /// Internal errors in case of failures in NSS. + pub fn as_bytes<'a>(&'a self) -> Res<&'a [u8]> { + secstatus_to_res(unsafe { PK11_ExtractKeyValue(self.ptr) })?; + + let key_item = unsafe { PK11_GetKeyData(self.ptr) }; + // This is accessing a value attached to the key, so we can treat this as a borrow. + match unsafe { key_item.as_mut() } { + None => Err(Error::InternalError), + Some(key) => Ok(unsafe { std::slice::from_raw_parts(key.data, key.len as usize) }), + } + } +} + +impl Clone for SymKey { + #[must_use] + fn clone(&self) -> Self { + let ptr = unsafe { PK11_ReferenceSymKey(self.ptr) }; + assert!(!ptr.is_null()); + Self { ptr } + } +} + +impl std::fmt::Debug for SymKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if let Ok(b) = self.as_bytes() { + write!(f, "SymKey {}", hex_with_len(b)) + } else { + write!(f, "Opaque SymKey") + } + } +} + +/// Generate a randomized buffer. +#[must_use] +pub fn random(size: usize) -> Vec<u8> { + let mut buf = vec![0; size]; + secstatus_to_res(unsafe { + PK11_GenerateRandom(buf.as_mut_ptr(), buf.len().try_into().unwrap()) + }) + .unwrap(); + buf +} + +#[cfg(test)] +mod test { + use super::random; + use test_fixture::fixture_init; + + #[test] + fn randomness() { + fixture_init(); + // If this ever fails, there is either a bug, or it's time to buy a lottery ticket. + assert_ne!(random(16), random(16)); + } +} diff --git a/third_party/rust/neqo-crypto/src/prio.rs b/third_party/rust/neqo-crypto/src/prio.rs new file mode 100644 index 0000000000..c561277b0c --- /dev/null +++ b/third_party/rust/neqo-crypto/src/prio.rs @@ -0,0 +1,20 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code, non_upper_case_globals, non_snake_case)] +#![allow( + clippy::cognitive_complexity, + clippy::empty_enum, + clippy::too_many_lines +)] + +include!(concat!(env!("OUT_DIR"), "/nspr_io.rs")); + +pub enum PRFileInfo {} +pub enum PRFileInfo64 {} +pub enum PRFilePrivate {} +pub enum PRIOVec {} +pub enum PRSendFileData {} diff --git a/third_party/rust/neqo-crypto/src/replay.rs b/third_party/rust/neqo-crypto/src/replay.rs new file mode 100644 index 0000000000..7e57e3b66f --- /dev/null +++ b/third_party/rust/neqo-crypto/src/replay.rs @@ -0,0 +1,80 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::err::{Error, Res}; +use crate::ssl::PRFileDesc; +use crate::time::{Interval, PRTime, Time}; + +use std::convert::{TryFrom, TryInto}; +use std::ops::{Deref, DerefMut}; +use std::os::raw::c_uint; +use std::ptr::{null_mut, NonNull}; +use std::time::{Duration, Instant}; + +// This is an opaque struct in NSS. +#[allow(clippy::empty_enum)] +pub enum SSLAntiReplayContext {} + +experimental_api!(SSL_CreateAntiReplayContext( + now: PRTime, + window: PRTime, + k: c_uint, + bits: c_uint, + ctx: *mut *mut SSLAntiReplayContext, +)); +experimental_api!(SSL_ReleaseAntiReplayContext(ctx: *mut SSLAntiReplayContext)); +experimental_api!(SSL_SetAntiReplayContext( + fd: *mut PRFileDesc, + ctx: *mut SSLAntiReplayContext, +)); + +scoped_ptr!( + AntiReplayContext, + SSLAntiReplayContext, + SSL_ReleaseAntiReplayContext +); + +/// `AntiReplay` is used by servers when processing 0-RTT handshakes. +/// It limits the exposure of servers to replay attack by rejecting 0-RTT +/// if it appears to be a replay. There is a false-positive rate that can be +/// managed by tuning the parameters used to create the context. +#[allow(clippy::module_name_repetitions)] +pub struct AntiReplay { + ctx: AntiReplayContext, +} + +impl AntiReplay { + /// Make a new anti-replay context. + /// See the documentation in NSS for advice on how to set these values. + /// + /// # Errors + /// Returns an error if `now` is in the past relative to our baseline or + /// NSS is unable to generate an anti-replay context. + pub fn new(now: Instant, window: Duration, k: usize, bits: usize) -> Res<Self> { + let mut ctx: *mut SSLAntiReplayContext = null_mut(); + unsafe { + SSL_CreateAntiReplayContext( + Time::from(now).try_into()?, + Interval::from(window).try_into()?, + c_uint::try_from(k)?, + c_uint::try_from(bits)?, + &mut ctx, + ) + }?; + + match NonNull::new(ctx) { + Some(ctx_nn) => Ok(Self { + ctx: AntiReplayContext::new(ctx_nn), + }), + None => Err(Error::InternalError), + } + } + + /// Configure the provided socket with this anti-replay context. + pub(crate) fn config_socket(&self, fd: *mut PRFileDesc) -> Res<()> { + unsafe { SSL_SetAntiReplayContext(fd, *self.ctx) } + } +} diff --git a/third_party/rust/neqo-crypto/src/result.rs b/third_party/rust/neqo-crypto/src/result.rs new file mode 100644 index 0000000000..e6148dd054 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/result.rs @@ -0,0 +1,133 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::err::{ + nspr, Error, PR_ErrorToName, PR_ErrorToString, PR_GetError, Res, PR_LANGUAGE_I_DEFAULT, +}; +use crate::ssl; + +use std::ffi::CStr; + +pub fn result(rv: ssl::SECStatus) -> Res<()> { + let _ = result_helper(rv, false)?; + Ok(()) +} + +pub fn result_or_blocked(rv: ssl::SECStatus) -> Res<bool> { + result_helper(rv, true) +} + +fn wrap_str_fn<F>(f: F, dflt: &str) -> String +where + F: FnOnce() -> *const i8, +{ + unsafe { + let p = f(); + if p.is_null() { + return dflt.to_string(); + } + CStr::from_ptr(p).to_string_lossy().into_owned() + } +} + +fn result_helper(rv: ssl::SECStatus, allow_blocked: bool) -> Res<bool> { + if rv == ssl::_SECStatus_SECSuccess { + return Ok(false); + } + + let code = unsafe { PR_GetError() }; + if allow_blocked && code == nspr::PR_WOULD_BLOCK_ERROR { + return Ok(true); + } + + let name = wrap_str_fn(|| unsafe { PR_ErrorToName(code) }, "UNKNOWN_ERROR"); + let desc = wrap_str_fn( + || unsafe { PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT) }, + "...", + ); + Err(Error::NssError { name, code, desc }) +} + +#[cfg(test)] +mod tests { + use super::{result, result_or_blocked}; + use crate::err::{self, nspr, Error, PRErrorCode, PR_SetError}; + use crate::ssl; + use test_fixture::fixture_init; + + fn set_error_code(code: PRErrorCode) { + unsafe { PR_SetError(code, 0) }; + } + + #[test] + fn is_ok() { + assert!(result(ssl::SECSuccess).is_ok()); + } + + #[test] + fn is_err() { + // This code doesn't work without initializing NSS first. + fixture_init(); + + set_error_code(err::ssl::SSL_ERROR_BAD_MAC_READ); + let r = result(ssl::SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "SSL_ERROR_BAD_MAC_READ"); + assert_eq!(code, -12273); + assert_eq!( + desc, + "SSL received a record with an incorrect Message Authentication Code." + ); + } + _ => unreachable!(), + } + } + + #[test] + fn is_err_zero_code() { + // This code doesn't work without initializing NSS first. + fixture_init(); + + set_error_code(0); + let r = result(ssl::SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, .. } => { + assert_eq!(name, "UNKNOWN_ERROR"); + assert_eq!(code, 0); + // Note that we don't test |desc| here because that comes from + // strerror(0), which is platform-dependent. + } + _ => unreachable!(), + } + } + + #[test] + fn blocked_as_error() { + // This code doesn't work without initializing NSS first. + fixture_init(); + + set_error_code(nspr::PR_WOULD_BLOCK_ERROR); + let r = result(ssl::SECFailure); + assert!(r.is_err()); + match r.unwrap_err() { + Error::NssError { name, code, desc } => { + assert_eq!(name, "PR_WOULD_BLOCK_ERROR"); + assert_eq!(code, -5998); + assert_eq!(desc, "The operation would have blocked"); + } + _ => panic!("bad error type"), + } + } + + #[test] + fn is_blocked() { + set_error_code(nspr::PR_WOULD_BLOCK_ERROR); + assert!(result_or_blocked(ssl::SECFailure).unwrap()); + } +} diff --git a/third_party/rust/neqo-crypto/src/secrets.rs b/third_party/rust/neqo-crypto/src/secrets.rs new file mode 100644 index 0000000000..7353fbfa47 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/secrets.rs @@ -0,0 +1,132 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::agentio::as_c_void; +use crate::constants::Epoch; +use crate::err::Res; +use crate::p11::{PK11SymKey, PK11_ReferenceSymKey, SymKey}; +use crate::ssl::{PRFileDesc, SSLSecretCallback, SSLSecretDirection}; + +use neqo_common::qdebug; +use std::os::raw::c_void; +use std::pin::Pin; +use std::ptr::NonNull; + +experimental_api!(SSL_SecretCallback( + fd: *mut PRFileDesc, + cb: SSLSecretCallback, + arg: *mut c_void, +)); + +#[derive(Clone, Copy, Debug)] +pub enum SecretDirection { + Read, + Write, +} + +impl From<SSLSecretDirection::Type> for SecretDirection { + #[must_use] + fn from(dir: SSLSecretDirection::Type) -> Self { + match dir { + SSLSecretDirection::ssl_secret_read => Self::Read, + SSLSecretDirection::ssl_secret_write => Self::Write, + _ => unreachable!(), + } + } +} + +#[derive(Debug, Default)] +#[allow(clippy::module_name_repetitions)] +pub struct DirectionalSecrets { + // We only need to maintain 3 secrets for the epochs used during the handshake. + secrets: [Option<SymKey>; 3], +} + +impl DirectionalSecrets { + fn put(&mut self, epoch: Epoch, key: SymKey) { + assert!(epoch > 0); + let i = (epoch - 1) as usize; + assert!(i < self.secrets.len()); + // assert!(self.secrets[i].is_none()); + self.secrets[i] = Some(key); + } + + pub fn take(&mut self, epoch: Epoch) -> Option<SymKey> { + assert!(epoch > 0); + let i = (epoch - 1) as usize; + assert!(i < self.secrets.len()); + self.secrets[i].take() + } +} + +#[derive(Debug, Default)] +pub struct Secrets { + r: DirectionalSecrets, + w: DirectionalSecrets, +} + +impl Secrets { + #[allow(clippy::unused_self)] + unsafe extern "C" fn secret_available( + _fd: *mut PRFileDesc, + epoch: u16, + dir: SSLSecretDirection::Type, + secret: *mut PK11SymKey, + arg: *mut c_void, + ) { + let secrets_ptr = arg as *mut Self; + let secrets = secrets_ptr.as_mut().unwrap(); + secrets.put_raw(epoch, dir, secret); + } + + fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) { + let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) }; + let key = match NonNull::new(key_ptr) { + None => panic!("NSS shouldn't be passing out NULL secrets"), + Some(p) => SymKey::new(p), + }; + self.put(SecretDirection::from(dir), epoch, key); + } + + fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) { + qdebug!("{:?} secret available for {:?}", dir, epoch); + let keys = match dir { + SecretDirection::Read => &mut self.r, + SecretDirection::Write => &mut self.w, + }; + keys.put(epoch, key); + } +} + +#[derive(Debug)] +pub struct SecretHolder { + secrets: Pin<Box<Secrets>>, +} + +impl SecretHolder { + /// This registers with NSS. The lifetime of this object needs to match the lifetime + /// of the connection, or bad things might happen. + pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> { + let p = as_c_void(&mut self.secrets); + unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) } + } + + pub fn take_read(&mut self, epoch: Epoch) -> Option<SymKey> { + self.secrets.r.take(epoch) + } + + pub fn take_write(&mut self, epoch: Epoch) -> Option<SymKey> { + self.secrets.w.take(epoch) + } +} + +impl Default for SecretHolder { + fn default() -> Self { + Self { + secrets: Box::pin(Secrets::default()), + } + } +} diff --git a/third_party/rust/neqo-crypto/src/selfencrypt.rs b/third_party/rust/neqo-crypto/src/selfencrypt.rs new file mode 100644 index 0000000000..ef76e44ee3 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/selfencrypt.rs @@ -0,0 +1,155 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::aead::Aead; +use crate::constants::{Cipher, Version}; +use crate::err::{Error, Res}; +use crate::hkdf; +use crate::p11::{random, SymKey}; + +use neqo_common::{hex, qinfo, qtrace, Encoder}; + +use std::mem; + +#[derive(Debug)] +pub struct SelfEncrypt { + version: Version, + cipher: Cipher, + key_id: u8, + key: SymKey, + old_key: Option<SymKey>, +} + +impl SelfEncrypt { + const VERSION: u8 = 1; + const SALT_LENGTH: usize = 16; + + /// # Errors + /// Failure to generate a new HKDF key using NSS results in an error. + pub fn new(version: Version, cipher: Cipher) -> Res<Self> { + let key = hkdf::generate_key(version, cipher)?; + Ok(Self { + version, + cipher, + key_id: 0, + key, + old_key: None, + }) + } + + fn make_aead(&self, k: &SymKey, salt: &[u8]) -> Res<Aead> { + debug_assert_eq!(salt.len(), Self::SALT_LENGTH); + let salt = hkdf::import_key(self.version, self.cipher, salt)?; + let secret = hkdf::extract(self.version, self.cipher, Some(&salt), k)?; + Aead::new(self.version, self.cipher, &secret, "neqo self") + } + + /// Rotate keys. This causes any previous key that is being held to be replaced by the current key. + /// + /// # Errors + /// Failure to generate a new HKDF key using NSS results in an error. + pub fn rotate(&mut self) -> Res<()> { + let new_key = hkdf::generate_key(self.version, self.cipher)?; + self.old_key = Some(mem::replace(&mut self.key, new_key)); + let (kid, _) = self.key_id.overflowing_add(1); + self.key_id = kid; + qinfo!(["SelfEncrypt"], "Rotated keys to {}", self.key_id); + Ok(()) + } + + /// Seal an item using the underlying key. This produces a single buffer that contains + /// the encrypted `plaintext`, plus a version number and salt. + /// `aad` is only used as input to the AEAD, it is not included in the output; the + /// caller is responsible for carrying the AAD as appropriate. + /// + /// # Errors + /// Failure to protect using NSS AEAD APIs produces an error. + pub fn seal(&self, aad: &[u8], plaintext: &[u8]) -> Res<Vec<u8>> { + // Format is: + // struct { + // uint8 version; + // uint8 key_id; + // uint8 salt[16]; + // opaque aead_encrypted(plaintext)[length as expanded]; + // }; + // AAD covers the entire header, plus the value of the AAD parameter that is provided. + let salt = random(Self::SALT_LENGTH); + let cipher = self.make_aead(&self.key, &salt)?; + let encoded_len = 2 + salt.len() + plaintext.len() + cipher.expansion(); + + let mut enc = Encoder::with_capacity(encoded_len); + enc.encode_byte(Self::VERSION); + enc.encode_byte(self.key_id); + enc.encode(&salt); + + let mut extended_aad = enc.clone(); + extended_aad.encode(aad); + + let offset = enc.len(); + let mut output: Vec<u8> = enc.into(); + output.resize(encoded_len, 0); + cipher.encrypt(0, &extended_aad, plaintext, &mut output[offset..])?; + qtrace!( + ["SelfEncrypt"], + "seal {} {} -> {}", + hex(aad), + hex(plaintext), + hex(&output) + ); + Ok(output) + } + + fn select_key(&self, kid: u8) -> Option<&SymKey> { + if kid == self.key_id { + Some(&self.key) + } else { + let (prev_key_id, _) = self.key_id.overflowing_sub(1); + if kid == prev_key_id { + self.old_key.as_ref() + } else { + None + } + } + } + + /// Open the protected `ciphertext`. + /// + /// # Errors + /// Returns an error when the self-encrypted object is invalid; + /// when the keys have been rotated; or when NSS fails. + #[allow(clippy::similar_names)] // aad is similar to aead + pub fn open(&self, aad: &[u8], ciphertext: &[u8]) -> Res<Vec<u8>> { + if ciphertext[0] != Self::VERSION { + return Err(Error::SelfEncryptFailure); + } + let key = if let Some(k) = self.select_key(ciphertext[1]) { + k + } else { + return Err(Error::SelfEncryptFailure); + }; + let offset = 2 + Self::SALT_LENGTH; + + let mut extended_aad = Encoder::with_capacity(offset + aad.len()); + extended_aad.encode(&ciphertext[0..offset]); + extended_aad.encode(aad); + + let aead = self.make_aead(key, &ciphertext[2..offset])?; + // NSS insists on having extra space available for decryption. + let padded_len = ciphertext.len() - offset; + let mut output = vec![0; padded_len]; + let decrypted = aead.decrypt(0, &extended_aad, &ciphertext[offset..], &mut output)?; + let final_len = decrypted.len(); + output.truncate(final_len); + qtrace!( + ["SelfEncrypt"], + "open {} {} -> {}", + hex(aad), + hex(ciphertext), + hex(&output) + ); + Ok(output) + } +} diff --git a/third_party/rust/neqo-crypto/src/ssl.rs b/third_party/rust/neqo-crypto/src/ssl.rs new file mode 100644 index 0000000000..7389d57812 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/ssl.rs @@ -0,0 +1,141 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![allow(dead_code, non_upper_case_globals, non_snake_case)] +#![allow(clippy::cognitive_complexity, clippy::too_many_lines)] + +use crate::constants::Epoch; +use crate::err::{secstatus_to_res, Res}; + +use std::os::raw::{c_uint, c_void}; + +include!(concat!(env!("OUT_DIR"), "/nss_ssl.rs")); +mod SSLOption { + include!(concat!(env!("OUT_DIR"), "/nss_sslopt.rs")); +} + +// I clearly don't understand how bindgen operates. +#[allow(clippy::empty_enum)] +pub enum PLArenaPool {} +#[allow(clippy::empty_enum)] +pub enum PRFileDesc {} + +// Remap some constants. +pub const SECSuccess: SECStatus = _SECStatus_SECSuccess; +pub const SECFailure: SECStatus = _SECStatus_SECFailure; + +#[derive(Debug, Copy, Clone)] +pub enum Opt { + Locking, + Tickets, + OcspStapling, + Alpn, + ExtendedMasterSecret, + SignedCertificateTimestamps, + EarlyData, + RecordSizeLimit, + Tls13CompatMode, + HelloDowngradeCheck, + SuppressEndOfEarlyData, +} + +impl Opt { + // Cast is safe here because SSLOptions are within the i32 range + #[allow(clippy::cast_possible_wrap)] + pub(crate) fn as_int(self) -> PRInt32 { + let i = match self { + Self::Locking => SSLOption::SSL_NO_LOCKS, + Self::Tickets => SSLOption::SSL_ENABLE_SESSION_TICKETS, + Self::OcspStapling => SSLOption::SSL_ENABLE_OCSP_STAPLING, + Self::Alpn => SSLOption::SSL_ENABLE_ALPN, + Self::ExtendedMasterSecret => SSLOption::SSL_ENABLE_EXTENDED_MASTER_SECRET, + Self::SignedCertificateTimestamps => SSLOption::SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, + Self::EarlyData => SSLOption::SSL_ENABLE_0RTT_DATA, + Self::RecordSizeLimit => SSLOption::SSL_RECORD_SIZE_LIMIT, + Self::Tls13CompatMode => SSLOption::SSL_ENABLE_TLS13_COMPAT_MODE, + Self::HelloDowngradeCheck => SSLOption::SSL_ENABLE_HELLO_DOWNGRADE_CHECK, + Self::SuppressEndOfEarlyData => SSLOption::SSL_SUPPRESS_END_OF_EARLY_DATA, + }; + i as PRInt32 + } + + // Some options are backwards, like SSL_NO_LOCKS, so use this to manage that. + fn map_enabled(self, enabled: bool) -> PRIntn { + let v = match self { + Self::Locking => !enabled, + _ => enabled, + }; + PRIntn::from(v) + } + + pub(crate) fn set(self, fd: *mut PRFileDesc, value: bool) -> Res<()> { + secstatus_to_res(unsafe { SSL_OptionSet(fd, self.as_int(), self.map_enabled(value)) }) + } +} + +experimental_api!(SSL_GetCurrentEpoch( + fd: *mut PRFileDesc, + read_epoch: *mut u16, + write_epoch: *mut u16, +)); +experimental_api!(SSL_HelloRetryRequestCallback( + fd: *mut PRFileDesc, + cb: SSLHelloRetryRequestCallback, + arg: *mut c_void, +)); +experimental_api!(SSL_RecordLayerWriteCallback( + fd: *mut PRFileDesc, + cb: SSLRecordWriteCallback, + arg: *mut c_void, +)); +experimental_api!(SSL_RecordLayerData( + fd: *mut PRFileDesc, + epoch: Epoch, + ct: SSLContentType::Type, + data: *const u8, + len: c_uint, +)); +experimental_api!(SSL_SendSessionTicket( + fd: *mut PRFileDesc, + extra: *const u8, + len: c_uint, +)); +experimental_api!(SSL_SetMaxEarlyDataSize(fd: *mut PRFileDesc, size: u32)); +experimental_api!(SSL_SetResumptionToken( + fd: *mut PRFileDesc, + token: *const u8, + len: c_uint, +)); +experimental_api!(SSL_SetResumptionTokenCallback( + fd: *mut PRFileDesc, + cb: SSLResumptionTokenCallback, + arg: *mut c_void, +)); + +experimental_api!(SSL_GetResumptionTokenInfo( + token: *const u8, + token_len: c_uint, + info: *mut SSLResumptionTokenInfo, + len: c_uint, +)); + +experimental_api!(SSL_DestroyResumptionTokenInfo( + info: *mut SSLResumptionTokenInfo, +)); + +#[cfg(test)] +mod tests { + use super::{SSL_GetNumImplementedCiphers, SSL_NumImplementedCiphers}; + + #[test] + fn num_ciphers() { + assert!(unsafe { SSL_NumImplementedCiphers } > 0); + assert!(unsafe { SSL_GetNumImplementedCiphers() } > 0); + assert_eq!(unsafe { SSL_NumImplementedCiphers }, unsafe { + SSL_GetNumImplementedCiphers() + }); + } +} diff --git a/third_party/rust/neqo-crypto/src/time.rs b/third_party/rust/neqo-crypto/src/time.rs new file mode 100644 index 0000000000..789e2f9e88 --- /dev/null +++ b/third_party/rust/neqo-crypto/src/time.rs @@ -0,0 +1,248 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::agentio::as_c_void; +use crate::err::{Error, Res}; +use crate::once::OnceResult; +use crate::ssl::{PRFileDesc, SSLTimeFunc}; + +use std::boxed::Box; +use std::convert::{TryFrom, TryInto}; +use std::ops::Deref; +use std::os::raw::c_void; +use std::pin::Pin; +use std::time::{Duration, Instant}; + +include!(concat!(env!("OUT_DIR"), "/nspr_time.rs")); + +experimental_api!(SSL_SetTimeFunc( + fd: *mut PRFileDesc, + cb: SSLTimeFunc, + arg: *mut c_void, +)); + +/// This struct holds the zero time used for converting between `Instant` and `PRTime`. +#[derive(Debug)] +struct TimeZero { + instant: Instant, + prtime: PRTime, +} + +impl TimeZero { + /// This function sets a baseline from an instance of `Instant`. + /// This allows for the possibility that code that uses these APIs will create + /// instances of `Instant` before any of this code is run. If `Instant`s older than + /// `BASE_TIME` are used with these conversion functions, they will fail. + /// To avoid that, we make sure that this sets the base time using the first value + /// it sees if it is in the past. If it is not, then use `Instant::now()` instead. + pub fn baseline(t: Instant) -> Self { + let now = Instant::now(); + let prnow = unsafe { PR_Now() }; + + if now <= t { + // `t` is in the future, just use `now`. + Self { + instant: now, + prtime: prnow, + } + } else { + let elapsed = Interval::from(now.duration_since(now)); + // An error from these unwrap functions would require + // ridiculously long application running time. + let prelapsed: PRTime = elapsed.try_into().unwrap(); + Self { + instant: t, + prtime: prnow.checked_sub(prelapsed).unwrap(), + } + } + } +} + +static mut BASE_TIME: OnceResult<TimeZero> = OnceResult::new(); + +fn get_base() -> &'static TimeZero { + let f = || TimeZero { + instant: Instant::now(), + prtime: unsafe { PR_Now() }, + }; + unsafe { BASE_TIME.call_once(f) } +} + +pub(crate) fn init() { + let _ = get_base(); +} + +/// Time wraps Instant and provides conversion functions into `PRTime`. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Time { + t: Instant, +} + +impl Deref for Time { + type Target = Instant; + fn deref(&self) -> &Self::Target { + &self.t + } +} + +impl From<Instant> for Time { + /// Convert from an Instant into a Time. + fn from(t: Instant) -> Self { + // Call `TimeZero::baseline(t)` so that time zero can be set. + let f = || TimeZero::baseline(t); + let _ = unsafe { BASE_TIME.call_once(f) }; + Self { t } + } +} + +impl TryFrom<PRTime> for Time { + type Error = Error; + fn try_from(prtime: PRTime) -> Res<Self> { + let base = get_base(); + if let Some(delta) = prtime.checked_sub(base.prtime) { + let d = Duration::from_micros(delta.try_into()?); + base.instant + .checked_add(d) + .map_or(Err(Error::TimeTravelError), |t| Ok(Self { t })) + } else { + Err(Error::TimeTravelError) + } + } +} + +impl TryInto<PRTime> for Time { + type Error = Error; + fn try_into(self) -> Res<PRTime> { + let base = get_base(); + // TODO(mt) use checked_duration_since when that is available. + let delta = self.t.duration_since(base.instant); + if let Ok(d) = PRTime::try_from(delta.as_micros()) { + d.checked_add(base.prtime) + .map_or(Err(Error::TimeTravelError), Ok) + } else { + Err(Error::TimeTravelError) + } + } +} + +impl Into<Instant> for Time { + fn into(self) -> Instant { + self.t + } +} + +/// Interval wraps Duration and provides conversion functions into `PRTime`. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Interval { + d: Duration, +} + +impl Deref for Interval { + type Target = Duration; + fn deref(&self) -> &Self::Target { + &self.d + } +} + +impl TryFrom<PRTime> for Interval { + type Error = Error; + fn try_from(prtime: PRTime) -> Res<Self> { + Ok(Self { + d: Duration::from_micros(u64::try_from(prtime)?), + }) + } +} + +impl From<Duration> for Interval { + fn from(d: Duration) -> Self { + Self { d } + } +} + +impl TryInto<PRTime> for Interval { + type Error = Error; + fn try_into(self) -> Res<PRTime> { + Ok(PRTime::try_from(self.d.as_micros())?) + } +} + +/// `TimeHolder` maintains a `PRTime` value in a form that is accessible to the TLS stack. +#[derive(Debug)] +pub struct TimeHolder { + t: Pin<Box<PRTime>>, +} + +impl TimeHolder { + unsafe extern "C" fn time_func(arg: *mut c_void) -> PRTime { + let p = arg as *mut PRTime as *const PRTime; + *p.as_ref().unwrap() + } + + pub fn bind(&mut self, fd: *mut PRFileDesc) -> Res<()> { + unsafe { SSL_SetTimeFunc(fd, Some(Self::time_func), as_c_void(&mut self.t)) } + } + + pub fn set(&mut self, t: Instant) -> Res<()> { + *self.t = Time::from(t).try_into()?; + Ok(()) + } +} + +impl Default for TimeHolder { + fn default() -> Self { + TimeHolder { t: Box::pin(0) } + } +} + +#[cfg(test)] +mod test { + use super::{get_base, init, Interval, PRTime, Time}; + use crate::err::Res; + use std::convert::{TryFrom, TryInto}; + use std::time::{Duration, Instant}; + + #[test] + fn convert_stable() { + init(); + let now = Time::from(Instant::now()); + let pr: PRTime = now.try_into().expect("convert to PRTime with truncation"); + let t2 = Time::try_from(pr).expect("convert to Instant"); + let pr2: PRTime = t2.try_into().expect("convert to PRTime again"); + assert_eq!(pr, pr2); + let t3 = Time::try_from(pr2).expect("convert to Instant again"); + assert_eq!(t2, t3); + } + + #[test] + fn past_time() { + init(); + let base = get_base(); + assert!(Time::try_from(base.prtime - 1).is_err()); + } + + #[test] + fn negative_time() { + init(); + assert!(Time::try_from(-1).is_err()); + } + + #[test] + fn negative_interval() { + init(); + assert!(Interval::try_from(-1).is_err()); + } + + #[test] + // We allow replace_consts here because + // std::u64::max_value() isn't available + // in all of our targets + fn overflow_interval() { + init(); + let interval = Interval::from(Duration::from_micros(u64::max_value())); + let res: Res<PRTime> = interval.try_into(); + assert!(res.is_err()); + } +} |