summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-crypto/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/neqo-crypto/src
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/neqo-crypto/src')
-rw-r--r--third_party/rust/neqo-crypto/src/aead.rs161
-rw-r--r--third_party/rust/neqo-crypto/src/aead_fuzzing.rs70
-rw-r--r--third_party/rust/neqo-crypto/src/agent.rs1219
-rw-r--r--third_party/rust/neqo-crypto/src/agentio.rs393
-rw-r--r--third_party/rust/neqo-crypto/src/auth.rs108
-rw-r--r--third_party/rust/neqo-crypto/src/cert.rs118
-rw-r--r--third_party/rust/neqo-crypto/src/constants.rs145
-rw-r--r--third_party/rust/neqo-crypto/src/ech.rs190
-rw-r--r--third_party/rust/neqo-crypto/src/err.rs214
-rw-r--r--third_party/rust/neqo-crypto/src/exp.rs23
-rw-r--r--third_party/rust/neqo-crypto/src/ext.rs165
-rw-r--r--third_party/rust/neqo-crypto/src/hkdf.rs128
-rw-r--r--third_party/rust/neqo-crypto/src/hp.rs187
-rw-r--r--third_party/rust/neqo-crypto/src/lib.rs204
-rw-r--r--third_party/rust/neqo-crypto/src/once.rs44
-rw-r--r--third_party/rust/neqo-crypto/src/p11.rs303
-rw-r--r--third_party/rust/neqo-crypto/src/prio.rs25
-rw-r--r--third_party/rust/neqo-crypto/src/replay.rs78
-rw-r--r--third_party/rust/neqo-crypto/src/result.rs133
-rw-r--r--third_party/rust/neqo-crypto/src/secrets.rs127
-rw-r--r--third_party/rust/neqo-crypto/src/selfencrypt.rs155
-rw-r--r--third_party/rust/neqo-crypto/src/ssl.rs149
-rw-r--r--third_party/rust/neqo-crypto/src/time.rs252
23 files changed, 4591 insertions, 0 deletions
diff --git a/third_party/rust/neqo-crypto/src/aead.rs b/third_party/rust/neqo-crypto/src/aead.rs
new file mode 100644
index 0000000000..74f0ae0ce2
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/aead.rs
@@ -0,0 +1,161 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{Cipher, Version};
+use crate::err::Res;
+use crate::p11::{PK11SymKey, SymKey};
+use crate::ssl;
+use crate::ssl::{PRUint16, PRUint64, PRUint8, SSLAeadContext};
+
+use std::convert::{TryFrom, TryInto};
+use std::fmt;
+use std::ops::{Deref, DerefMut};
+use std::os::raw::{c_char, c_uint};
+use std::ptr::null_mut;
+
+experimental_api!(SSL_MakeAead(
+ version: PRUint16,
+ cipher: PRUint16,
+ secret: *mut PK11SymKey,
+ label_prefix: *const c_char,
+ label_prefix_len: c_uint,
+ ctx: *mut *mut SSLAeadContext,
+));
+experimental_api!(SSL_AeadEncrypt(
+ ctx: *const SSLAeadContext,
+ counter: PRUint64,
+ aad: *const PRUint8,
+ aad_len: c_uint,
+ input: *const PRUint8,
+ input_len: c_uint,
+ output: *const PRUint8,
+ output_len: *mut c_uint,
+ max_output: c_uint
+));
+experimental_api!(SSL_AeadDecrypt(
+ ctx: *const SSLAeadContext,
+ counter: PRUint64,
+ aad: *const PRUint8,
+ aad_len: c_uint,
+ input: *const PRUint8,
+ input_len: c_uint,
+ output: *const PRUint8,
+ output_len: *mut c_uint,
+ max_output: c_uint
+));
+experimental_api!(SSL_DestroyAead(ctx: *mut SSLAeadContext));
+scoped_ptr!(AeadContext, SSLAeadContext, SSL_DestroyAead);
+
+pub struct Aead {
+ ctx: AeadContext,
+}
+
+impl Aead {
+ /// Create a new AEAD based on the indicated TLS version and cipher suite.
+ ///
+ /// # Errors
+ /// Returns `Error` when the supporting NSS functions fail.
+ pub fn new(version: Version, cipher: Cipher, secret: &SymKey, prefix: &str) -> Res<Self> {
+ let s: *mut PK11SymKey = **secret;
+ unsafe { Self::from_raw(version, cipher, s, prefix) }
+ }
+
+ #[must_use]
+ #[allow(clippy::unused_self)]
+ pub fn expansion(&self) -> usize {
+ 16
+ }
+
+ unsafe fn from_raw(
+ version: Version,
+ cipher: Cipher,
+ secret: *mut PK11SymKey,
+ prefix: &str,
+ ) -> Res<Self> {
+ let p = prefix.as_bytes();
+ let mut ctx: *mut ssl::SSLAeadContext = null_mut();
+ SSL_MakeAead(
+ version,
+ cipher,
+ secret,
+ p.as_ptr().cast(),
+ c_uint::try_from(p.len())?,
+ &mut ctx,
+ )?;
+ Ok(Self {
+ ctx: AeadContext::from_ptr(ctx)?,
+ })
+ }
+
+ /// Decrypt a plaintext.
+ ///
+ /// The space provided in `output` needs to be larger than `input` by
+ /// the value provided in `Aead::expansion`.
+ ///
+ /// # Errors
+ /// If the input can't be protected or any input is too large for NSS.
+ pub fn encrypt<'a>(
+ &self,
+ count: u64,
+ aad: &[u8],
+ input: &[u8],
+ output: &'a mut [u8],
+ ) -> Res<&'a [u8]> {
+ let mut l: c_uint = 0;
+ unsafe {
+ SSL_AeadEncrypt(
+ *self.ctx.deref(),
+ count,
+ aad.as_ptr(),
+ c_uint::try_from(aad.len())?,
+ input.as_ptr(),
+ c_uint::try_from(input.len())?,
+ output.as_mut_ptr(),
+ &mut l,
+ c_uint::try_from(output.len())?,
+ )
+ }?;
+ Ok(&output[0..(l.try_into()?)])
+ }
+
+ /// Decrypt a ciphertext.
+ ///
+ /// Note that NSS insists upon having extra space available for decryption, so
+ /// the buffer for `output` should be the same length as `input`, even though
+ /// the final result will be shorter.
+ ///
+ /// # Errors
+ /// If the input isn't authenticated or any input is too large for NSS.
+ pub fn decrypt<'a>(
+ &self,
+ count: u64,
+ aad: &[u8],
+ input: &[u8],
+ output: &'a mut [u8],
+ ) -> Res<&'a [u8]> {
+ let mut l: c_uint = 0;
+ unsafe {
+ SSL_AeadDecrypt(
+ *self.ctx.deref(),
+ count,
+ aad.as_ptr(),
+ c_uint::try_from(aad.len())?,
+ input.as_ptr(),
+ c_uint::try_from(input.len())?,
+ output.as_mut_ptr(),
+ &mut l,
+ c_uint::try_from(output.len())?,
+ )
+ }?;
+ Ok(&output[0..(l.try_into()?)])
+ }
+}
+
+impl fmt::Debug for Aead {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "[AEAD Context]")
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/aead_fuzzing.rs b/third_party/rust/neqo-crypto/src/aead_fuzzing.rs
new file mode 100644
index 0000000000..2c5cc6c56e
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/aead_fuzzing.rs
@@ -0,0 +1,70 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{Cipher, Version};
+use crate::err::{sec::SEC_ERROR_BAD_DATA, Error, Res};
+use crate::p11::SymKey;
+use std::fmt;
+
+pub const FIXED_TAG_FUZZING: &[u8] = &[0x0a; 16];
+pub struct Aead {}
+
+#[allow(clippy::unused_self)]
+impl Aead {
+ pub fn new(_version: Version, _cipher: Cipher, _secret: &SymKey, _prefix: &str) -> Res<Self> {
+ Ok(Self {})
+ }
+
+ #[must_use]
+ pub fn expansion(&self) -> usize {
+ FIXED_TAG_FUZZING.len()
+ }
+
+ pub fn encrypt<'a>(
+ &self,
+ _count: u64,
+ _aad: &[u8],
+ input: &[u8],
+ output: &'a mut [u8],
+ ) -> Res<&'a [u8]> {
+ let l = input.len();
+ output[..l].copy_from_slice(input);
+ output[l..l + 16].copy_from_slice(FIXED_TAG_FUZZING);
+ Ok(&output[..l + 16])
+ }
+
+ pub fn decrypt<'a>(
+ &self,
+ _count: u64,
+ _aad: &[u8],
+ input: &[u8],
+ output: &'a mut [u8],
+ ) -> Res<&'a [u8]> {
+ if input.len() < FIXED_TAG_FUZZING.len() {
+ return Err(Error::from(SEC_ERROR_BAD_DATA));
+ }
+
+ let len_encrypted = input.len() - FIXED_TAG_FUZZING.len();
+ // Check that:
+ // 1) expansion is all zeros and
+ // 2) if the encrypted data is also supplied that at least some values
+ // are no zero (otherwise padding will be interpreted as a valid packet)
+ if &input[len_encrypted..] == FIXED_TAG_FUZZING
+ && (len_encrypted == 0 || input[..len_encrypted].iter().any(|x| *x != 0x0))
+ {
+ output[..len_encrypted].copy_from_slice(&input[..len_encrypted]);
+ Ok(&output[..len_encrypted])
+ } else {
+ Err(Error::from(SEC_ERROR_BAD_DATA))
+ }
+ }
+}
+
+impl fmt::Debug for Aead {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "[FUZZING AEAD]")
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/agent.rs b/third_party/rust/neqo-crypto/src/agent.rs
new file mode 100644
index 0000000000..9163c4c711
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/agent.rs
@@ -0,0 +1,1219 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+pub use crate::agentio::{as_c_void, Record, RecordList};
+use crate::agentio::{AgentIo, METHODS};
+use crate::assert_initialized;
+use crate::auth::AuthenticationStatus;
+pub use crate::cert::CertificateInfo;
+use crate::constants::{
+ Alert, Cipher, Epoch, Extension, Group, SignatureScheme, Version, TLS_VERSION_1_3,
+};
+use crate::ech;
+use crate::err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res};
+use crate::ext::{ExtensionHandler, ExtensionTracker};
+use crate::p11::{self, PrivateKey, PublicKey};
+use crate::prio;
+use crate::replay::AntiReplay;
+use crate::secrets::SecretHolder;
+use crate::ssl::{self, PRBool};
+use crate::time::{Time, TimeHolder};
+
+use neqo_common::{hex_snip_middle, hex_with_len, qdebug, qinfo, qtrace, qwarn};
+use std::cell::RefCell;
+use std::convert::TryFrom;
+use std::ffi::{CStr, CString};
+use std::mem::{self, MaybeUninit};
+use std::ops::{Deref, DerefMut};
+use std::os::raw::{c_uint, c_void};
+use std::pin::Pin;
+use std::ptr::{null, null_mut};
+use std::rc::Rc;
+use std::time::Instant;
+
+/// The maximum number of tickets to remember for a given connection.
+const MAX_TICKETS: usize = 4;
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum HandshakeState {
+ New,
+ InProgress,
+ AuthenticationPending,
+ /// When encrypted client hello is enabled, the server might engage a fallback.
+ /// This is the status that is returned. The included value is the public
+ /// name of the server, which should be used to validated the certificate.
+ EchFallbackAuthenticationPending(String),
+ Authenticated(PRErrorCode),
+ Complete(SecretAgentInfo),
+ Failed(Error),
+}
+
+impl HandshakeState {
+ #[must_use]
+ pub fn is_connected(&self) -> bool {
+ matches!(self, Self::Complete(_))
+ }
+
+ #[must_use]
+ pub fn is_final(&self) -> bool {
+ matches!(self, Self::Complete(_) | Self::Failed(_))
+ }
+
+ #[must_use]
+ pub fn authentication_needed(&self) -> bool {
+ matches!(
+ self,
+ Self::AuthenticationPending | Self::EchFallbackAuthenticationPending(_)
+ )
+ }
+}
+
+fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>> {
+ let mut alpn_state = ssl::SSLNextProtoState::SSL_NEXT_PROTO_NO_SUPPORT;
+ let mut chosen = vec![0_u8; 255];
+ let mut chosen_len: c_uint = 0;
+ secstatus_to_res(unsafe {
+ ssl::SSL_GetNextProto(
+ fd,
+ &mut alpn_state,
+ chosen.as_mut_ptr(),
+ &mut chosen_len,
+ c_uint::try_from(chosen.len())?,
+ )
+ })?;
+
+ let alpn = match (pre, alpn_state) {
+ (true, ssl::SSLNextProtoState::SSL_NEXT_PROTO_EARLY_VALUE)
+ | (
+ false,
+ ssl::SSLNextProtoState::SSL_NEXT_PROTO_NEGOTIATED
+ | ssl::SSLNextProtoState::SSL_NEXT_PROTO_SELECTED,
+ ) => {
+ chosen.truncate(usize::try_from(chosen_len)?);
+ Some(match String::from_utf8(chosen) {
+ Ok(a) => a,
+ Err(_) => return Err(Error::InternalError),
+ })
+ }
+ _ => None,
+ };
+ qtrace!([format!("{:p}", fd)], "got ALPN {:?}", alpn);
+ Ok(alpn)
+}
+
+pub struct SecretAgentPreInfo {
+ info: ssl::SSLPreliminaryChannelInfo,
+ alpn: Option<String>,
+}
+
+macro_rules! preinfo_arg {
+ ($v:ident, $m:ident, $f:ident: $t:ident $(,)?) => {
+ #[must_use]
+ pub fn $v(&self) -> Option<$t> {
+ match self.info.valuesSet & ssl::$m {
+ 0 => None,
+ _ => Some($t::try_from(self.info.$f).unwrap()),
+ }
+ }
+ };
+}
+
+impl SecretAgentPreInfo {
+ fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> {
+ let mut info: MaybeUninit<ssl::SSLPreliminaryChannelInfo> = MaybeUninit::uninit();
+ secstatus_to_res(unsafe {
+ ssl::SSL_GetPreliminaryChannelInfo(
+ fd,
+ info.as_mut_ptr(),
+ c_uint::try_from(mem::size_of::<ssl::SSLPreliminaryChannelInfo>())?,
+ )
+ })?;
+
+ Ok(Self {
+ info: unsafe { info.assume_init() },
+ alpn: get_alpn(fd, true)?,
+ })
+ }
+
+ preinfo_arg!(version, ssl_preinfo_version, protocolVersion: Version);
+ preinfo_arg!(cipher_suite, ssl_preinfo_cipher_suite, cipherSuite: Cipher);
+ preinfo_arg!(
+ early_data_cipher,
+ ssl_preinfo_0rtt_cipher_suite,
+ zeroRttCipherSuite: Cipher,
+ );
+
+ #[must_use]
+ pub fn early_data(&self) -> bool {
+ self.info.canSendEarlyData != 0
+ }
+
+ /// # Panics
+ /// If `usize` is less than 32 bits and the value is too large.
+ #[must_use]
+ pub fn max_early_data(&self) -> usize {
+ usize::try_from(self.info.maxEarlyDataSize).unwrap()
+ }
+
+ /// Was ECH accepted.
+ #[must_use]
+ pub fn ech_accepted(&self) -> Option<bool> {
+ if self.info.valuesSet & ssl::ssl_preinfo_ech == 0 {
+ None
+ } else {
+ Some(self.info.echAccepted != 0)
+ }
+ }
+
+ /// Get the ECH public name that was used. This will only be available
+ /// (that is, not `None`) if `ech_accepted()` returns `false`.
+ /// In this case, certificate validation needs to use this name rather
+ /// than the original name to validate the certificate. If
+ /// that validation passes (that is, `SecretAgent::authenticated` is called
+ /// with `AuthenticationStatus::Ok`), then the handshake will still fail.
+ /// After the failed handshake, the state will be `Error::EchRetry`,
+ /// which contains a valid ECH configuration.
+ ///
+ /// # Errors
+ /// When the public name is not valid UTF-8. (Note: names should be ASCII.)
+ pub fn ech_public_name(&self) -> Res<Option<&str>> {
+ if self.info.valuesSet & ssl::ssl_preinfo_ech == 0 || self.info.echPublicName.is_null() {
+ Ok(None)
+ } else {
+ let n = unsafe { CStr::from_ptr(self.info.echPublicName) };
+ Ok(Some(n.to_str()?))
+ }
+ }
+
+ #[must_use]
+ pub fn alpn(&self) -> Option<&String> {
+ self.alpn.as_ref()
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
+pub struct SecretAgentInfo {
+ version: Version,
+ cipher: Cipher,
+ group: Group,
+ resumed: bool,
+ early_data: bool,
+ ech_accepted: bool,
+ alpn: Option<String>,
+ signature_scheme: SignatureScheme,
+}
+
+impl SecretAgentInfo {
+ fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> {
+ let mut info: MaybeUninit<ssl::SSLChannelInfo> = MaybeUninit::uninit();
+ secstatus_to_res(unsafe {
+ ssl::SSL_GetChannelInfo(
+ fd,
+ info.as_mut_ptr(),
+ c_uint::try_from(mem::size_of::<ssl::SSLChannelInfo>())?,
+ )
+ })?;
+ let info = unsafe { info.assume_init() };
+ Ok(Self {
+ version: info.protocolVersion,
+ cipher: info.cipherSuite,
+ group: Group::try_from(info.keaGroup)?,
+ resumed: info.resumed != 0,
+ early_data: info.earlyDataAccepted != 0,
+ ech_accepted: info.echAccepted != 0,
+ alpn: get_alpn(fd, false)?,
+ signature_scheme: SignatureScheme::try_from(info.signatureScheme)?,
+ })
+ }
+ #[must_use]
+ pub fn version(&self) -> Version {
+ self.version
+ }
+ #[must_use]
+ pub fn cipher_suite(&self) -> Cipher {
+ self.cipher
+ }
+ #[must_use]
+ pub fn key_exchange(&self) -> Group {
+ self.group
+ }
+ #[must_use]
+ pub fn resumed(&self) -> bool {
+ self.resumed
+ }
+ #[must_use]
+ pub fn early_data_accepted(&self) -> bool {
+ self.early_data
+ }
+ #[must_use]
+ pub fn ech_accepted(&self) -> bool {
+ self.ech_accepted
+ }
+ #[must_use]
+ pub fn alpn(&self) -> Option<&String> {
+ self.alpn.as_ref()
+ }
+ #[must_use]
+ pub fn signature_scheme(&self) -> SignatureScheme {
+ self.signature_scheme
+ }
+}
+
+/// `SecretAgent` holds the common parts of client and server.
+#[derive(Debug)]
+#[allow(clippy::module_name_repetitions)]
+pub struct SecretAgent {
+ fd: *mut ssl::PRFileDesc,
+ secrets: SecretHolder,
+ raw: Option<bool>,
+ io: Pin<Box<AgentIo>>,
+ state: HandshakeState,
+
+ /// Records whether authentication of certificates is required.
+ auth_required: Pin<Box<bool>>,
+ /// Records any fatal alert that is sent by the stack.
+ alert: Pin<Box<Option<Alert>>>,
+ /// The current time.
+ now: TimeHolder,
+
+ extension_handlers: Vec<ExtensionTracker>,
+
+ /// The encrypted client hello (ECH) configuration that is in use.
+ /// Empty if ECH is not enabled.
+ ech_config: Vec<u8>,
+}
+
+impl SecretAgent {
+ fn new() -> Res<Self> {
+ let mut io = Box::pin(AgentIo::new());
+ let fd = Self::create_fd(&mut io)?;
+ Ok(Self {
+ fd,
+ secrets: SecretHolder::default(),
+ raw: None,
+ io,
+ state: HandshakeState::New,
+
+ auth_required: Box::pin(false),
+ alert: Box::pin(None),
+ now: TimeHolder::default(),
+
+ extension_handlers: Vec::new(),
+
+ ech_config: Vec::new(),
+ })
+ }
+
+ // Create a new SSL file descriptor.
+ //
+ // Note that we create separate bindings for PRFileDesc as both
+ // ssl::PRFileDesc and prio::PRFileDesc. This keeps the bindings
+ // minimal, but it means that the two forms need casts to translate
+ // between them. ssl::PRFileDesc is left as an opaque type, as the
+ // ssl::SSL_* APIs only need an opaque type.
+ fn create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc> {
+ assert_initialized();
+ let label = CString::new("sslwrapper")?;
+ let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) };
+
+ let base_fd = unsafe { prio::PR_CreateIOLayerStub(id, METHODS) };
+ if base_fd.is_null() {
+ return Err(Error::CreateSslSocket);
+ }
+ let fd = unsafe {
+ (*base_fd).secret = as_c_void(io).cast();
+ ssl::SSL_ImportFD(null_mut(), base_fd.cast())
+ };
+ if fd.is_null() {
+ unsafe { prio::PR_Close(base_fd) };
+ return Err(Error::CreateSslSocket);
+ }
+ Ok(fd)
+ }
+
+ unsafe extern "C" fn auth_complete_hook(
+ arg: *mut c_void,
+ _fd: *mut ssl::PRFileDesc,
+ _check_sig: ssl::PRBool,
+ _is_server: ssl::PRBool,
+ ) -> ssl::SECStatus {
+ let auth_required_ptr = arg.cast::<bool>();
+ *auth_required_ptr = true;
+ // NSS insists on getting SECWouldBlock here rather than accepting
+ // the usual combination of PR_WOULD_BLOCK_ERROR and SECFailure.
+ ssl::_SECStatus_SECWouldBlock
+ }
+
+ unsafe extern "C" fn alert_sent_cb(
+ fd: *const ssl::PRFileDesc,
+ arg: *mut c_void,
+ alert: *const ssl::SSLAlert,
+ ) {
+ let alert = alert.as_ref().unwrap();
+ if alert.level == 2 {
+ // Fatal alerts demand attention.
+ let st = arg.cast::<Option<Alert>>().as_mut().unwrap();
+ if st.is_none() {
+ *st = Some(alert.description);
+ } else {
+ qwarn!(
+ [format!("{:p}", fd)],
+ "duplicate alert {}",
+ alert.description
+ );
+ }
+ }
+ }
+
+ // Ready this for connecting.
+ fn ready(&mut self, is_server: bool) -> Res<()> {
+ secstatus_to_res(unsafe {
+ ssl::SSL_AuthCertificateHook(
+ self.fd,
+ Some(Self::auth_complete_hook),
+ as_c_void(&mut self.auth_required),
+ )
+ })?;
+
+ secstatus_to_res(unsafe {
+ ssl::SSL_AlertSentCallback(
+ self.fd,
+ Some(Self::alert_sent_cb),
+ as_c_void(&mut self.alert),
+ )
+ })?;
+
+ self.now.bind(self.fd)?;
+ self.configure()?;
+ secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, ssl::PRBool::from(is_server)) })
+ }
+
+ /// Default configuration.
+ ///
+ /// # Errors
+ /// If `set_version_range` fails.
+ fn configure(&mut self) -> Res<()> {
+ self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?;
+ self.set_option(ssl::Opt::Locking, false)?;
+ self.set_option(ssl::Opt::Tickets, false)?;
+ self.set_option(ssl::Opt::OcspStapling, true)?;
+ Ok(())
+ }
+
+ /// Set the versions that are supported.
+ ///
+ /// # Errors
+ /// If the range of versions isn't supported.
+ pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> {
+ let range = ssl::SSLVersionRange { min, max };
+ secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) })
+ }
+
+ /// Enable a set of ciphers. Note that the order of these is not respected.
+ ///
+ /// # Errors
+ /// If NSS can't enable or disable ciphers.
+ pub fn set_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> {
+ if self.state != HandshakeState::New {
+ qwarn!([self], "Cannot enable ciphers in state {:?}", self.state);
+ return Err(Error::InternalError);
+ }
+
+ let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() };
+ let cipher_count = usize::from(unsafe { ssl::SSL_GetNumImplementedCiphers() });
+ for i in 0..cipher_count {
+ let p = all_ciphers.wrapping_add(i);
+ secstatus_to_res(unsafe {
+ ssl::SSL_CipherPrefSet(self.fd, i32::from(*p), ssl::PRBool::from(false))
+ })?;
+ }
+
+ for c in ciphers {
+ secstatus_to_res(unsafe {
+ ssl::SSL_CipherPrefSet(self.fd, i32::from(*c), ssl::PRBool::from(true))
+ })?;
+ }
+ Ok(())
+ }
+
+ /// Set key exchange groups.
+ ///
+ /// # Errors
+ /// If the underlying API fails (which shouldn't happen).
+ pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> {
+ // SSLNamedGroup is a different size to Group, so copy one by one.
+ let group_vec: Vec<_> = groups
+ .iter()
+ .map(|&g| ssl::SSLNamedGroup::Type::from(g))
+ .collect();
+
+ let ptr = group_vec.as_slice().as_ptr();
+ secstatus_to_res(unsafe {
+ ssl::SSL_NamedGroupConfig(self.fd, ptr, c_uint::try_from(group_vec.len())?)
+ })
+ }
+
+ /// Set TLS options.
+ ///
+ /// # Errors
+ /// Returns an error if the option or option value is invalid; i.e., never.
+ pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> {
+ opt.set(self.fd, value)
+ }
+
+ /// Enable 0-RTT.
+ ///
+ /// # Errors
+ /// See `set_option`.
+ pub fn enable_0rtt(&mut self) -> Res<()> {
+ self.set_option(ssl::Opt::EarlyData, true)
+ }
+
+ /// Disable the `EndOfEarlyData` message.
+ ///
+ /// # Errors
+ /// See `set_option`.
+ pub fn disable_end_of_early_data(&mut self) -> Res<()> {
+ self.set_option(ssl::Opt::SuppressEndOfEarlyData, true)
+ }
+
+ /// `set_alpn` sets a list of preferred protocols, starting with the most preferred.
+ /// Though ALPN [RFC7301] permits octet sequences, this only allows for UTF-8-encoded
+ /// strings.
+ ///
+ /// This asserts if no items are provided, or if any individual item is longer than
+ /// 255 octets in length.
+ ///
+ /// # Errors
+ /// This should always panic rather than return an error.
+ /// # Panics
+ /// If any of the provided `protocols` are more than 255 bytes long.
+ ///
+ /// [RFC7301]: https://datatracker.ietf.org/doc/html/rfc7301
+ pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> {
+ // Validate and set length.
+ let mut encoded_len = protocols.len();
+ for v in protocols {
+ assert!(v.as_ref().len() < 256);
+ assert!(!v.as_ref().is_empty());
+ encoded_len += v.as_ref().len();
+ }
+
+ // Prepare to encode.
+ let mut encoded = Vec::with_capacity(encoded_len);
+ let mut add = |v: &str| {
+ if let Ok(s) = u8::try_from(v.len()) {
+ encoded.push(s);
+ encoded.extend_from_slice(v.as_bytes());
+ }
+ };
+
+ // NSS inherited an idiosyncratic API as a result of having implemented NPN
+ // before ALPN. For that reason, we need to put the "best" option last.
+ let (first, rest) = protocols
+ .split_first()
+ .expect("at least one ALPN value needed");
+ for v in rest {
+ add(v.as_ref());
+ }
+ add(first.as_ref());
+ assert_eq!(encoded_len, encoded.len());
+
+ // Now give the result to NSS.
+ secstatus_to_res(unsafe {
+ ssl::SSL_SetNextProtoNego(
+ self.fd,
+ encoded.as_slice().as_ptr(),
+ c_uint::try_from(encoded.len())?,
+ )
+ })
+ }
+
+ /// Install an extension handler.
+ ///
+ /// This can be called multiple times with different values for `ext`. The handler is provided as
+ /// Rc<RefCell<>> so that the caller is able to hold a reference to the handler and later access any
+ /// state that it accumulates.
+ ///
+ /// # Errors
+ /// When the extension handler can't be successfully installed.
+ pub fn extension_handler(
+ &mut self,
+ ext: Extension,
+ handler: Rc<RefCell<dyn ExtensionHandler>>,
+ ) -> Res<()> {
+ let tracker = unsafe { ExtensionTracker::new(self.fd, ext, handler) }?;
+ self.extension_handlers.push(tracker);
+ Ok(())
+ }
+
+ // This function tracks whether handshake() or handshake_raw() was used
+ // and prevents the other from being used.
+ fn set_raw(&mut self, r: bool) -> Res<()> {
+ if self.raw.is_none() {
+ self.secrets.register(self.fd)?;
+ self.raw = Some(r);
+ Ok(())
+ } else if self.raw.unwrap() == r {
+ Ok(())
+ } else {
+ Err(Error::MixedHandshakeMethod)
+ }
+ }
+
+ /// Get information about the connection.
+ /// This includes the version, ciphersuite, and ALPN.
+ ///
+ /// Calling this function returns None until the connection is complete.
+ #[must_use]
+ pub fn info(&self) -> Option<&SecretAgentInfo> {
+ match self.state {
+ HandshakeState::Complete(ref info) => Some(info),
+ _ => None,
+ }
+ }
+
+ /// Get any preliminary information about the status of the connection.
+ ///
+ /// This includes whether 0-RTT was accepted and any information related to that.
+ /// Calling this function collects all the relevant information.
+ ///
+ /// # Errors
+ /// When the underlying socket functions fail.
+ pub fn preinfo(&self) -> Res<SecretAgentPreInfo> {
+ SecretAgentPreInfo::new(self.fd)
+ }
+
+ /// Get the peer's certificate chain.
+ #[must_use]
+ pub fn peer_certificate(&self) -> Option<CertificateInfo> {
+ CertificateInfo::new(self.fd)
+ }
+
+ /// Return any fatal alert that the TLS stack might have sent.
+ #[must_use]
+ pub fn alert(&self) -> Option<&Alert> {
+ (*self.alert).as_ref()
+ }
+
+ /// Call this function to mark the peer as authenticated.
+ /// # Panics
+ /// If the handshake doesn't need to be authenticated.
+ pub fn authenticated(&mut self, status: AuthenticationStatus) {
+ assert!(self.state.authentication_needed());
+ *self.auth_required = false;
+ self.state = HandshakeState::Authenticated(status.into());
+ }
+
+ fn capture_error<T>(&mut self, res: Res<T>) -> Res<T> {
+ if let Err(e) = res {
+ let e = ech::convert_ech_error(self.fd, e);
+ qwarn!([self], "error: {:?}", e);
+ self.state = HandshakeState::Failed(e.clone());
+ Err(e)
+ } else {
+ res
+ }
+ }
+
+ fn update_state(&mut self, res: Res<()>) -> Res<()> {
+ self.state = if is_blocked(&res) {
+ if *self.auth_required {
+ self.preinfo()?.ech_public_name()?.map_or(
+ HandshakeState::AuthenticationPending,
+ |public_name| {
+ HandshakeState::EchFallbackAuthenticationPending(public_name.to_owned())
+ },
+ )
+ } else {
+ HandshakeState::InProgress
+ }
+ } else {
+ self.capture_error(res)?;
+ let info = self.capture_error(SecretAgentInfo::new(self.fd))?;
+ HandshakeState::Complete(info)
+ };
+ qinfo!([self], "state -> {:?}", self.state);
+ Ok(())
+ }
+
+ /// Drive the TLS handshake, taking bytes from `input` and putting
+ /// any bytes necessary into `output`.
+ /// This takes the current time as `now`.
+ /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake
+ /// is complete and how many bytes were written to `output`, respectively.
+ /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this
+ /// function if you want to proceed, because this will mark the certificate as OK.
+ ///
+ /// # Errors
+ /// When the handshake fails this returns an error.
+ pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>> {
+ self.now.set(now)?;
+ self.set_raw(false)?;
+
+ let rv = {
+ // Within this scope, _h maintains a mutable reference to self.io.
+ let _h = self.io.wrap(input);
+ match self.state {
+ HandshakeState::Authenticated(ref err) => unsafe {
+ ssl::SSL_AuthCertificateComplete(self.fd, *err)
+ },
+ _ => unsafe { ssl::SSL_ForceHandshake(self.fd) },
+ }
+ };
+ // Take before updating state so that we leave the output buffer empty
+ // even if there is an error.
+ let output = self.io.take_output();
+ self.update_state(secstatus_to_res(rv))?;
+ Ok(output)
+ }
+
+ /// Setup to receive records for raw handshake functions.
+ fn setup_raw(&mut self) -> Res<Pin<Box<RecordList>>> {
+ self.set_raw(true)?;
+ self.capture_error(RecordList::setup(self.fd))
+ }
+
+ /// Drive the TLS handshake, but get the raw content of records, not
+ /// protected records as bytes. This function is incompatible with
+ /// `handshake()`; use either this or `handshake()` exclusively.
+ ///
+ /// Ideally, this only includes records from the current epoch.
+ /// If you send data from multiple epochs, you might end up being sad.
+ ///
+ /// # Errors
+ /// When the handshake fails this returns an error.
+ pub fn handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList> {
+ self.now.set(now)?;
+ let records = self.setup_raw()?;
+
+ // Fire off any authentication we might need to complete.
+ if let HandshakeState::Authenticated(ref err) = self.state {
+ let result =
+ secstatus_to_res(unsafe { ssl::SSL_AuthCertificateComplete(self.fd, *err) });
+ qdebug!([self], "SSL_AuthCertificateComplete: {:?}", result);
+ // This should return SECSuccess, so don't use update_state().
+ self.capture_error(result)?;
+ }
+
+ // Feed in any records.
+ if let Some(rec) = input {
+ self.capture_error(rec.write(self.fd))?;
+ }
+
+ // Drive the handshake once more.
+ let rv = secstatus_to_res(unsafe { ssl::SSL_ForceHandshake(self.fd) });
+ self.update_state(rv)?;
+
+ Ok(*Pin::into_inner(records))
+ }
+
+ #[allow(unknown_lints, clippy::branches_sharing_code)]
+ pub fn close(&mut self) {
+ // It should be safe to close multiple times.
+ if self.fd.is_null() {
+ return;
+ }
+ if let Some(true) = self.raw {
+ // Need to hold the record list in scope until the close is done.
+ let _records = self.setup_raw().expect("Can only close");
+ unsafe { prio::PR_Close(self.fd.cast()) };
+ } else {
+ // Need to hold the IO wrapper in scope until the close is done.
+ let _io = self.io.wrap(&[]);
+ unsafe { prio::PR_Close(self.fd.cast()) };
+ };
+ let _output = self.io.take_output();
+ self.fd = null_mut();
+ }
+
+ /// State returns the status of the handshake.
+ #[must_use]
+ pub fn state(&self) -> &HandshakeState {
+ &self.state
+ }
+
+ /// Take a read secret. This will only return a non-`None` value once.
+ #[must_use]
+ pub fn read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> {
+ self.secrets.take_read(epoch)
+ }
+
+ /// Take a write secret.
+ #[must_use]
+ pub fn write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> {
+ self.secrets.take_write(epoch)
+ }
+
+ /// Get the active ECH configuration, which is empty if ECH is disabled.
+ #[must_use]
+ pub fn ech_config(&self) -> &[u8] {
+ &self.ech_config
+ }
+}
+
+impl Drop for SecretAgent {
+ fn drop(&mut self) {
+ self.close();
+ }
+}
+
+impl ::std::fmt::Display for SecretAgent {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Agent {:p}", self.fd)
+ }
+}
+
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone)]
+pub struct ResumptionToken {
+ token: Vec<u8>,
+ expiration_time: Instant,
+}
+
+impl AsRef<[u8]> for ResumptionToken {
+ fn as_ref(&self) -> &[u8] {
+ &self.token
+ }
+}
+
+impl ResumptionToken {
+ #[must_use]
+ pub fn new(token: Vec<u8>, expiration_time: Instant) -> Self {
+ Self {
+ token,
+ expiration_time,
+ }
+ }
+
+ #[must_use]
+ pub fn expiration_time(&self) -> Instant {
+ self.expiration_time
+ }
+}
+
+/// A TLS Client.
+#[derive(Debug)]
+#[allow(
+ renamed_and_removed_lints,
+ clippy::box_vec,
+ unknown_lints,
+ clippy::box_collection
+)] // We need the Box.
+pub struct Client {
+ agent: SecretAgent,
+
+ /// The name of the server we're attempting a connection to.
+ server_name: String,
+ /// Records the resumption tokens we've received.
+ resumption: Pin<Box<Vec<ResumptionToken>>>,
+}
+
+impl Client {
+ /// Create a new client agent.
+ ///
+ /// # Errors
+ /// Errors returned if the socket can't be created or configured.
+ pub fn new(server_name: impl Into<String>) -> Res<Self> {
+ let server_name = server_name.into();
+ let mut agent = SecretAgent::new()?;
+ let url = CString::new(server_name.as_bytes())?;
+ secstatus_to_res(unsafe { ssl::SSL_SetURL(agent.fd, url.as_ptr()) })?;
+ agent.ready(false)?;
+ let mut client = Self {
+ agent,
+ server_name,
+ resumption: Box::pin(Vec::new()),
+ };
+ client.ready()?;
+ Ok(client)
+ }
+
+ unsafe extern "C" fn resumption_token_cb(
+ fd: *mut ssl::PRFileDesc,
+ token: *const u8,
+ len: c_uint,
+ arg: *mut c_void,
+ ) -> ssl::SECStatus {
+ let mut info: MaybeUninit<ssl::SSLResumptionTokenInfo> = MaybeUninit::uninit();
+ if ssl::SSL_GetResumptionTokenInfo(
+ token,
+ len,
+ info.as_mut_ptr(),
+ c_uint::try_from(mem::size_of::<ssl::SSLResumptionTokenInfo>()).unwrap(),
+ )
+ .is_err()
+ {
+ // Ignore the token.
+ return ssl::SECSuccess;
+ }
+ let expiration_time = info.assume_init().expirationTime;
+ if ssl::SSL_DestroyResumptionTokenInfo(info.as_mut_ptr()).is_err() {
+ // Ignore the token.
+ return ssl::SECSuccess;
+ }
+ let resumption = arg.cast::<Vec<ResumptionToken>>().as_mut().unwrap();
+ let len = usize::try_from(len).unwrap();
+ let mut v = Vec::with_capacity(len);
+ v.extend_from_slice(std::slice::from_raw_parts(token, len));
+ qinfo!(
+ [format!("{:p}", fd)],
+ "Got resumption token {}",
+ hex_snip_middle(&v)
+ );
+
+ if resumption.len() >= MAX_TICKETS {
+ resumption.remove(0);
+ }
+ if let Ok(t) = Time::try_from(expiration_time) {
+ resumption.push(ResumptionToken::new(v, *t));
+ }
+ ssl::SECSuccess
+ }
+
+ #[must_use]
+ pub fn server_name(&self) -> &str {
+ &self.server_name
+ }
+
+ fn ready(&mut self) -> Res<()> {
+ let fd = self.fd;
+ unsafe {
+ ssl::SSL_SetResumptionTokenCallback(
+ fd,
+ Some(Self::resumption_token_cb),
+ as_c_void(&mut self.resumption),
+ )
+ }
+ }
+
+ /// Take a resumption token.
+ #[must_use]
+ pub fn resumption_token(&mut self) -> Option<ResumptionToken> {
+ (*self.resumption).pop()
+ }
+
+ /// Check if there are more resumption tokens.
+ #[must_use]
+ pub fn has_resumption_token(&self) -> bool {
+ !(*self.resumption).is_empty()
+ }
+
+ /// Enable resumption, using a token previously provided.
+ ///
+ /// # Errors
+ /// Error returned when the resumption token is invalid or
+ /// the socket is not able to use the value.
+ pub fn enable_resumption(&mut self, token: impl AsRef<[u8]>) -> Res<()> {
+ unsafe {
+ ssl::SSL_SetResumptionToken(
+ self.agent.fd,
+ token.as_ref().as_ptr(),
+ c_uint::try_from(token.as_ref().len())?,
+ )
+ }
+ }
+
+ /// Enable encrypted client hello (ECH), using the encoded `ECHConfigList`.
+ ///
+ /// When ECH is enabled, a client needs to look for `Error::EchRetry` as a
+ /// failure code. If `Error::EchRetry` is received when connecting, the
+ /// connection attempt should be retried and the included value provided
+ /// to this function (instead of what is received from DNS).
+ ///
+ /// Calling this function with an empty value for `ech_config_list` enables
+ /// ECH greasing. When that is done, there is no need to look for `EchRetry`
+ ///
+ /// # Errors
+ /// Error returned when the configuration is invalid.
+ pub fn enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> {
+ let config = ech_config_list.as_ref();
+ qdebug!([self], "Enable ECH for a server: {}", hex_with_len(config));
+ self.ech_config = Vec::from(config);
+ if config.is_empty() {
+ unsafe { ech::SSL_EnableTls13GreaseEch(self.agent.fd, PRBool::from(true)) }
+ } else {
+ unsafe {
+ ech::SSL_SetClientEchConfigs(
+ self.agent.fd,
+ config.as_ptr(),
+ c_uint::try_from(config.len())?,
+ )
+ }
+ }
+ }
+}
+
+impl Deref for Client {
+ type Target = SecretAgent;
+ #[must_use]
+ fn deref(&self) -> &SecretAgent {
+ &self.agent
+ }
+}
+
+impl DerefMut for Client {
+ fn deref_mut(&mut self) -> &mut SecretAgent {
+ &mut self.agent
+ }
+}
+
+impl ::std::fmt::Display for Client {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Client {:p}", self.agent.fd)
+ }
+}
+
+/// `ZeroRttCheckResult` encapsulates the options for handling a `ClientHello`.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum ZeroRttCheckResult {
+ /// Accept 0-RTT.
+ Accept,
+ /// Reject 0-RTT, but continue the handshake normally.
+ Reject,
+ /// Send HelloRetryRequest (probably not needed for QUIC).
+ HelloRetryRequest(Vec<u8>),
+ /// Fail the handshake.
+ Fail,
+}
+
+/// A `ZeroRttChecker` is used by the agent to validate the application token (as provided by `send_ticket`)
+pub trait ZeroRttChecker: std::fmt::Debug + std::marker::Unpin {
+ fn check(&self, token: &[u8]) -> ZeroRttCheckResult;
+}
+
+/// Using `AllowZeroRtt` for the implementation of `ZeroRttChecker` means
+/// accepting 0-RTT always. This generally isn't a great idea, so this
+/// generates a strong warning when it is used.
+#[derive(Debug)]
+pub struct AllowZeroRtt {}
+impl ZeroRttChecker for AllowZeroRtt {
+ fn check(&self, _token: &[u8]) -> ZeroRttCheckResult {
+ qwarn!("AllowZeroRtt accepting 0-RTT");
+ ZeroRttCheckResult::Accept
+ }
+}
+
+#[derive(Debug)]
+struct ZeroRttCheckState {
+ checker: Pin<Box<dyn ZeroRttChecker>>,
+}
+
+impl ZeroRttCheckState {
+ pub fn new(checker: Box<dyn ZeroRttChecker>) -> Self {
+ Self {
+ checker: Pin::new(checker),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Server {
+ agent: SecretAgent,
+ /// This holds the HRR callback context.
+ zero_rtt_check: Option<Pin<Box<ZeroRttCheckState>>>,
+}
+
+impl Server {
+ /// Create a new server agent.
+ ///
+ /// # Errors
+ /// Errors returned when NSS fails.
+ pub fn new(certificates: &[impl AsRef<str>]) -> Res<Self> {
+ let mut agent = SecretAgent::new()?;
+
+ for n in certificates {
+ let c = CString::new(n.as_ref())?;
+ let cert_ptr = unsafe { p11::PK11_FindCertFromNickname(c.as_ptr(), null_mut()) };
+ let cert = if let Ok(c) = p11::Certificate::from_ptr(cert_ptr) {
+ c
+ } else {
+ return Err(Error::CertificateLoading);
+ };
+ let key_ptr = unsafe { p11::PK11_FindKeyByAnyCert(*cert.deref(), null_mut()) };
+ let key = if let Ok(k) = p11::PrivateKey::from_ptr(key_ptr) {
+ k
+ } else {
+ return Err(Error::CertificateLoading);
+ };
+ secstatus_to_res(unsafe {
+ ssl::SSL_ConfigServerCert(agent.fd, *cert.deref(), *key.deref(), null(), 0)
+ })?;
+ }
+
+ agent.ready(true)?;
+ Ok(Self {
+ agent,
+ zero_rtt_check: None,
+ })
+ }
+
+ unsafe extern "C" fn hello_retry_cb(
+ first_hello: PRBool,
+ client_token: *const u8,
+ client_token_len: c_uint,
+ retry_token: *mut u8,
+ retry_token_len: *mut c_uint,
+ retry_token_max: c_uint,
+ arg: *mut c_void,
+ ) -> ssl::SSLHelloRetryRequestAction::Type {
+ if first_hello == 0 {
+ // On the second ClientHello after HelloRetryRequest, skip checks.
+ return ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept;
+ }
+
+ let check_state = arg.cast::<ZeroRttCheckState>().as_mut().unwrap();
+ let token = if client_token.is_null() {
+ &[]
+ } else {
+ std::slice::from_raw_parts(client_token, usize::try_from(client_token_len).unwrap())
+ };
+ match check_state.checker.check(token) {
+ ZeroRttCheckResult::Accept => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept,
+ ZeroRttCheckResult::Fail => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_fail,
+ ZeroRttCheckResult::Reject => {
+ ssl::SSLHelloRetryRequestAction::ssl_hello_retry_reject_0rtt
+ }
+ ZeroRttCheckResult::HelloRetryRequest(tok) => {
+ // Don't bother propagating errors from this, because it should be caught in testing.
+ assert!(tok.len() <= usize::try_from(retry_token_max).unwrap());
+ let slc = std::slice::from_raw_parts_mut(retry_token, tok.len());
+ slc.copy_from_slice(&tok);
+ *retry_token_len = c_uint::try_from(tok.len()).unwrap();
+ ssl::SSLHelloRetryRequestAction::ssl_hello_retry_request
+ }
+ }
+ }
+
+ /// Enable 0-RTT. This shadows the function of the same name that can be accessed
+ /// via the Deref implementation on Server.
+ ///
+ /// # Errors
+ /// Returns an error if the underlying NSS functions fail.
+ pub fn enable_0rtt(
+ &mut self,
+ anti_replay: &AntiReplay,
+ max_early_data: u32,
+ checker: Box<dyn ZeroRttChecker>,
+ ) -> Res<()> {
+ let mut check_state = Box::pin(ZeroRttCheckState::new(checker));
+ unsafe {
+ ssl::SSL_HelloRetryRequestCallback(
+ self.agent.fd,
+ Some(Self::hello_retry_cb),
+ as_c_void(&mut check_state),
+ )
+ }?;
+ unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?;
+ self.zero_rtt_check = Some(check_state);
+ self.agent.enable_0rtt()?;
+ anti_replay.config_socket(self.fd)?;
+ Ok(())
+ }
+
+ /// Send a session ticket to the client.
+ /// This adds |extra| application-specific content into that ticket.
+ /// The records that are sent are captured and returned.
+ ///
+ /// # Errors
+ /// If NSS is unable to send a ticket, or if this agent is incorrectly configured.
+ pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList> {
+ self.agent.now.set(now)?;
+ let records = self.setup_raw()?;
+
+ unsafe {
+ ssl::SSL_SendSessionTicket(self.fd, extra.as_ptr(), c_uint::try_from(extra.len())?)
+ }?;
+
+ Ok(*Pin::into_inner(records))
+ }
+
+ /// Enable encrypted client hello (ECH).
+ ///
+ /// # Errors
+ /// Fails when NSS cannot create a key pair.
+ pub fn enable_ech(
+ &mut self,
+ config: u8,
+ public_name: &str,
+ sk: &PrivateKey,
+ pk: &PublicKey,
+ ) -> Res<()> {
+ let cfg = ech::encode_config(config, public_name, pk)?;
+ qdebug!([self], "Enable ECH for a server: {}", hex_with_len(&cfg));
+ unsafe {
+ ech::SSL_SetServerEchConfigs(
+ self.agent.fd,
+ **pk,
+ **sk,
+ cfg.as_ptr(),
+ c_uint::try_from(cfg.len())?,
+ )?;
+ };
+ self.ech_config = cfg;
+ Ok(())
+ }
+}
+
+impl Deref for Server {
+ type Target = SecretAgent;
+ #[must_use]
+ fn deref(&self) -> &SecretAgent {
+ &self.agent
+ }
+}
+
+impl DerefMut for Server {
+ fn deref_mut(&mut self) -> &mut SecretAgent {
+ &mut self.agent
+ }
+}
+
+impl ::std::fmt::Display for Server {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "Server {:p}", self.agent.fd)
+ }
+}
+
+/// A generic container for Client or Server.
+#[derive(Debug)]
+pub enum Agent {
+ Client(crate::agent::Client),
+ Server(crate::agent::Server),
+}
+
+impl Deref for Agent {
+ type Target = SecretAgent;
+ #[must_use]
+ fn deref(&self) -> &SecretAgent {
+ match self {
+ Self::Client(c) => c,
+ Self::Server(s) => s,
+ }
+ }
+}
+
+impl DerefMut for Agent {
+ fn deref_mut(&mut self) -> &mut SecretAgent {
+ match self {
+ Self::Client(c) => c,
+ Self::Server(s) => s,
+ }
+ }
+}
+
+impl From<Client> for Agent {
+ #[must_use]
+ fn from(c: Client) -> Self {
+ Self::Client(c)
+ }
+}
+
+impl From<Server> for Agent {
+ #[must_use]
+ fn from(s: Server) -> Self {
+ Self::Server(s)
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/agentio.rs b/third_party/rust/neqo-crypto/src/agentio.rs
new file mode 100644
index 0000000000..1d39b2398a
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/agentio.rs
@@ -0,0 +1,393 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{ContentType, Epoch};
+use crate::err::{nspr, Error, PR_SetError, Res};
+use crate::prio;
+use crate::ssl;
+
+use neqo_common::{hex, hex_with_len, qtrace};
+use std::cmp::min;
+use std::convert::{TryFrom, TryInto};
+use std::fmt;
+use std::mem;
+use std::ops::Deref;
+use std::os::raw::{c_uint, c_void};
+use std::pin::Pin;
+use std::ptr::{null, null_mut};
+use std::vec::Vec;
+
+// Alias common types.
+type PrFd = *mut prio::PRFileDesc;
+type PrStatus = prio::PRStatus::Type;
+const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS;
+const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE;
+
+/// Convert a pinned, boxed object into a void pointer.
+pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
+ (Pin::into_inner(pin.as_mut()) as *mut T).cast()
+}
+
+/// A slice of the output.
+#[derive(Default)]
+pub struct Record {
+ pub epoch: Epoch,
+ pub ct: ContentType,
+ pub data: Vec<u8>,
+}
+
+impl Record {
+ #[must_use]
+ pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self {
+ Self {
+ epoch,
+ ct,
+ data: data.to_vec(),
+ }
+ }
+
+ // Shoves this record into the socket, returns true if blocked.
+ pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> {
+ qtrace!("write {:?}", self);
+ unsafe {
+ ssl::SSL_RecordLayerData(
+ fd,
+ self.epoch,
+ ssl::SSLContentType::Type::from(self.ct),
+ self.data.as_ptr(),
+ c_uint::try_from(self.data.len())?,
+ )
+ }
+ }
+}
+
+impl fmt::Debug for Record {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Record {:?}:{:?} {}",
+ self.epoch,
+ self.ct,
+ hex_with_len(&self.data[..])
+ )
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct RecordList {
+ records: Vec<Record>,
+}
+
+impl RecordList {
+ fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) {
+ self.records.push(Record::new(epoch, ct, data));
+ }
+
+ #[allow(clippy::unused_self)]
+ unsafe extern "C" fn ingest(
+ _fd: *mut ssl::PRFileDesc,
+ epoch: ssl::PRUint16,
+ ct: ssl::SSLContentType::Type,
+ data: *const ssl::PRUint8,
+ len: c_uint,
+ arg: *mut c_void,
+ ) -> ssl::SECStatus {
+ let records = arg.cast::<Self>().as_mut().unwrap();
+
+ let slice = std::slice::from_raw_parts(data, len as usize);
+ records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
+ ssl::SECSuccess
+ }
+
+ /// Create a new record list.
+ pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> {
+ let mut records = Box::pin(Self::default());
+ unsafe {
+ ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records))
+ }?;
+ Ok(records)
+ }
+}
+
+impl Deref for RecordList {
+ type Target = Vec<Record>;
+ #[must_use]
+ fn deref(&self) -> &Vec<Record> {
+ &self.records
+ }
+}
+
+pub struct RecordListIter(std::vec::IntoIter<Record>);
+
+impl Iterator for RecordListIter {
+ type Item = Record;
+ fn next(&mut self) -> Option<Self::Item> {
+ self.0.next()
+ }
+}
+
+impl IntoIterator for RecordList {
+ type Item = Record;
+ type IntoIter = RecordListIter;
+ #[must_use]
+ fn into_iter(self) -> Self::IntoIter {
+ RecordListIter(self.records.into_iter())
+ }
+}
+
+pub struct AgentIoInputContext<'a> {
+ input: &'a mut AgentIoInput,
+}
+
+impl<'a> Drop for AgentIoInputContext<'a> {
+ fn drop(&mut self) {
+ self.input.reset();
+ }
+}
+
+#[derive(Debug)]
+struct AgentIoInput {
+ // input is data that is read by TLS.
+ input: *const u8,
+ // input_available is how much data is left for reading.
+ available: usize,
+}
+
+impl AgentIoInput {
+ fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
+ assert!(self.input.is_null());
+ self.input = input.as_ptr();
+ self.available = input.len();
+ qtrace!("AgentIoInput wrap {:p}", self.input);
+ AgentIoInputContext { input: self }
+ }
+
+ // Take the data provided as input and provide it to the TLS stack.
+ fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> {
+ let amount = min(self.available, count);
+ if amount == 0 {
+ unsafe {
+ PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0);
+ }
+ return Err(Error::NoDataAvailable);
+ }
+
+ let src = unsafe { std::slice::from_raw_parts(self.input, amount) };
+ qtrace!([self], "read {}", hex(src));
+ let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) };
+ dst.copy_from_slice(src);
+ self.input = self.input.wrapping_add(amount);
+ self.available -= amount;
+ Ok(amount)
+ }
+
+ fn reset(&mut self) {
+ qtrace!([self], "reset");
+ self.input = null();
+ self.available = 0;
+ }
+}
+
+impl ::std::fmt::Display for AgentIoInput {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "AgentIoInput {:p}", self.input)
+ }
+}
+
+#[derive(Debug)]
+pub struct AgentIo {
+ // input collects the input we might provide to TLS.
+ input: AgentIoInput,
+
+ // output contains data that is written by TLS.
+ output: Vec<u8>,
+}
+
+impl AgentIo {
+ pub fn new() -> Self {
+ Self {
+ input: AgentIoInput {
+ input: null(),
+ available: 0,
+ },
+ output: Vec::new(),
+ }
+ }
+
+ unsafe fn borrow(fd: &mut PrFd) -> &mut Self {
+ #[allow(clippy::cast_ptr_alignment)]
+ (**fd).secret.cast::<Self>().as_mut().unwrap()
+ }
+
+ pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
+ assert_eq!(self.output.len(), 0);
+ self.input.wrap(input)
+ }
+
+ // Stage output from TLS into the output buffer.
+ fn save_output(&mut self, buf: *const u8, count: usize) {
+ let slice = unsafe { std::slice::from_raw_parts(buf, count) };
+ qtrace!([self], "save output {}", hex(slice));
+ self.output.extend_from_slice(slice);
+ }
+
+ pub fn take_output(&mut self) -> Vec<u8> {
+ qtrace!([self], "take output");
+ mem::take(&mut self.output)
+ }
+}
+
+impl ::std::fmt::Display for AgentIo {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "AgentIo")
+ }
+}
+
+unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus {
+ (*fd).secret = null_mut();
+ if let Some(dtor) = (*fd).dtor {
+ dtor(fd);
+ }
+ PR_SUCCESS
+}
+
+unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus {
+ let io = AgentIo::borrow(&mut fd);
+ if let Ok(a) = usize::try_from(amount) {
+ match io.input.read_input(buf.cast(), a) {
+ Ok(_) => PR_SUCCESS,
+ Err(_) => PR_FAILURE,
+ }
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_recv(
+ mut fd: PrFd,
+ buf: *mut c_void,
+ amount: prio::PRInt32,
+ flags: prio::PRIntn,
+ _timeout: prio::PRIntervalTime,
+) -> prio::PRInt32 {
+ let io = AgentIo::borrow(&mut fd);
+ if flags != 0 {
+ return PR_FAILURE;
+ }
+ if let Ok(a) = usize::try_from(amount) {
+ match io.input.read_input(buf.cast(), a) {
+ Ok(v) => prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE),
+ Err(_) => PR_FAILURE,
+ }
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_write(
+ mut fd: PrFd,
+ buf: *const c_void,
+ amount: prio::PRInt32,
+) -> PrStatus {
+ let io = AgentIo::borrow(&mut fd);
+ if let Ok(a) = usize::try_from(amount) {
+ io.save_output(buf.cast(), a);
+ amount
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_send(
+ mut fd: PrFd,
+ buf: *const c_void,
+ amount: prio::PRInt32,
+ flags: prio::PRIntn,
+ _timeout: prio::PRIntervalTime,
+) -> prio::PRInt32 {
+ let io = AgentIo::borrow(&mut fd);
+
+ if flags != 0 {
+ return PR_FAILURE;
+ }
+ if let Ok(a) = usize::try_from(amount) {
+ io.save_output(buf.cast(), a);
+ amount
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 {
+ let io = AgentIo::borrow(&mut fd);
+ io.input.available.try_into().unwrap_or(PR_FAILURE)
+}
+
+unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 {
+ let io = AgentIo::borrow(&mut fd);
+ io.input
+ .available
+ .try_into()
+ .unwrap_or_else(|_| PR_FAILURE.into())
+}
+
+#[allow(clippy::cast_possible_truncation)]
+unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
+ let a = addr.as_mut().unwrap();
+ // Cast is safe because prio::PR_AF_INET is 2
+ a.inet.family = prio::PR_AF_INET as prio::PRUint16;
+ a.inet.port = 0;
+ a.inet.ip = 0;
+ PR_SUCCESS
+}
+
+unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus {
+ let o = opt.as_mut().unwrap();
+ if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
+ o.value.non_blocking = 1;
+ return PR_SUCCESS;
+ }
+ PR_FAILURE
+}
+
+pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods {
+ file_type: prio::PRDescType::PR_DESC_LAYERED,
+ close: Some(agent_close),
+ read: Some(agent_read),
+ write: Some(agent_write),
+ available: Some(agent_available),
+ available64: Some(agent_available64),
+ fsync: None,
+ seek: None,
+ seek64: None,
+ fileInfo: None,
+ fileInfo64: None,
+ writev: None,
+ connect: None,
+ accept: None,
+ bind: None,
+ listen: None,
+ shutdown: None,
+ recv: Some(agent_recv),
+ send: Some(agent_send),
+ recvfrom: None,
+ sendto: None,
+ poll: None,
+ acceptread: None,
+ transmitfile: None,
+ getsockname: Some(agent_getname),
+ getpeername: Some(agent_getname),
+ reserved_fn_6: None,
+ reserved_fn_5: None,
+ getsocketoption: Some(agent_getsockopt),
+ setsocketoption: None,
+ sendfile: None,
+ connectcontinue: None,
+ reserved_fn_3: None,
+ reserved_fn_2: None,
+ reserved_fn_1: None,
+ reserved_fn_0: None,
+};
diff --git a/third_party/rust/neqo-crypto/src/auth.rs b/third_party/rust/neqo-crypto/src/auth.rs
new file mode 100644
index 0000000000..2932cdf2eb
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/auth.rs
@@ -0,0 +1,108 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::err::{mozpkix, sec, ssl, PRErrorCode};
+
+/// The outcome of authentication.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum AuthenticationStatus {
+ Ok,
+ CaInvalid,
+ CaNotV3,
+ CertAlgorithmDisabled,
+ CertExpired,
+ CertInvalidTime,
+ CertIsCa,
+ CertKeyUsage,
+ CertMitm,
+ CertNotYetValid,
+ CertRevoked,
+ CertSelfSigned,
+ CertSubjectInvalid,
+ CertUntrusted,
+ CertWeakKey,
+ IssuerEmptyName,
+ IssuerExpired,
+ IssuerNotYetValid,
+ IssuerUnknown,
+ IssuerUntrusted,
+ PolicyRejection,
+ Unknown,
+}
+
+impl From<AuthenticationStatus> for PRErrorCode {
+ #[must_use]
+ fn from(v: AuthenticationStatus) -> Self {
+ match v {
+ AuthenticationStatus::Ok => 0,
+ AuthenticationStatus::CaInvalid => sec::SEC_ERROR_CA_CERT_INVALID,
+ AuthenticationStatus::CaNotV3 => mozpkix::MOZILLA_PKIX_ERROR_V1_CERT_USED_AS_CA,
+ AuthenticationStatus::CertAlgorithmDisabled => {
+ sec::SEC_ERROR_CERT_SIGNATURE_ALGORITHM_DISABLED
+ }
+ AuthenticationStatus::CertExpired => sec::SEC_ERROR_EXPIRED_CERTIFICATE,
+ AuthenticationStatus::CertInvalidTime => sec::SEC_ERROR_INVALID_TIME,
+ AuthenticationStatus::CertIsCa => {
+ mozpkix::MOZILLA_PKIX_ERROR_CA_CERT_USED_AS_END_ENTITY
+ }
+ AuthenticationStatus::CertKeyUsage => sec::SEC_ERROR_INADEQUATE_KEY_USAGE,
+ AuthenticationStatus::CertMitm => mozpkix::MOZILLA_PKIX_ERROR_MITM_DETECTED,
+ AuthenticationStatus::CertNotYetValid => {
+ mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_CERTIFICATE
+ }
+ AuthenticationStatus::CertRevoked => sec::SEC_ERROR_REVOKED_CERTIFICATE,
+ AuthenticationStatus::CertSelfSigned => mozpkix::MOZILLA_PKIX_ERROR_SELF_SIGNED_CERT,
+ AuthenticationStatus::CertSubjectInvalid => ssl::SSL_ERROR_BAD_CERT_DOMAIN,
+ AuthenticationStatus::CertUntrusted => sec::SEC_ERROR_UNTRUSTED_CERT,
+ AuthenticationStatus::CertWeakKey => mozpkix::MOZILLA_PKIX_ERROR_INADEQUATE_KEY_SIZE,
+ AuthenticationStatus::IssuerEmptyName => mozpkix::MOZILLA_PKIX_ERROR_EMPTY_ISSUER_NAME,
+ AuthenticationStatus::IssuerExpired => sec::SEC_ERROR_EXPIRED_ISSUER_CERTIFICATE,
+ AuthenticationStatus::IssuerNotYetValid => {
+ mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_ISSUER_CERTIFICATE
+ }
+ AuthenticationStatus::IssuerUnknown => sec::SEC_ERROR_UNKNOWN_ISSUER,
+ AuthenticationStatus::IssuerUntrusted => sec::SEC_ERROR_UNTRUSTED_ISSUER,
+ AuthenticationStatus::PolicyRejection => {
+ mozpkix::MOZILLA_PKIX_ERROR_ADDITIONAL_POLICY_CONSTRAINT_FAILED
+ }
+ AuthenticationStatus::Unknown => sec::SEC_ERROR_LIBRARY_FAILURE,
+ }
+ }
+}
+
+// Note that this mapping should be removed after gecko eventually learns how to
+// map into the enumerated type.
+impl From<PRErrorCode> for AuthenticationStatus {
+ #[must_use]
+ fn from(v: PRErrorCode) -> Self {
+ match v {
+ 0 => Self::Ok,
+ sec::SEC_ERROR_CA_CERT_INVALID => Self::CaInvalid,
+ mozpkix::MOZILLA_PKIX_ERROR_V1_CERT_USED_AS_CA => Self::CaNotV3,
+ sec::SEC_ERROR_CERT_SIGNATURE_ALGORITHM_DISABLED => Self::CertAlgorithmDisabled,
+ sec::SEC_ERROR_EXPIRED_CERTIFICATE => Self::CertExpired,
+ sec::SEC_ERROR_INVALID_TIME => Self::CertInvalidTime,
+ mozpkix::MOZILLA_PKIX_ERROR_CA_CERT_USED_AS_END_ENTITY => Self::CertIsCa,
+ sec::SEC_ERROR_INADEQUATE_KEY_USAGE => Self::CertKeyUsage,
+ mozpkix::MOZILLA_PKIX_ERROR_MITM_DETECTED => Self::CertMitm,
+ mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_CERTIFICATE => Self::CertNotYetValid,
+ sec::SEC_ERROR_REVOKED_CERTIFICATE => Self::CertRevoked,
+ mozpkix::MOZILLA_PKIX_ERROR_SELF_SIGNED_CERT => Self::CertSelfSigned,
+ ssl::SSL_ERROR_BAD_CERT_DOMAIN => Self::CertSubjectInvalid,
+ sec::SEC_ERROR_UNTRUSTED_CERT => Self::CertUntrusted,
+ mozpkix::MOZILLA_PKIX_ERROR_INADEQUATE_KEY_SIZE => Self::CertWeakKey,
+ mozpkix::MOZILLA_PKIX_ERROR_EMPTY_ISSUER_NAME => Self::IssuerEmptyName,
+ sec::SEC_ERROR_EXPIRED_ISSUER_CERTIFICATE => Self::IssuerExpired,
+ mozpkix::MOZILLA_PKIX_ERROR_NOT_YET_VALID_ISSUER_CERTIFICATE => Self::IssuerNotYetValid,
+ sec::SEC_ERROR_UNKNOWN_ISSUER => Self::IssuerUnknown,
+ sec::SEC_ERROR_UNTRUSTED_ISSUER => Self::IssuerUntrusted,
+ mozpkix::MOZILLA_PKIX_ERROR_ADDITIONAL_POLICY_CONSTRAINT_FAILED => {
+ Self::PolicyRejection
+ }
+ _ => Self::Unknown,
+ }
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/cert.rs b/third_party/rust/neqo-crypto/src/cert.rs
new file mode 100644
index 0000000000..d78f57c1b4
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/cert.rs
@@ -0,0 +1,118 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::err::secstatus_to_res;
+use crate::p11::{CERTCertListNode, CERT_GetCertificateDer, CertList, Item, SECItem, SECItemArray};
+use crate::ssl::{
+ PRFileDesc, SSL_PeerCertificateChain, SSL_PeerSignedCertTimestamps,
+ SSL_PeerStapledOCSPResponses,
+};
+use neqo_common::qerror;
+
+use std::convert::TryFrom;
+use std::ptr::{addr_of, NonNull};
+
+use std::slice;
+
+pub struct CertificateInfo {
+ certs: CertList,
+ cursor: *const CERTCertListNode,
+ /// stapled_ocsp_responses and signed_cert_timestamp are properties
+ /// associated with each of the certificates. Right now, NSS only
+ /// reports the value for the end-entity certificate (the first).
+ stapled_ocsp_responses: Option<Vec<Vec<u8>>>,
+ signed_cert_timestamp: Option<Vec<u8>>,
+}
+
+fn peer_certificate_chain(fd: *mut PRFileDesc) -> Option<(CertList, *const CERTCertListNode)> {
+ let chain = unsafe { SSL_PeerCertificateChain(fd) };
+ CertList::from_ptr(chain.cast()).ok().map(|certs| {
+ let cursor = CertificateInfo::head(&certs);
+ (certs, cursor)
+ })
+}
+
+// As explained in rfc6961, an OCSPResponseList can have at most
+// 2^24 items. Casting its length is therefore safe even on 32 bits targets.
+fn stapled_ocsp_responses(fd: *mut PRFileDesc) -> Option<Vec<Vec<u8>>> {
+ let ocsp_nss = unsafe { SSL_PeerStapledOCSPResponses(fd) };
+ match NonNull::new(ocsp_nss as *mut SECItemArray) {
+ Some(ocsp_ptr) => {
+ let mut ocsp_helper: Vec<Vec<u8>> = Vec::new();
+ let len = if let Ok(l) = isize::try_from(unsafe { ocsp_ptr.as_ref().len }) {
+ l
+ } else {
+ qerror!([format!("{:p}", fd)], "Received illegal OSCP length");
+ return None;
+ };
+ for idx in 0..len {
+ let itemp: *const SECItem = unsafe { ocsp_ptr.as_ref().items.offset(idx).cast() };
+ let item = unsafe { slice::from_raw_parts((*itemp).data, (*itemp).len as usize) };
+ ocsp_helper.push(item.to_owned());
+ }
+ Some(ocsp_helper)
+ }
+ None => None,
+ }
+}
+
+fn signed_cert_timestamp(fd: *mut PRFileDesc) -> Option<Vec<u8>> {
+ let sct_nss = unsafe { SSL_PeerSignedCertTimestamps(fd) };
+ match NonNull::new(sct_nss as *mut SECItem) {
+ Some(sct_ptr) => {
+ if unsafe { sct_ptr.as_ref().len == 0 || sct_ptr.as_ref().data.is_null() } {
+ Some(Vec::new())
+ } else {
+ let sct_slice = unsafe {
+ slice::from_raw_parts(sct_ptr.as_ref().data, sct_ptr.as_ref().len as usize)
+ };
+ Some(sct_slice.to_owned())
+ }
+ }
+ None => None,
+ }
+}
+
+impl CertificateInfo {
+ pub(crate) fn new(fd: *mut PRFileDesc) -> Option<Self> {
+ peer_certificate_chain(fd).map(|(certs, cursor)| Self {
+ certs,
+ cursor,
+ stapled_ocsp_responses: stapled_ocsp_responses(fd),
+ signed_cert_timestamp: signed_cert_timestamp(fd),
+ })
+ }
+
+ fn head(certs: &CertList) -> *const CERTCertListNode {
+ // Three stars: one for the reference, one for the wrapper, one to deference the pointer.
+ unsafe { addr_of!((***certs).list).cast() }
+ }
+}
+
+impl<'a> Iterator for &'a mut CertificateInfo {
+ type Item = &'a [u8];
+ fn next(&mut self) -> Option<&'a [u8]> {
+ self.cursor = unsafe { *self.cursor }.links.next.cast();
+ if self.cursor == CertificateInfo::head(&self.certs) {
+ return None;
+ }
+ let mut item = Item::make_empty();
+ let cert = unsafe { *self.cursor }.cert;
+ secstatus_to_res(unsafe { CERT_GetCertificateDer(cert, &mut item) })
+ .expect("getting DER from certificate should work");
+ Some(unsafe { std::slice::from_raw_parts(item.data, item.len as usize) })
+ }
+}
+
+impl CertificateInfo {
+ pub fn stapled_ocsp_responses(&mut self) -> &Option<Vec<Vec<u8>>> {
+ &self.stapled_ocsp_responses
+ }
+
+ pub fn signed_cert_timestamp(&mut self) -> &Option<Vec<u8>> {
+ &self.signed_cert_timestamp
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/constants.rs b/third_party/rust/neqo-crypto/src/constants.rs
new file mode 100644
index 0000000000..21e1a5aceb
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/constants.rs
@@ -0,0 +1,145 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(dead_code)]
+
+use crate::ssl;
+
+// Ideally all of these would be enums, but size matters and we need to allow
+// for values outside of those that are defined here.
+
+pub type Alert = u8;
+
+pub type Epoch = u16;
+// TLS doesn't really have an "initial" concept that maps to QUIC so directly,
+// but this should be clear enough.
+pub const TLS_EPOCH_INITIAL: Epoch = 0_u16;
+pub const TLS_EPOCH_ZERO_RTT: Epoch = 1_u16;
+pub const TLS_EPOCH_HANDSHAKE: Epoch = 2_u16;
+// Also, we don't use TLS epochs > 3.
+pub const TLS_EPOCH_APPLICATION_DATA: Epoch = 3_u16;
+
+/// Rather than defining a type alias and a bunch of constants, which leads to a ton of repetition,
+/// use this macro.
+macro_rules! remap_enum {
+ { $t:ident: $s:ty { $( $n:ident = $v:path ),+ $(,)? } } => {
+ pub type $t = $s;
+ $( pub const $n: $t = $v as $t; )+
+ };
+ { $t:ident: $s:ty => $e:ident { $( $n:ident = $v:ident ),+ $(,)? } } => {
+ remap_enum!{ $t: $s { $( $n = $e::$v ),+ } }
+ };
+ { $t:ident: $s:ty => $p:ident::$e:ident { $( $n:ident = $v:ident ),+ $(,)? } } => {
+ remap_enum!{ $t: $s { $( $n = $p::$e::$v ),+ } }
+ };
+}
+
+remap_enum! {
+ Version: u16 => ssl {
+ TLS_VERSION_1_2 = SSL_LIBRARY_VERSION_TLS_1_2,
+ TLS_VERSION_1_3 = SSL_LIBRARY_VERSION_TLS_1_3,
+ }
+}
+
+mod ciphers {
+ include!(concat!(env!("OUT_DIR"), "/nss_ciphers.rs"));
+}
+
+remap_enum! {
+ Cipher: u16 => ciphers {
+ TLS_AES_128_GCM_SHA256 = TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384 = TLS_AES_256_GCM_SHA384,
+ TLS_CHACHA20_POLY1305_SHA256 = TLS_CHACHA20_POLY1305_SHA256,
+ }
+}
+
+remap_enum! {
+ Group: u16 => ssl::SSLNamedGroup {
+ TLS_GRP_EC_SECP256R1 = ssl_grp_ec_secp256r1,
+ TLS_GRP_EC_SECP384R1 = ssl_grp_ec_secp384r1,
+ TLS_GRP_EC_SECP521R1 = ssl_grp_ec_secp521r1,
+ TLS_GRP_EC_X25519 = ssl_grp_ec_curve25519,
+ }
+}
+
+remap_enum! {
+ HandshakeMessage: u8 => ssl::SSLHandshakeType {
+ TLS_HS_HELLO_REQUEST = ssl_hs_hello_request,
+ TLS_HS_CLIENT_HELLO = ssl_hs_client_hello,
+ TLS_HS_SERVER_HELLO = ssl_hs_server_hello,
+ TLS_HS_HELLO_VERIFY_REQUEST = ssl_hs_hello_verify_request,
+ TLS_HS_NEW_SESSION_TICKET = ssl_hs_new_session_ticket,
+ TLS_HS_END_OF_EARLY_DATA = ssl_hs_end_of_early_data,
+ TLS_HS_HELLO_RETRY_REQUEST = ssl_hs_hello_retry_request,
+ TLS_HS_ENCRYPTED_EXTENSIONS = ssl_hs_encrypted_extensions,
+ TLS_HS_CERTIFICATE = ssl_hs_certificate,
+ TLS_HS_SERVER_KEY_EXCHANGE = ssl_hs_server_key_exchange,
+ TLS_HS_CERTIFICATE_REQUEST = ssl_hs_certificate_request,
+ TLS_HS_SERVER_HELLO_DONE = ssl_hs_server_hello_done,
+ TLS_HS_CERTIFICATE_VERIFY = ssl_hs_certificate_verify,
+ TLS_HS_CLIENT_KEY_EXCHANGE = ssl_hs_client_key_exchange,
+ TLS_HS_FINISHED = ssl_hs_finished,
+ TLS_HS_CERT_STATUS = ssl_hs_certificate_status,
+ TLS_HS_KEY_UDPATE = ssl_hs_key_update,
+ }
+}
+
+remap_enum! {
+ ContentType: u8 => ssl::SSLContentType {
+ TLS_CT_CHANGE_CIPHER_SPEC = ssl_ct_change_cipher_spec,
+ TLS_CT_ALERT = ssl_ct_alert,
+ TLS_CT_HANDSHAKE = ssl_ct_handshake,
+ TLS_CT_APPLICATION_DATA = ssl_ct_application_data,
+ TLS_CT_ACK = ssl_ct_ack,
+ }
+}
+
+remap_enum! {
+ Extension: u16 => ssl::SSLExtensionType {
+ TLS_EXT_SERVER_NAME = ssl_server_name_xtn,
+ TLS_EXT_CERT_STATUS = ssl_cert_status_xtn,
+ TLS_EXT_GROUPS = ssl_supported_groups_xtn,
+ TLS_EXT_EC_POINT_FORMATS = ssl_ec_point_formats_xtn,
+ TLS_EXT_SIG_SCHEMES = ssl_signature_algorithms_xtn,
+ TLS_EXT_USE_SRTP = ssl_use_srtp_xtn,
+ TLS_EXT_ALPN = ssl_app_layer_protocol_xtn,
+ TLS_EXT_SCT = ssl_signed_cert_timestamp_xtn,
+ TLS_EXT_PADDING = ssl_padding_xtn,
+ TLS_EXT_EMS = ssl_extended_master_secret_xtn,
+ TLS_EXT_RECORD_SIZE = ssl_record_size_limit_xtn,
+ TLS_EXT_SESSION_TICKET = ssl_session_ticket_xtn,
+ TLS_EXT_PSK = ssl_tls13_pre_shared_key_xtn,
+ TLS_EXT_EARLY_DATA = ssl_tls13_early_data_xtn,
+ TLS_EXT_VERSIONS = ssl_tls13_supported_versions_xtn,
+ TLS_EXT_COOKIE = ssl_tls13_cookie_xtn,
+ TLS_EXT_PSK_MODES = ssl_tls13_psk_key_exchange_modes_xtn,
+ TLS_EXT_CA = ssl_tls13_certificate_authorities_xtn,
+ TLS_EXT_POST_HS_AUTH = ssl_tls13_post_handshake_auth_xtn,
+ TLS_EXT_CERT_SIG_SCHEMES = ssl_signature_algorithms_cert_xtn,
+ TLS_EXT_KEY_SHARE = ssl_tls13_key_share_xtn,
+ TLS_EXT_RENEGOTIATION_INFO = ssl_renegotiation_info_xtn,
+ }
+}
+
+remap_enum! {
+ SignatureScheme: u16 => ssl::SSLSignatureScheme {
+ TLS_SIG_NONE = ssl_sig_none,
+ TLS_SIG_RSA_PKCS1_SHA256 = ssl_sig_rsa_pkcs1_sha256,
+ TLS_SIG_RSA_PKCS1_SHA384 = ssl_sig_rsa_pkcs1_sha384,
+ TLS_SIG_RSA_PKCS1_SHA512 = ssl_sig_rsa_pkcs1_sha512,
+ TLS_SIG_ECDSA_SECP256R1_SHA256 = ssl_sig_ecdsa_secp256r1_sha256,
+ TLS_SIG_ECDSA_SECP384R1_SHA384 = ssl_sig_ecdsa_secp384r1_sha384,
+ TLS_SIG_ECDSA_SECP512R1_SHA512 = ssl_sig_ecdsa_secp521r1_sha512,
+ TLS_SIG_RSA_PSS_RSAE_SHA256 = ssl_sig_rsa_pss_rsae_sha256,
+ TLS_SIG_RSA_PSS_RSAE_SHA384 = ssl_sig_rsa_pss_rsae_sha384,
+ TLS_SIG_RSA_PSS_RSAE_SHA512 = ssl_sig_rsa_pss_rsae_sha512,
+ TLS_SIG_ED25519 = ssl_sig_ed25519,
+ TLS_SIG_ED448 = ssl_sig_ed448,
+ TLS_SIG_RSA_PSS_PSS_SHA256 = ssl_sig_rsa_pss_pss_sha256,
+ TLS_SIG_RSA_PSS_PSS_SHA384 = ssl_sig_rsa_pss_pss_sha384,
+ TLS_SIG_RSA_PSS_PSS_SHA512 = ssl_sig_rsa_pss_pss_sha512,
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/ech.rs b/third_party/rust/neqo-crypto/src/ech.rs
new file mode 100644
index 0000000000..5c774e1360
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/ech.rs
@@ -0,0 +1,190 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::err::{ssl::SSL_ERROR_ECH_RETRY_WITH_ECH, Error, Res};
+use crate::p11::{
+ self, Item, PrivateKey, PublicKey, SECITEM_FreeItem, SECItem, SECKEYPrivateKey,
+ SECKEYPublicKey, Slot,
+};
+use crate::ssl::{PRBool, PRFileDesc};
+use neqo_common::qtrace;
+use std::convert::TryFrom;
+use std::ffi::CString;
+use std::os::raw::{c_char, c_uint};
+use std::ptr::{addr_of_mut, null_mut};
+
+pub use crate::p11::{HpkeAeadId as AeadId, HpkeKdfId as KdfId, HpkeKemId as KemId};
+pub use crate::ssl::HpkeSymmetricSuite as SymmetricSuite;
+
+experimental_api!(SSL_EnableTls13GreaseEch(
+ fd: *mut PRFileDesc,
+ enabled: PRBool,
+));
+
+experimental_api!(SSL_GetEchRetryConfigs(
+ fd: *mut PRFileDesc,
+ config: *mut SECItem,
+));
+
+experimental_api!(SSL_SetClientEchConfigs(
+ fd: *mut PRFileDesc,
+ config_list: *const u8,
+ config_list_len: c_uint,
+));
+
+experimental_api!(SSL_SetServerEchConfigs(
+ fd: *mut PRFileDesc,
+ pk: *const SECKEYPublicKey,
+ sk: *const SECKEYPrivateKey,
+ record: *const u8,
+ record_len: c_uint,
+));
+
+experimental_api!(SSL_EncodeEchConfigId(
+ config_id: u8,
+ public_name: *const c_char,
+ max_name_len: c_uint,
+ kem_id: KemId::Type,
+ pk: *const SECKEYPublicKey,
+ hpke_suites: *const SymmetricSuite,
+ hpke_suite_count: c_uint,
+ out: *mut u8,
+ out_len: *mut c_uint,
+ max_len: c_uint,
+));
+
+/// Convert any result that contains an ECH error into a result with an `EchRetry`.
+pub fn convert_ech_error(fd: *mut PRFileDesc, err: Error) -> Error {
+ if let Error::NssError {
+ code: SSL_ERROR_ECH_RETRY_WITH_ECH,
+ ..
+ } = &err
+ {
+ let mut item = Item::make_empty();
+ if unsafe { SSL_GetEchRetryConfigs(fd, &mut item).is_err() } {
+ return Error::InternalError;
+ }
+ let buf = unsafe {
+ let slc = std::slice::from_raw_parts(item.data, usize::try_from(item.len).unwrap());
+ let buf = Vec::from(slc);
+ SECITEM_FreeItem(&mut item, PRBool::from(false));
+ buf
+ };
+ Error::EchRetry(buf)
+ } else {
+ err
+ }
+}
+
+/// Generate a key pair for encrypted client hello (ECH).
+///
+/// # Errors
+/// When NSS fails to generate a key pair or when the KEM is not supported.
+/// # Panics
+/// When underlying types aren't large enough to hold keys. So never.
+pub fn generate_keys() -> Res<(PrivateKey, PublicKey)> {
+ let slot = Slot::internal()?;
+
+ let oid_data = unsafe { p11::SECOID_FindOIDByTag(p11::SECOidTag::SEC_OID_CURVE25519) };
+ let oid = unsafe { oid_data.as_ref() }.ok_or(Error::InternalError)?;
+ let oid_slc =
+ unsafe { std::slice::from_raw_parts(oid.oid.data, usize::try_from(oid.oid.len).unwrap()) };
+ let mut params: Vec<u8> = Vec::with_capacity(oid_slc.len() + 2);
+ params.push(u8::try_from(p11::SEC_ASN1_OBJECT_ID).unwrap());
+ params.push(u8::try_from(oid.oid.len).unwrap());
+ params.extend_from_slice(oid_slc);
+
+ let mut public_ptr: *mut SECKEYPublicKey = null_mut();
+ let mut param_item = Item::wrap(&params);
+
+ // If we have tracing on, try to ensure that key data can be read.
+ let insensitive_secret_ptr = if log::log_enabled!(log::Level::Trace) {
+ unsafe {
+ p11::PK11_GenerateKeyPairWithOpFlags(
+ *slot,
+ p11::CK_MECHANISM_TYPE::from(p11::CKM_EC_KEY_PAIR_GEN),
+ addr_of_mut!(param_item).cast(),
+ &mut public_ptr,
+ p11::PK11_ATTR_SESSION | p11::PK11_ATTR_INSENSITIVE | p11::PK11_ATTR_PUBLIC,
+ p11::CK_FLAGS::from(p11::CKF_DERIVE),
+ p11::CK_FLAGS::from(p11::CKF_DERIVE),
+ null_mut(),
+ )
+ }
+ } else {
+ null_mut()
+ };
+ assert_eq!(insensitive_secret_ptr.is_null(), public_ptr.is_null());
+ let secret_ptr = if insensitive_secret_ptr.is_null() {
+ unsafe {
+ p11::PK11_GenerateKeyPairWithOpFlags(
+ *slot,
+ p11::CK_MECHANISM_TYPE::from(p11::CKM_EC_KEY_PAIR_GEN),
+ addr_of_mut!(param_item).cast(),
+ &mut public_ptr,
+ p11::PK11_ATTR_SESSION | p11::PK11_ATTR_SENSITIVE | p11::PK11_ATTR_PRIVATE,
+ p11::CK_FLAGS::from(p11::CKF_DERIVE),
+ p11::CK_FLAGS::from(p11::CKF_DERIVE),
+ null_mut(),
+ )
+ }
+ } else {
+ insensitive_secret_ptr
+ };
+ assert_eq!(secret_ptr.is_null(), public_ptr.is_null());
+ let sk = PrivateKey::from_ptr(secret_ptr)?;
+ let pk = PublicKey::from_ptr(public_ptr)?;
+ qtrace!("Generated key pair: sk={:?} pk={:?}", sk, pk);
+ Ok((sk, pk))
+}
+
+/// Encode a configuration for encrypted client hello (ECH).
+///
+/// # Errors
+/// When NSS fails to generate a valid configuration encoding (i.e., unlikely).
+pub fn encode_config(config: u8, public_name: &str, pk: &PublicKey) -> Res<Vec<u8>> {
+ // A sensible fixed value for the maximum length of a name.
+ const MAX_NAME_LEN: c_uint = 64;
+ // Enable a selection of suites.
+ // NSS supports SHA-512 as well, which could be added here.
+ const SUITES: &[SymmetricSuite] = &[
+ SymmetricSuite {
+ kdfId: KdfId::HpkeKdfHkdfSha256,
+ aeadId: AeadId::HpkeAeadAes128Gcm,
+ },
+ SymmetricSuite {
+ kdfId: KdfId::HpkeKdfHkdfSha256,
+ aeadId: AeadId::HpkeAeadChaCha20Poly1305,
+ },
+ SymmetricSuite {
+ kdfId: KdfId::HpkeKdfHkdfSha384,
+ aeadId: AeadId::HpkeAeadAes128Gcm,
+ },
+ SymmetricSuite {
+ kdfId: KdfId::HpkeKdfHkdfSha384,
+ aeadId: AeadId::HpkeAeadChaCha20Poly1305,
+ },
+ ];
+
+ let name = CString::new(public_name)?;
+ let mut encoded = [0; 1024];
+ let mut encoded_len = 0;
+ unsafe {
+ SSL_EncodeEchConfigId(
+ config,
+ name.as_ptr(),
+ MAX_NAME_LEN,
+ KemId::HpkeDhKemX25519Sha256,
+ **pk,
+ SUITES.as_ptr(),
+ c_uint::try_from(SUITES.len())?,
+ encoded.as_mut_ptr(),
+ &mut encoded_len,
+ c_uint::try_from(encoded.len())?,
+ )?;
+ }
+ Ok(Vec::from(&encoded[..usize::try_from(encoded_len)?]))
+}
diff --git a/third_party/rust/neqo-crypto/src/err.rs b/third_party/rust/neqo-crypto/src/err.rs
new file mode 100644
index 0000000000..5ba7c02d4b
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/err.rs
@@ -0,0 +1,214 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(dead_code)]
+#![allow(clippy::upper_case_acronyms)]
+
+use std::os::raw::c_char;
+use std::str::Utf8Error;
+
+use crate::ssl::{SECStatus, SECSuccess};
+
+include!(concat!(env!("OUT_DIR"), "/nspr_error.rs"));
+mod codes {
+ #![allow(non_snake_case)]
+ include!(concat!(env!("OUT_DIR"), "/nss_secerr.rs"));
+ include!(concat!(env!("OUT_DIR"), "/nss_sslerr.rs"));
+ include!(concat!(env!("OUT_DIR"), "/mozpkix.rs"));
+}
+pub use codes::mozilla_pkix_ErrorCode as mozpkix;
+pub use codes::SECErrorCodes as sec;
+pub use codes::SSLErrorCodes as ssl;
+pub mod nspr {
+ include!(concat!(env!("OUT_DIR"), "/nspr_err.rs"));
+}
+
+pub type Res<T> = Result<T, Error>;
+
+#[derive(Clone, Debug, PartialEq, PartialOrd, Ord, Eq)]
+pub enum Error {
+ AeadError,
+ CertificateLoading,
+ CipherInitFailure,
+ CreateSslSocket,
+ EchRetry(Vec<u8>),
+ HkdfError,
+ InternalError,
+ IntegerOverflow,
+ InvalidEpoch,
+ MixedHandshakeMethod,
+ NoDataAvailable,
+ NssError {
+ name: String,
+ code: PRErrorCode,
+ desc: String,
+ },
+ OverrunError,
+ SelfEncryptFailure,
+ StringError,
+ TimeTravelError,
+ UnsupportedCipher,
+ UnsupportedVersion,
+}
+
+impl Error {
+ pub(crate) fn last_nss_error() -> Self {
+ Self::from(unsafe { PR_GetError() })
+ }
+}
+
+impl std::error::Error for Error {
+ #[must_use]
+ fn cause(&self) -> Option<&dyn std::error::Error> {
+ None
+ }
+ #[must_use]
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ None
+ }
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "Error: {:?}", self)
+ }
+}
+
+impl From<std::num::TryFromIntError> for Error {
+ #[must_use]
+ fn from(_: std::num::TryFromIntError) -> Self {
+ Self::IntegerOverflow
+ }
+}
+impl From<std::ffi::NulError> for Error {
+ #[must_use]
+ fn from(_: std::ffi::NulError) -> Self {
+ Self::InternalError
+ }
+}
+impl From<Utf8Error> for Error {
+ fn from(_: Utf8Error) -> Self {
+ Self::StringError
+ }
+}
+impl From<PRErrorCode> for Error {
+ fn from(code: PRErrorCode) -> Self {
+ let name = wrap_str_fn(|| unsafe { PR_ErrorToName(code) }, "UNKNOWN_ERROR");
+ let desc = wrap_str_fn(
+ || unsafe { PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT) },
+ "...",
+ );
+ Self::NssError { name, code, desc }
+ }
+}
+
+use std::ffi::CStr;
+
+fn wrap_str_fn<F>(f: F, dflt: &str) -> String
+where
+ F: FnOnce() -> *const c_char,
+{
+ unsafe {
+ let p = f();
+ if p.is_null() {
+ return dflt.to_string();
+ }
+ CStr::from_ptr(p).to_string_lossy().into_owned()
+ }
+}
+
+pub fn secstatus_to_res(rv: SECStatus) -> Res<()> {
+ if rv == SECSuccess {
+ Ok(())
+ } else {
+ Err(Error::last_nss_error())
+ }
+}
+
+pub fn is_blocked(result: &Res<()>) -> bool {
+ match result {
+ Err(Error::NssError { code, .. }) => *code == nspr::PR_WOULD_BLOCK_ERROR,
+ _ => false,
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::err::{self, is_blocked, secstatus_to_res, Error, PRErrorCode, PR_SetError};
+ use crate::ssl::{SECFailure, SECSuccess};
+ use test_fixture::fixture_init;
+
+ fn set_error_code(code: PRErrorCode) {
+ // This code doesn't work without initializing NSS first.
+ fixture_init();
+ unsafe {
+ PR_SetError(code, 0);
+ }
+ }
+
+ #[test]
+ fn error_code() {
+ fixture_init();
+ assert_eq!(15 - 0x3000, err::ssl::SSL_ERROR_BAD_MAC_READ);
+ assert_eq!(166 - 0x2000, err::sec::SEC_ERROR_LIBPKIX_INTERNAL);
+ assert_eq!(-5998, err::nspr::PR_WOULD_BLOCK_ERROR);
+ }
+
+ #[test]
+ fn is_ok() {
+ assert!(secstatus_to_res(SECSuccess).is_ok());
+ }
+
+ #[test]
+ fn is_err() {
+ set_error_code(err::ssl::SSL_ERROR_BAD_MAC_READ);
+ let r = secstatus_to_res(SECFailure);
+ assert!(r.is_err());
+ match r.unwrap_err() {
+ Error::NssError { name, code, desc } => {
+ assert_eq!(name, "SSL_ERROR_BAD_MAC_READ");
+ assert_eq!(code, -12273);
+ assert_eq!(
+ desc,
+ "SSL received a record with an incorrect Message Authentication Code."
+ );
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ #[test]
+ fn is_err_zero_code() {
+ set_error_code(0);
+ let r = secstatus_to_res(SECFailure);
+ assert!(r.is_err());
+ match r.unwrap_err() {
+ Error::NssError { name, code, .. } => {
+ assert_eq!(name, "UNKNOWN_ERROR");
+ assert_eq!(code, 0);
+ // Note that we don't test |desc| here because that comes from
+ // strerror(0), which is platform-dependent.
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ #[test]
+ fn blocked() {
+ set_error_code(err::nspr::PR_WOULD_BLOCK_ERROR);
+ let r = secstatus_to_res(SECFailure);
+ assert!(r.is_err());
+ assert!(is_blocked(&r));
+ match r.unwrap_err() {
+ Error::NssError { name, code, desc } => {
+ assert_eq!(name, "PR_WOULD_BLOCK_ERROR");
+ assert_eq!(code, -5998);
+ assert_eq!(desc, "The operation would have blocked");
+ }
+ _ => panic!("bad error type"),
+ }
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/exp.rs b/third_party/rust/neqo-crypto/src/exp.rs
new file mode 100644
index 0000000000..59b6cb0ba5
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/exp.rs
@@ -0,0 +1,23 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+macro_rules! experimental_api {
+ ( $n:ident ( $( $a:ident : $t:ty ),* $(,)? ) ) => {
+ #[allow(non_snake_case)]
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) unsafe fn $n ( $( $a : $t ),* ) -> Result<(), crate::err::Error> {
+ const EXP_FUNCTION: &str = stringify!($n);
+ let n = ::std::ffi::CString::new(EXP_FUNCTION)?;
+ let f = crate::ssl::SSL_GetExperimentalAPI(n.as_ptr());
+ if f.is_null() {
+ return Err(crate::err::Error::InternalError);
+ }
+ let f: unsafe extern "C" fn( $( $t ),* ) -> crate::ssl::SECStatus = ::std::mem::transmute(f);
+ let rv = f( $( $a ),* );
+ crate::err::secstatus_to_res(rv)
+ }
+ };
+}
diff --git a/third_party/rust/neqo-crypto/src/ext.rs b/third_party/rust/neqo-crypto/src/ext.rs
new file mode 100644
index 0000000000..5c6dc1c8ff
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/ext.rs
@@ -0,0 +1,165 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::agentio::as_c_void;
+use crate::constants::{
+ Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS,
+};
+use crate::err::Res;
+use crate::ssl::{
+ PRBool, PRFileDesc, SECFailure, SECStatus, SECSuccess, SSLAlertDescription,
+ SSLExtensionHandler, SSLExtensionWriter, SSLHandshakeType,
+};
+
+use std::cell::RefCell;
+use std::convert::TryFrom;
+use std::os::raw::{c_uint, c_void};
+use std::pin::Pin;
+use std::rc::Rc;
+
+experimental_api!(SSL_InstallExtensionHooks(
+ fd: *mut PRFileDesc,
+ extension: u16,
+ writer: SSLExtensionWriter,
+ writer_arg: *mut c_void,
+ handler: SSLExtensionHandler,
+ handler_arg: *mut c_void,
+));
+
+pub enum ExtensionWriterResult {
+ Write(usize),
+ Skip,
+}
+
+pub enum ExtensionHandlerResult {
+ Ok,
+ Alert(crate::constants::Alert),
+}
+
+pub trait ExtensionHandler {
+ fn write(&mut self, msg: HandshakeMessage, _d: &mut [u8]) -> ExtensionWriterResult {
+ match msg {
+ TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0),
+ _ => ExtensionWriterResult::Skip,
+ }
+ }
+
+ fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult {
+ match msg {
+ TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok,
+ _ => ExtensionHandlerResult::Alert(110), // unsupported_extension
+ }
+ }
+}
+
+type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>;
+
+pub struct ExtensionTracker {
+ extension: Extension,
+ handler: Pin<Box<BoxedExtensionHandler>>,
+}
+
+impl ExtensionTracker {
+ // Technically the as_mut() call here is the only unsafe bit,
+ // but don't call this function lightly.
+ unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T
+ where
+ F: FnOnce(&mut dyn ExtensionHandler) -> T,
+ {
+ let rc = arg.cast::<BoxedExtensionHandler>().as_mut().unwrap();
+ f(&mut *rc.borrow_mut())
+ }
+
+ #[allow(clippy::cast_possible_truncation)]
+ unsafe extern "C" fn extension_writer(
+ _fd: *mut PRFileDesc,
+ message: SSLHandshakeType::Type,
+ data: *mut u8,
+ len: *mut c_uint,
+ max_len: c_uint,
+ arg: *mut c_void,
+ ) -> PRBool {
+ let d = std::slice::from_raw_parts_mut(data, max_len as usize);
+ Self::wrap_handler_call(arg, |handler| {
+ // Cast is safe here because the message type is always part of the enum
+ match handler.write(message as HandshakeMessage, d) {
+ ExtensionWriterResult::Write(sz) => {
+ *len = c_uint::try_from(sz).expect("integer overflow from extension writer");
+ 1
+ }
+ ExtensionWriterResult::Skip => 0,
+ }
+ })
+ }
+
+ unsafe extern "C" fn extension_handler(
+ _fd: *mut PRFileDesc,
+ message: SSLHandshakeType::Type,
+ data: *const u8,
+ len: c_uint,
+ alert: *mut SSLAlertDescription,
+ arg: *mut c_void,
+ ) -> SECStatus {
+ let d = std::slice::from_raw_parts(data, len as usize);
+ #[allow(clippy::cast_possible_truncation)]
+ Self::wrap_handler_call(arg, |handler| {
+ // Cast is safe here because the message type is always part of the enum
+ match handler.handle(message as HandshakeMessage, d) {
+ ExtensionHandlerResult::Ok => SECSuccess,
+ ExtensionHandlerResult::Alert(a) => {
+ *alert = a;
+ SECFailure
+ }
+ }
+ })
+ }
+
+ /// Use the provided handler to manage an extension. This is quite unsafe.
+ ///
+ /// # Safety
+ /// The holder of this `ExtensionTracker` needs to ensure that it lives at
+ /// least as long as the file descriptor, as NSS provides no way to remove
+ /// an extension handler once it is configured.
+ ///
+ /// # Errors
+ /// If the underlying NSS API fails to register a handler.
+ pub unsafe fn new(
+ fd: *mut PRFileDesc,
+ extension: Extension,
+ handler: Rc<RefCell<dyn ExtensionHandler>>,
+ ) -> Res<Self> {
+ // The ergonomics here aren't great for users of this API, but it's
+ // horrific here. The pinned outer box gives us a stable pointer to the inner
+ // box. This is the pointer that is passed to NSS.
+ //
+ // The inner box points to the reference-counted object. This inner box is
+ // what we end up with a reference to in callbacks. That extra wrapper around
+ // the Rc avoid any touching of reference counts in callbacks, which would
+ // inevitably lead to leaks as we don't control how many times the callback
+ // is invoked.
+ //
+ // This way, only this "outer" code deals with the reference count.
+ let mut tracker = Self {
+ extension,
+ handler: Box::pin(Box::new(handler)),
+ };
+ SSL_InstallExtensionHooks(
+ fd,
+ extension,
+ Some(Self::extension_writer),
+ as_c_void(&mut tracker.handler),
+ Some(Self::extension_handler),
+ as_c_void(&mut tracker.handler),
+ )?;
+ Ok(tracker)
+ }
+}
+
+impl std::fmt::Debug for ExtensionTracker {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "ExtensionTracker: {:?}", self.extension)
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/hkdf.rs b/third_party/rust/neqo-crypto/src/hkdf.rs
new file mode 100644
index 0000000000..8de9c01fc8
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/hkdf.rs
@@ -0,0 +1,128 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{
+ Cipher, Version, TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256,
+ TLS_VERSION_1_3,
+};
+use crate::err::{Error, Res};
+use crate::p11::{
+ random, Item, PK11Origin, PK11SymKey, PK11_ImportDataKey, Slot, SymKey, CKA_DERIVE,
+ CKM_HKDF_DERIVE, CK_ATTRIBUTE_TYPE, CK_MECHANISM_TYPE,
+};
+
+use std::convert::TryFrom;
+use std::os::raw::{c_char, c_uint};
+use std::ptr::null_mut;
+
+experimental_api!(SSL_HkdfExtract(
+ version: Version,
+ cipher: Cipher,
+ salt: *mut PK11SymKey,
+ ikm: *mut PK11SymKey,
+ prk: *mut *mut PK11SymKey,
+));
+experimental_api!(SSL_HkdfExpandLabel(
+ version: Version,
+ cipher: Cipher,
+ prk: *mut PK11SymKey,
+ handshake_hash: *const u8,
+ handshake_hash_len: c_uint,
+ label: *const c_char,
+ label_len: c_uint,
+ secret: *mut *mut PK11SymKey,
+));
+
+fn key_size(version: Version, cipher: Cipher) -> Res<usize> {
+ if version != TLS_VERSION_1_3 {
+ return Err(Error::UnsupportedVersion);
+ }
+ Ok(match cipher {
+ TLS_AES_128_GCM_SHA256 | TLS_CHACHA20_POLY1305_SHA256 => 32,
+ TLS_AES_256_GCM_SHA384 => 48,
+ _ => return Err(Error::UnsupportedCipher),
+ })
+}
+
+/// Generate a random key of the right size for the given suite.
+///
+/// # Errors
+/// Only if NSS fails.
+pub fn generate_key(version: Version, cipher: Cipher) -> Res<SymKey> {
+ import_key(version, &random(key_size(version, cipher)?))
+}
+
+/// Import a symmetric key for use with HKDF.
+///
+/// # Errors
+/// Errors returned if the key buffer is an incompatible size or the NSS functions fail.
+pub fn import_key(version: Version, buf: &[u8]) -> Res<SymKey> {
+ if version != TLS_VERSION_1_3 {
+ return Err(Error::UnsupportedVersion);
+ }
+ let slot = Slot::internal()?;
+ let key_ptr = unsafe {
+ PK11_ImportDataKey(
+ *slot,
+ CK_MECHANISM_TYPE::from(CKM_HKDF_DERIVE),
+ PK11Origin::PK11_OriginUnwrap,
+ CK_ATTRIBUTE_TYPE::from(CKA_DERIVE),
+ &mut Item::wrap(buf),
+ null_mut(),
+ )
+ };
+ SymKey::from_ptr(key_ptr)
+}
+
+/// Extract a PRK from the given salt and IKM using the algorithm defined in RFC 5869.
+///
+/// # Errors
+/// Errors returned if inputs are too large or the NSS functions fail.
+pub fn extract(
+ version: Version,
+ cipher: Cipher,
+ salt: Option<&SymKey>,
+ ikm: &SymKey,
+) -> Res<SymKey> {
+ let mut prk: *mut PK11SymKey = null_mut();
+ let salt_ptr: *mut PK11SymKey = match salt {
+ Some(s) => **s,
+ None => null_mut(),
+ };
+ unsafe { SSL_HkdfExtract(version, cipher, salt_ptr, **ikm, &mut prk) }?;
+ SymKey::from_ptr(prk)
+}
+
+/// Expand a PRK using the HKDF-Expand-Label function defined in RFC 8446.
+///
+/// # Errors
+/// Errors returned if inputs are too large or the NSS functions fail.
+pub fn expand_label(
+ version: Version,
+ cipher: Cipher,
+ prk: &SymKey,
+ handshake_hash: &[u8],
+ label: &str,
+) -> Res<SymKey> {
+ let l = label.as_bytes();
+ let mut secret: *mut PK11SymKey = null_mut();
+
+ // Note that this doesn't allow for passing null() for the handshake hash.
+ // A zero-length slice produces an identical result.
+ unsafe {
+ SSL_HkdfExpandLabel(
+ version,
+ cipher,
+ **prk,
+ handshake_hash.as_ptr(),
+ c_uint::try_from(handshake_hash.len())?,
+ l.as_ptr().cast(),
+ c_uint::try_from(l.len())?,
+ &mut secret,
+ )
+ }?;
+ SymKey::from_ptr(secret)
+}
diff --git a/third_party/rust/neqo-crypto/src/hp.rs b/third_party/rust/neqo-crypto/src/hp.rs
new file mode 100644
index 0000000000..f968943c00
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/hp.rs
@@ -0,0 +1,187 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{
+ Cipher, Version, TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256,
+};
+use crate::err::{secstatus_to_res, Error, Res};
+use crate::p11::{
+ Context, Item, PK11SymKey, PK11_CipherOp, PK11_CreateContextBySymKey, PK11_Encrypt,
+ PK11_GetBlockSize, SymKey, CKA_ENCRYPT, CKM_AES_ECB, CKM_CHACHA20, CK_ATTRIBUTE_TYPE,
+ CK_CHACHA20_PARAMS, CK_MECHANISM_TYPE,
+};
+
+use std::cell::RefCell;
+use std::convert::TryFrom;
+use std::fmt::{self, Debug};
+use std::os::raw::{c_char, c_int, c_uint};
+use std::ptr::{addr_of_mut, null, null_mut};
+use std::rc::Rc;
+
+experimental_api!(SSL_HkdfExpandLabelWithMech(
+ version: Version,
+ cipher: Cipher,
+ prk: *mut PK11SymKey,
+ handshake_hash: *const u8,
+ handshake_hash_len: c_uint,
+ label: *const c_char,
+ label_len: c_uint,
+ mech: CK_MECHANISM_TYPE,
+ key_size: c_uint,
+ secret: *mut *mut PK11SymKey,
+));
+
+#[derive(Clone)]
+pub enum HpKey {
+ /// An AES encryption context.
+ /// Note: as we need to clone this object, we clone the pointer and
+ /// track references using `Rc`. `PK11Context` can't be used with `PK11_CloneContext`
+ /// as that is not supported for these contexts.
+ Aes(Rc<RefCell<Context>>),
+ /// The ChaCha20 mask has to invoke a new PK11_Encrypt every time as it needs to
+ /// change the counter and nonce on each invocation.
+ Chacha(SymKey),
+}
+
+impl Debug for HpKey {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "HpKey")
+ }
+}
+
+impl HpKey {
+ const SAMPLE_SIZE: usize = 16;
+
+ /// QUIC-specific API for extracting a header-protection key.
+ ///
+ /// # Errors
+ /// Errors if HKDF fails or if the label is too long to fit in a `c_uint`.
+ /// # Panics
+ /// When `cipher` is not known to this code.
+ #[allow(clippy::cast_sign_loss)] // Cast for PK11_GetBlockSize is safe.
+ pub fn extract(version: Version, cipher: Cipher, prk: &SymKey, label: &str) -> Res<Self> {
+ const ZERO: &[u8] = &[0; 12];
+
+ let l = label.as_bytes();
+ let mut secret: *mut PK11SymKey = null_mut();
+
+ let (mech, key_size) = match cipher {
+ TLS_AES_128_GCM_SHA256 => (CK_MECHANISM_TYPE::from(CKM_AES_ECB), 16),
+ TLS_AES_256_GCM_SHA384 => (CK_MECHANISM_TYPE::from(CKM_AES_ECB), 32),
+ TLS_CHACHA20_POLY1305_SHA256 => (CK_MECHANISM_TYPE::from(CKM_CHACHA20), 32),
+ _ => unreachable!(),
+ };
+
+ // Note that this doesn't allow for passing null() for the handshake hash.
+ // A zero-length slice produces an identical result.
+ unsafe {
+ SSL_HkdfExpandLabelWithMech(
+ version,
+ cipher,
+ **prk,
+ null(),
+ 0,
+ l.as_ptr().cast(),
+ c_uint::try_from(l.len())?,
+ mech,
+ key_size,
+ &mut secret,
+ )
+ }?;
+ let key = SymKey::from_ptr(secret).or(Err(Error::HkdfError))?;
+
+ let res = match cipher {
+ TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => {
+ let context_ptr = unsafe {
+ PK11_CreateContextBySymKey(
+ mech,
+ CK_ATTRIBUTE_TYPE::from(CKA_ENCRYPT),
+ *key,
+ &Item::wrap(&ZERO[..0]), // Borrow a zero-length slice of ZERO.
+ )
+ };
+ let context = Context::from_ptr(context_ptr).or(Err(Error::CipherInitFailure))?;
+ Self::Aes(Rc::new(RefCell::new(context)))
+ }
+ TLS_CHACHA20_POLY1305_SHA256 => Self::Chacha(key),
+ _ => unreachable!(),
+ };
+
+ debug_assert_eq!(
+ res.block_size(),
+ usize::try_from(unsafe { PK11_GetBlockSize(mech, null_mut()) }).unwrap()
+ );
+ Ok(res)
+ }
+
+ /// Get the sample size, which is also the output size.
+ #[must_use]
+ #[allow(clippy::unused_self)] // To maintain an API contract.
+ pub fn sample_size(&self) -> usize {
+ Self::SAMPLE_SIZE
+ }
+
+ fn block_size(&self) -> usize {
+ match self {
+ Self::Aes(_) => 16,
+ Self::Chacha(_) => 64,
+ }
+ }
+
+ /// Generate a header protection mask for QUIC.
+ ///
+ /// # Errors
+ /// An error is returned if the NSS functions fail; a sample of the
+ /// wrong size is the obvious cause.
+ /// # Panics
+ /// When the mechanism for our key is not supported.
+ pub fn mask(&self, sample: &[u8]) -> Res<Vec<u8>> {
+ let mut output = vec![0_u8; self.block_size()];
+
+ match self {
+ Self::Aes(context) => {
+ let mut output_len: c_int = 0;
+ secstatus_to_res(unsafe {
+ PK11_CipherOp(
+ **context.borrow_mut(),
+ output.as_mut_ptr(),
+ &mut output_len,
+ c_int::try_from(output.len())?,
+ sample[..Self::SAMPLE_SIZE].as_ptr().cast(),
+ c_int::try_from(Self::SAMPLE_SIZE).unwrap(),
+ )
+ })?;
+ assert_eq!(usize::try_from(output_len).unwrap(), output.len());
+ Ok(output)
+ }
+
+ Self::Chacha(key) => {
+ let params: CK_CHACHA20_PARAMS = CK_CHACHA20_PARAMS {
+ pBlockCounter: sample.as_ptr() as *mut u8,
+ blockCounterBits: 32,
+ pNonce: sample[4..Self::SAMPLE_SIZE].as_ptr() as *mut _,
+ ulNonceBits: 96,
+ };
+ let mut output_len: c_uint = 0;
+ let mut param_item = Item::wrap_struct(&params);
+ secstatus_to_res(unsafe {
+ PK11_Encrypt(
+ **key,
+ CK_MECHANISM_TYPE::from(CKM_CHACHA20),
+ addr_of_mut!(param_item),
+ output[..].as_mut_ptr(),
+ &mut output_len,
+ c_uint::try_from(output.len())?,
+ output[..].as_ptr(),
+ c_uint::try_from(output.len())?,
+ )
+ })?;
+ assert_eq!(usize::try_from(output_len).unwrap(), output.len());
+ Ok(output)
+ }
+ }
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/lib.rs b/third_party/rust/neqo-crypto/src/lib.rs
new file mode 100644
index 0000000000..5067701738
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/lib.rs
@@ -0,0 +1,204 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![cfg_attr(feature = "deny-warnings", deny(warnings))]
+#![warn(clippy::pedantic)]
+// Bindgen auto generated code
+// won't adhere to the clippy rules below
+#![allow(clippy::module_name_repetitions)]
+#![allow(clippy::unseparated_literal_suffix)]
+#![allow(clippy::used_underscore_binding)]
+
+#[macro_use]
+mod exp;
+#[macro_use]
+mod p11;
+
+#[cfg(not(feature = "fuzzing"))]
+mod aead;
+
+#[cfg(feature = "fuzzing")]
+mod aead_fuzzing;
+
+pub mod agent;
+mod agentio;
+mod auth;
+mod cert;
+pub mod constants;
+mod ech;
+mod err;
+pub mod ext;
+pub mod hkdf;
+pub mod hp;
+mod once;
+mod prio;
+mod replay;
+mod secrets;
+pub mod selfencrypt;
+mod ssl;
+mod time;
+
+#[cfg(not(feature = "fuzzing"))]
+pub use self::aead::Aead;
+
+#[cfg(feature = "fuzzing")]
+pub use self::aead_fuzzing::Aead;
+
+#[cfg(feature = "fuzzing")]
+pub use self::aead_fuzzing::FIXED_TAG_FUZZING;
+
+pub use self::agent::{
+ Agent, AllowZeroRtt, Client, HandshakeState, Record, RecordList, ResumptionToken, SecretAgent,
+ SecretAgentInfo, SecretAgentPreInfo, Server, ZeroRttCheckResult, ZeroRttChecker,
+};
+pub use self::auth::AuthenticationStatus;
+pub use self::constants::*;
+pub use self::ech::{
+ encode_config as encode_ech_config, generate_keys as generate_ech_keys, AeadId, KdfId, KemId,
+ SymmetricSuite,
+};
+pub use self::err::{Error, PRErrorCode, Res};
+pub use self::ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult};
+pub use self::p11::{random, PrivateKey, PublicKey, SymKey};
+pub use self::replay::AntiReplay;
+pub use self::secrets::SecretDirection;
+pub use self::ssl::Opt;
+
+use self::once::OnceResult;
+
+use std::ffi::CString;
+use std::path::{Path, PathBuf};
+use std::ptr::null;
+
+const MINIMUM_NSS_VERSION: &str = "3.74";
+
+#[allow(non_upper_case_globals, clippy::redundant_static_lifetimes)]
+#[allow(clippy::upper_case_acronyms)]
+#[allow(unknown_lints, clippy::borrow_as_ptr)]
+mod nss {
+ include!(concat!(env!("OUT_DIR"), "/nss_init.rs"));
+}
+
+// Need to map the types through.
+fn secstatus_to_res(code: nss::SECStatus) -> Res<()> {
+ crate::err::secstatus_to_res(code as crate::ssl::SECStatus)
+}
+
+enum NssLoaded {
+ External,
+ NoDb,
+ Db(Box<Path>),
+}
+
+impl Drop for NssLoaded {
+ fn drop(&mut self) {
+ if !matches!(self, Self::External) {
+ unsafe {
+ secstatus_to_res(nss::NSS_Shutdown()).expect("NSS Shutdown failed");
+ }
+ }
+ }
+}
+
+static mut INITIALIZED: OnceResult<NssLoaded> = OnceResult::new();
+
+fn already_initialized() -> bool {
+ unsafe { nss::NSS_IsInitialized() != 0 }
+}
+
+fn version_check() {
+ let min_ver = CString::new(MINIMUM_NSS_VERSION).unwrap();
+ assert_ne!(
+ unsafe { nss::NSS_VersionCheck(min_ver.as_ptr()) },
+ 0,
+ "Minimum NSS version of {} not supported",
+ MINIMUM_NSS_VERSION,
+ );
+}
+
+/// Initialize NSS. This only executes the initialization routines once, so if there is any chance that
+pub fn init() {
+ // Set time zero.
+ time::init();
+ unsafe {
+ INITIALIZED.call_once(|| {
+ version_check();
+ if already_initialized() {
+ return NssLoaded::External;
+ }
+
+ secstatus_to_res(nss::NSS_NoDB_Init(null())).expect("NSS_NoDB_Init failed");
+ secstatus_to_res(nss::NSS_SetDomesticPolicy()).expect("NSS_SetDomesticPolicy failed");
+
+ NssLoaded::NoDb
+ });
+ }
+}
+
+/// This enables SSLTRACE by calling a simple, harmless function to trigger its
+/// side effects. SSLTRACE is not enabled in NSS until a socket is made or
+/// global options are accessed. Reading an option is the least impact approach.
+/// This allows us to use SSLTRACE in all of our unit tests and programs.
+#[cfg(debug_assertions)]
+fn enable_ssl_trace() {
+ let opt = ssl::Opt::Locking.as_int();
+ let mut v: ::std::os::raw::c_int = 0;
+ secstatus_to_res(unsafe { ssl::SSL_OptionGetDefault(opt, &mut v) })
+ .expect("SSL_OptionGetDefault failed");
+}
+
+/// Initialize with a database.
+/// # Panics
+/// If NSS cannot be initialized.
+pub fn init_db<P: Into<PathBuf>>(dir: P) {
+ time::init();
+ unsafe {
+ INITIALIZED.call_once(|| {
+ version_check();
+ if already_initialized() {
+ return NssLoaded::External;
+ }
+
+ let path = dir.into();
+ assert!(path.is_dir());
+ let pathstr = path.to_str().expect("path converts to string").to_string();
+ let dircstr = CString::new(pathstr).unwrap();
+ let empty = CString::new("").unwrap();
+ secstatus_to_res(nss::NSS_Initialize(
+ dircstr.as_ptr(),
+ empty.as_ptr(),
+ empty.as_ptr(),
+ nss::SECMOD_DB.as_ptr().cast(),
+ nss::NSS_INIT_READONLY,
+ ))
+ .expect("NSS_Initialize failed");
+
+ secstatus_to_res(nss::NSS_SetDomesticPolicy()).expect("NSS_SetDomesticPolicy failed");
+ secstatus_to_res(ssl::SSL_ConfigServerSessionIDCache(
+ 1024,
+ 0,
+ 0,
+ dircstr.as_ptr(),
+ ))
+ .expect("SSL_ConfigServerSessionIDCache failed");
+
+ #[cfg(debug_assertions)]
+ enable_ssl_trace();
+
+ NssLoaded::Db(path.into_boxed_path())
+ });
+ }
+}
+
+/// # Panics
+/// If NSS isn't initialized.
+pub fn assert_initialized() {
+ unsafe {
+ INITIALIZED.call_once(|| {
+ panic!("NSS not initialized with init or init_db");
+ });
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/once.rs b/third_party/rust/neqo-crypto/src/once.rs
new file mode 100644
index 0000000000..80657cfe26
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/once.rs
@@ -0,0 +1,44 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use std::sync::Once;
+
+#[allow(clippy::module_name_repetitions)]
+pub struct OnceResult<T> {
+ once: Once,
+ v: Option<T>,
+}
+
+impl<T> OnceResult<T> {
+ #[must_use]
+ pub const fn new() -> Self {
+ Self {
+ once: Once::new(),
+ v: None,
+ }
+ }
+
+ pub fn call_once<F: FnOnce() -> T>(&mut self, f: F) -> &T {
+ let v = &mut self.v;
+ self.once.call_once(|| {
+ *v = Some(f());
+ });
+ self.v.as_ref().unwrap()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::OnceResult;
+
+ static mut STATIC_ONCE_RESULT: OnceResult<u64> = OnceResult::new();
+
+ #[test]
+ fn static_update() {
+ assert_eq!(*unsafe { STATIC_ONCE_RESULT.call_once(|| 23) }, 23);
+ assert_eq!(*unsafe { STATIC_ONCE_RESULT.call_once(|| 24) }, 23);
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/p11.rs b/third_party/rust/neqo-crypto/src/p11.rs
new file mode 100644
index 0000000000..7848cf08c8
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/p11.rs
@@ -0,0 +1,303 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(dead_code)]
+#![allow(non_upper_case_globals)]
+#![allow(non_camel_case_types)]
+#![allow(non_snake_case)]
+
+use crate::err::{secstatus_to_res, Error, Res};
+
+use neqo_common::hex_with_len;
+
+use std::convert::TryFrom;
+use std::mem;
+use std::ops::{Deref, DerefMut};
+use std::os::raw::{c_int, c_uint};
+use std::ptr::null_mut;
+
+#[allow(clippy::upper_case_acronyms)]
+#[allow(clippy::unreadable_literal)]
+#[allow(unknown_lints, clippy::borrow_as_ptr)]
+mod nss_p11 {
+ include!(concat!(env!("OUT_DIR"), "/nss_p11.rs"));
+}
+
+pub use nss_p11::*;
+
+macro_rules! scoped_ptr {
+ ($scoped:ident, $target:ty, $dtor:path) => {
+ pub struct $scoped {
+ ptr: *mut $target,
+ }
+
+ impl $scoped {
+ /// Create a new instance of `$scoped` from a pointer.
+ ///
+ /// # Errors
+ /// When passed a null pointer generates an error.
+ pub fn from_ptr(ptr: *mut $target) -> Result<Self, crate::err::Error> {
+ if ptr.is_null() {
+ Err(crate::err::Error::last_nss_error())
+ } else {
+ Ok(Self { ptr })
+ }
+ }
+ }
+
+ impl Deref for $scoped {
+ type Target = *mut $target;
+ #[must_use]
+ fn deref(&self) -> &*mut $target {
+ &self.ptr
+ }
+ }
+
+ impl DerefMut for $scoped {
+ fn deref_mut(&mut self) -> &mut *mut $target {
+ &mut self.ptr
+ }
+ }
+
+ impl Drop for $scoped {
+ #[allow(unused_must_use)]
+ fn drop(&mut self) {
+ unsafe { $dtor(self.ptr) };
+ }
+ }
+ };
+}
+
+scoped_ptr!(Certificate, CERTCertificate, CERT_DestroyCertificate);
+scoped_ptr!(CertList, CERTCertList, CERT_DestroyCertList);
+scoped_ptr!(PublicKey, SECKEYPublicKey, SECKEY_DestroyPublicKey);
+
+impl PublicKey {
+ /// Get the HPKE serialization of the public key.
+ ///
+ /// # Errors
+ /// When the key cannot be exported, which can be because the type is not supported.
+ /// # Panics
+ /// When keys are too large to fit in `c_uint/usize`. So only on programming error.
+ pub fn key_data(&self) -> Res<Vec<u8>> {
+ let mut buf = vec![0; 100];
+ let mut len: c_uint = 0;
+ secstatus_to_res(unsafe {
+ PK11_HPKE_Serialize(
+ **self,
+ buf.as_mut_ptr(),
+ &mut len,
+ c_uint::try_from(buf.len()).unwrap(),
+ )
+ })?;
+ buf.truncate(usize::try_from(len).unwrap());
+ Ok(buf)
+ }
+}
+
+impl Clone for PublicKey {
+ #[must_use]
+ fn clone(&self) -> Self {
+ let ptr = unsafe { SECKEY_CopyPublicKey(self.ptr) };
+ assert!(!ptr.is_null());
+ Self { ptr }
+ }
+}
+
+impl std::fmt::Debug for PublicKey {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ if let Ok(b) = self.key_data() {
+ write!(f, "PublicKey {}", hex_with_len(b))
+ } else {
+ write!(f, "Opaque PublicKey")
+ }
+ }
+}
+
+scoped_ptr!(PrivateKey, SECKEYPrivateKey, SECKEY_DestroyPrivateKey);
+
+impl PrivateKey {
+ /// Get the bits of the private key.
+ ///
+ /// # Errors
+ /// When the key cannot be exported, which can be because the type is not supported
+ /// or because the key data cannot be extracted from the PKCS#11 module.
+ /// # Panics
+ /// When the values are too large to fit. So never.
+ pub fn key_data(&self) -> Res<Vec<u8>> {
+ let mut key_item = Item::make_empty();
+ secstatus_to_res(unsafe {
+ PK11_ReadRawAttribute(
+ PK11ObjectType::PK11_TypePrivKey,
+ (**self).cast(),
+ CK_ATTRIBUTE_TYPE::from(CKA_VALUE),
+ &mut key_item,
+ )
+ })?;
+ let slc = unsafe {
+ std::slice::from_raw_parts(key_item.data, usize::try_from(key_item.len).unwrap())
+ };
+ let key = Vec::from(slc);
+ // The data that `key_item` refers to needs to be freed, but we can't
+ // use the scoped `Item` implementation. This is OK as long as nothing
+ // panics between `PK11_ReadRawAttribute` succeeding and here.
+ unsafe {
+ SECITEM_FreeItem(&mut key_item, PRBool::from(false));
+ }
+ Ok(key)
+ }
+}
+unsafe impl Send for PrivateKey {}
+
+impl Clone for PrivateKey {
+ #[must_use]
+ fn clone(&self) -> Self {
+ let ptr = unsafe { SECKEY_CopyPrivateKey(self.ptr) };
+ assert!(!ptr.is_null());
+ Self { ptr }
+ }
+}
+
+impl std::fmt::Debug for PrivateKey {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ if let Ok(b) = self.key_data() {
+ write!(f, "PrivateKey {}", hex_with_len(b))
+ } else {
+ write!(f, "Opaque PrivateKey")
+ }
+ }
+}
+
+scoped_ptr!(Slot, PK11SlotInfo, PK11_FreeSlot);
+
+impl Slot {
+ pub fn internal() -> Res<Self> {
+ let p = unsafe { PK11_GetInternalSlot() };
+ Slot::from_ptr(p)
+ }
+}
+
+scoped_ptr!(SymKey, PK11SymKey, PK11_FreeSymKey);
+
+impl SymKey {
+ /// You really don't want to use this.
+ ///
+ /// # Errors
+ /// Internal errors in case of failures in NSS.
+ pub fn as_bytes(&self) -> Res<&[u8]> {
+ secstatus_to_res(unsafe { PK11_ExtractKeyValue(self.ptr) })?;
+
+ let key_item = unsafe { PK11_GetKeyData(self.ptr) };
+ // This is accessing a value attached to the key, so we can treat this as a borrow.
+ match unsafe { key_item.as_mut() } {
+ None => Err(Error::InternalError),
+ Some(key) => Ok(unsafe { std::slice::from_raw_parts(key.data, key.len as usize) }),
+ }
+ }
+}
+
+impl Clone for SymKey {
+ #[must_use]
+ fn clone(&self) -> Self {
+ let ptr = unsafe { PK11_ReferenceSymKey(self.ptr) };
+ assert!(!ptr.is_null());
+ Self { ptr }
+ }
+}
+
+impl std::fmt::Debug for SymKey {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ if let Ok(b) = self.as_bytes() {
+ write!(f, "SymKey {}", hex_with_len(b))
+ } else {
+ write!(f, "Opaque SymKey")
+ }
+ }
+}
+
+unsafe fn destroy_pk11_context(ctxt: *mut PK11Context) {
+ PK11_DestroyContext(ctxt, PRBool::from(true));
+}
+scoped_ptr!(Context, PK11Context, destroy_pk11_context);
+
+unsafe fn destroy_secitem(item: *mut SECItem) {
+ SECITEM_FreeItem(item, PRBool::from(true));
+}
+scoped_ptr!(Item, SECItem, destroy_secitem);
+
+impl Item {
+ /// Create a wrapper for a slice of this object.
+ /// Creating this object is technically safe, but using it is extremely dangerous.
+ /// Minimally, it can only be passed as a `const SECItem*` argument to functions,
+ /// or those that treat their argument as `const`.
+ pub fn wrap(buf: &[u8]) -> SECItem {
+ SECItem {
+ type_: SECItemType::siBuffer,
+ data: buf.as_ptr() as *mut u8,
+ len: c_uint::try_from(buf.len()).unwrap(),
+ }
+ }
+
+ /// Create a wrapper for a struct.
+ /// Creating this object is technically safe, but using it is extremely dangerous.
+ /// Minimally, it can only be passed as a `const SECItem*` argument to functions,
+ /// or those that treat their argument as `const`.
+ pub fn wrap_struct<T>(v: &T) -> SECItem {
+ SECItem {
+ type_: SECItemType::siBuffer,
+ data: (v as *const T as *mut T).cast(),
+ len: c_uint::try_from(mem::size_of::<T>()).unwrap(),
+ }
+ }
+
+ /// Make an empty `SECItem` for passing as a mutable `SECItem*` argument.
+ pub fn make_empty() -> SECItem {
+ SECItem {
+ type_: SECItemType::siBuffer,
+ data: null_mut(),
+ len: 0,
+ }
+ }
+
+ /// This dereferences the pointer held by the item and makes a copy of the
+ /// content that is referenced there.
+ ///
+ /// # Safety
+ /// This dereferences two pointers. It doesn't get much less safe.
+ pub unsafe fn into_vec(self) -> Vec<u8> {
+ let b = self.ptr.as_ref().unwrap();
+ // Sanity check the type, as some types don't count bytes in `Item::len`.
+ assert_eq!(b.type_, SECItemType::siBuffer);
+ let slc = std::slice::from_raw_parts(b.data, usize::try_from(b.len).unwrap());
+ Vec::from(slc)
+ }
+}
+
+/// Generate a randomized buffer.
+/// # Panics
+/// When `size` is too large or NSS fails.
+#[must_use]
+pub fn random(size: usize) -> Vec<u8> {
+ let mut buf = vec![0; size];
+ secstatus_to_res(unsafe {
+ PK11_GenerateRandom(buf.as_mut_ptr(), c_int::try_from(buf.len()).unwrap())
+ })
+ .unwrap();
+ buf
+}
+
+#[cfg(test)]
+mod test {
+ use super::random;
+ use test_fixture::fixture_init;
+
+ #[test]
+ fn randomness() {
+ fixture_init();
+ // If this ever fails, there is either a bug, or it's time to buy a lottery ticket.
+ assert_ne!(random(16), random(16));
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/prio.rs b/third_party/rust/neqo-crypto/src/prio.rs
new file mode 100644
index 0000000000..527d8739c8
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/prio.rs
@@ -0,0 +1,25 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(clippy::upper_case_acronyms)]
+#![allow(
+ dead_code,
+ non_upper_case_globals,
+ non_snake_case,
+ clippy::cognitive_complexity,
+ clippy::empty_enum,
+ clippy::too_many_lines,
+ unknown_lints,
+ clippy::borrow_as_ptr
+)]
+
+include!(concat!(env!("OUT_DIR"), "/nspr_io.rs"));
+
+pub enum PRFileInfo {}
+pub enum PRFileInfo64 {}
+pub enum PRFilePrivate {}
+pub enum PRIOVec {}
+pub enum PRSendFileData {}
diff --git a/third_party/rust/neqo-crypto/src/replay.rs b/third_party/rust/neqo-crypto/src/replay.rs
new file mode 100644
index 0000000000..7a493fbc1b
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/replay.rs
@@ -0,0 +1,78 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::err::Res;
+use crate::ssl::PRFileDesc;
+use crate::time::{Interval, PRTime, Time};
+
+use std::convert::{TryFrom, TryInto};
+use std::ops::{Deref, DerefMut};
+use std::os::raw::c_uint;
+use std::ptr::null_mut;
+use std::time::{Duration, Instant};
+
+// This is an opaque struct in NSS.
+#[allow(clippy::upper_case_acronyms)]
+#[allow(clippy::empty_enum)]
+pub enum SSLAntiReplayContext {}
+
+experimental_api!(SSL_CreateAntiReplayContext(
+ now: PRTime,
+ window: PRTime,
+ k: c_uint,
+ bits: c_uint,
+ ctx: *mut *mut SSLAntiReplayContext,
+));
+experimental_api!(SSL_ReleaseAntiReplayContext(ctx: *mut SSLAntiReplayContext));
+experimental_api!(SSL_SetAntiReplayContext(
+ fd: *mut PRFileDesc,
+ ctx: *mut SSLAntiReplayContext,
+));
+
+scoped_ptr!(
+ AntiReplayContext,
+ SSLAntiReplayContext,
+ SSL_ReleaseAntiReplayContext
+);
+
+/// `AntiReplay` is used by servers when processing 0-RTT handshakes.
+/// It limits the exposure of servers to replay attack by rejecting 0-RTT
+/// if it appears to be a replay. There is a false-positive rate that can be
+/// managed by tuning the parameters used to create the context.
+#[allow(clippy::module_name_repetitions)]
+pub struct AntiReplay {
+ ctx: AntiReplayContext,
+}
+
+impl AntiReplay {
+ /// Make a new anti-replay context.
+ /// See the documentation in NSS for advice on how to set these values.
+ ///
+ /// # Errors
+ /// Returns an error if `now` is in the past relative to our baseline or
+ /// NSS is unable to generate an anti-replay context.
+ pub fn new(now: Instant, window: Duration, k: usize, bits: usize) -> Res<Self> {
+ let mut ctx: *mut SSLAntiReplayContext = null_mut();
+ unsafe {
+ SSL_CreateAntiReplayContext(
+ Time::from(now).try_into()?,
+ Interval::from(window).try_into()?,
+ c_uint::try_from(k)?,
+ c_uint::try_from(bits)?,
+ &mut ctx,
+ )
+ }?;
+
+ Ok(Self {
+ ctx: AntiReplayContext::from_ptr(ctx)?,
+ })
+ }
+
+ /// Configure the provided socket with this anti-replay context.
+ pub(crate) fn config_socket(&self, fd: *mut PRFileDesc) -> Res<()> {
+ unsafe { SSL_SetAntiReplayContext(fd, *self.ctx) }
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/result.rs b/third_party/rust/neqo-crypto/src/result.rs
new file mode 100644
index 0000000000..e6148dd054
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/result.rs
@@ -0,0 +1,133 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::err::{
+ nspr, Error, PR_ErrorToName, PR_ErrorToString, PR_GetError, Res, PR_LANGUAGE_I_DEFAULT,
+};
+use crate::ssl;
+
+use std::ffi::CStr;
+
+pub fn result(rv: ssl::SECStatus) -> Res<()> {
+ let _ = result_helper(rv, false)?;
+ Ok(())
+}
+
+pub fn result_or_blocked(rv: ssl::SECStatus) -> Res<bool> {
+ result_helper(rv, true)
+}
+
+fn wrap_str_fn<F>(f: F, dflt: &str) -> String
+where
+ F: FnOnce() -> *const i8,
+{
+ unsafe {
+ let p = f();
+ if p.is_null() {
+ return dflt.to_string();
+ }
+ CStr::from_ptr(p).to_string_lossy().into_owned()
+ }
+}
+
+fn result_helper(rv: ssl::SECStatus, allow_blocked: bool) -> Res<bool> {
+ if rv == ssl::_SECStatus_SECSuccess {
+ return Ok(false);
+ }
+
+ let code = unsafe { PR_GetError() };
+ if allow_blocked && code == nspr::PR_WOULD_BLOCK_ERROR {
+ return Ok(true);
+ }
+
+ let name = wrap_str_fn(|| unsafe { PR_ErrorToName(code) }, "UNKNOWN_ERROR");
+ let desc = wrap_str_fn(
+ || unsafe { PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT) },
+ "...",
+ );
+ Err(Error::NssError { name, code, desc })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{result, result_or_blocked};
+ use crate::err::{self, nspr, Error, PRErrorCode, PR_SetError};
+ use crate::ssl;
+ use test_fixture::fixture_init;
+
+ fn set_error_code(code: PRErrorCode) {
+ unsafe { PR_SetError(code, 0) };
+ }
+
+ #[test]
+ fn is_ok() {
+ assert!(result(ssl::SECSuccess).is_ok());
+ }
+
+ #[test]
+ fn is_err() {
+ // This code doesn't work without initializing NSS first.
+ fixture_init();
+
+ set_error_code(err::ssl::SSL_ERROR_BAD_MAC_READ);
+ let r = result(ssl::SECFailure);
+ assert!(r.is_err());
+ match r.unwrap_err() {
+ Error::NssError { name, code, desc } => {
+ assert_eq!(name, "SSL_ERROR_BAD_MAC_READ");
+ assert_eq!(code, -12273);
+ assert_eq!(
+ desc,
+ "SSL received a record with an incorrect Message Authentication Code."
+ );
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ #[test]
+ fn is_err_zero_code() {
+ // This code doesn't work without initializing NSS first.
+ fixture_init();
+
+ set_error_code(0);
+ let r = result(ssl::SECFailure);
+ assert!(r.is_err());
+ match r.unwrap_err() {
+ Error::NssError { name, code, .. } => {
+ assert_eq!(name, "UNKNOWN_ERROR");
+ assert_eq!(code, 0);
+ // Note that we don't test |desc| here because that comes from
+ // strerror(0), which is platform-dependent.
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ #[test]
+ fn blocked_as_error() {
+ // This code doesn't work without initializing NSS first.
+ fixture_init();
+
+ set_error_code(nspr::PR_WOULD_BLOCK_ERROR);
+ let r = result(ssl::SECFailure);
+ assert!(r.is_err());
+ match r.unwrap_err() {
+ Error::NssError { name, code, desc } => {
+ assert_eq!(name, "PR_WOULD_BLOCK_ERROR");
+ assert_eq!(code, -5998);
+ assert_eq!(desc, "The operation would have blocked");
+ }
+ _ => panic!("bad error type"),
+ }
+ }
+
+ #[test]
+ fn is_blocked() {
+ set_error_code(nspr::PR_WOULD_BLOCK_ERROR);
+ assert!(result_or_blocked(ssl::SECFailure).unwrap());
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/secrets.rs b/third_party/rust/neqo-crypto/src/secrets.rs
new file mode 100644
index 0000000000..7e2739cfa5
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/secrets.rs
@@ -0,0 +1,127 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::agentio::as_c_void;
+use crate::constants::Epoch;
+use crate::err::Res;
+use crate::p11::{PK11SymKey, PK11_ReferenceSymKey, SymKey};
+use crate::ssl::{PRFileDesc, SSLSecretCallback, SSLSecretDirection};
+
+use neqo_common::qdebug;
+use std::os::raw::c_void;
+use std::pin::Pin;
+
+experimental_api!(SSL_SecretCallback(
+ fd: *mut PRFileDesc,
+ cb: SSLSecretCallback,
+ arg: *mut c_void,
+));
+
+#[derive(Clone, Copy, Debug)]
+pub enum SecretDirection {
+ Read,
+ Write,
+}
+
+impl From<SSLSecretDirection::Type> for SecretDirection {
+ #[must_use]
+ fn from(dir: SSLSecretDirection::Type) -> Self {
+ match dir {
+ SSLSecretDirection::ssl_secret_read => Self::Read,
+ SSLSecretDirection::ssl_secret_write => Self::Write,
+ _ => unreachable!(),
+ }
+ }
+}
+
+#[derive(Debug, Default)]
+#[allow(clippy::module_name_repetitions)]
+pub struct DirectionalSecrets {
+ // We only need to maintain 3 secrets for the epochs used during the handshake.
+ secrets: [Option<SymKey>; 3],
+}
+
+impl DirectionalSecrets {
+ fn put(&mut self, epoch: Epoch, key: SymKey) {
+ assert!(epoch > 0);
+ let i = (epoch - 1) as usize;
+ assert!(i < self.secrets.len());
+ // assert!(self.secrets[i].is_none());
+ self.secrets[i] = Some(key);
+ }
+
+ pub fn take(&mut self, epoch: Epoch) -> Option<SymKey> {
+ assert!(epoch > 0);
+ let i = (epoch - 1) as usize;
+ assert!(i < self.secrets.len());
+ self.secrets[i].take()
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct Secrets {
+ r: DirectionalSecrets,
+ w: DirectionalSecrets,
+}
+
+impl Secrets {
+ #[allow(clippy::unused_self)]
+ unsafe extern "C" fn secret_available(
+ _fd: *mut PRFileDesc,
+ epoch: u16,
+ dir: SSLSecretDirection::Type,
+ secret: *mut PK11SymKey,
+ arg: *mut c_void,
+ ) {
+ let secrets = arg.cast::<Self>().as_mut().unwrap();
+ secrets.put_raw(epoch, dir, secret);
+ }
+
+ fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) {
+ let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) };
+ let key = SymKey::from_ptr(key_ptr).expect("NSS shouldn't be passing out NULL secrets");
+ self.put(SecretDirection::from(dir), epoch, key);
+ }
+
+ fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) {
+ qdebug!("{:?} secret available for {:?}: {:?}", dir, epoch, key);
+ let keys = match dir {
+ SecretDirection::Read => &mut self.r,
+ SecretDirection::Write => &mut self.w,
+ };
+ keys.put(epoch, key);
+ }
+}
+
+#[derive(Debug)]
+pub struct SecretHolder {
+ secrets: Pin<Box<Secrets>>,
+}
+
+impl SecretHolder {
+ /// This registers with NSS. The lifetime of this object needs to match the lifetime
+ /// of the connection, or bad things might happen.
+ pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> {
+ let p = as_c_void(&mut self.secrets);
+ unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) }
+ }
+
+ pub fn take_read(&mut self, epoch: Epoch) -> Option<SymKey> {
+ self.secrets.r.take(epoch)
+ }
+
+ pub fn take_write(&mut self, epoch: Epoch) -> Option<SymKey> {
+ self.secrets.w.take(epoch)
+ }
+}
+
+impl Default for SecretHolder {
+ fn default() -> Self {
+ Self {
+ secrets: Box::pin(Secrets::default()),
+ }
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/selfencrypt.rs b/third_party/rust/neqo-crypto/src/selfencrypt.rs
new file mode 100644
index 0000000000..bcf219a4e9
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/selfencrypt.rs
@@ -0,0 +1,155 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{Cipher, Version};
+use crate::err::{Error, Res};
+use crate::p11::{random, SymKey};
+use crate::{hkdf, Aead};
+
+use neqo_common::{hex, qinfo, qtrace, Encoder};
+
+use std::mem;
+
+#[derive(Debug)]
+pub struct SelfEncrypt {
+ version: Version,
+ cipher: Cipher,
+ key_id: u8,
+ key: SymKey,
+ old_key: Option<SymKey>,
+}
+
+impl SelfEncrypt {
+ const VERSION: u8 = 1;
+ const SALT_LENGTH: usize = 16;
+
+ /// # Errors
+ /// Failure to generate a new HKDF key using NSS results in an error.
+ pub fn new(version: Version, cipher: Cipher) -> Res<Self> {
+ let key = hkdf::generate_key(version, cipher)?;
+ Ok(Self {
+ version,
+ cipher,
+ key_id: 0,
+ key,
+ old_key: None,
+ })
+ }
+
+ fn make_aead(&self, k: &SymKey, salt: &[u8]) -> Res<Aead> {
+ debug_assert_eq!(salt.len(), Self::SALT_LENGTH);
+ let salt = hkdf::import_key(self.version, salt)?;
+ let secret = hkdf::extract(self.version, self.cipher, Some(&salt), k)?;
+ Aead::new(self.version, self.cipher, &secret, "neqo self")
+ }
+
+ /// Rotate keys. This causes any previous key that is being held to be replaced by the current key.
+ ///
+ /// # Errors
+ /// Failure to generate a new HKDF key using NSS results in an error.
+ pub fn rotate(&mut self) -> Res<()> {
+ let new_key = hkdf::generate_key(self.version, self.cipher)?;
+ self.old_key = Some(mem::replace(&mut self.key, new_key));
+ let (kid, _) = self.key_id.overflowing_add(1);
+ self.key_id = kid;
+ qinfo!(["SelfEncrypt"], "Rotated keys to {}", self.key_id);
+ Ok(())
+ }
+
+ /// Seal an item using the underlying key. This produces a single buffer that contains
+ /// the encrypted `plaintext`, plus a version number and salt.
+ /// `aad` is only used as input to the AEAD, it is not included in the output; the
+ /// caller is responsible for carrying the AAD as appropriate.
+ ///
+ /// # Errors
+ /// Failure to protect using NSS AEAD APIs produces an error.
+ pub fn seal(&self, aad: &[u8], plaintext: &[u8]) -> Res<Vec<u8>> {
+ // Format is:
+ // struct {
+ // uint8 version;
+ // uint8 key_id;
+ // uint8 salt[16];
+ // opaque aead_encrypted(plaintext)[length as expanded];
+ // };
+ // AAD covers the entire header, plus the value of the AAD parameter that is provided.
+ let salt = random(Self::SALT_LENGTH);
+ let cipher = self.make_aead(&self.key, &salt)?;
+ let encoded_len = 2 + salt.len() + plaintext.len() + cipher.expansion();
+
+ let mut enc = Encoder::with_capacity(encoded_len);
+ enc.encode_byte(Self::VERSION);
+ enc.encode_byte(self.key_id);
+ enc.encode(&salt);
+
+ let mut extended_aad = enc.clone();
+ extended_aad.encode(aad);
+
+ let offset = enc.len();
+ let mut output: Vec<u8> = enc.into();
+ output.resize(encoded_len, 0);
+ cipher.encrypt(0, extended_aad.as_ref(), plaintext, &mut output[offset..])?;
+ qtrace!(
+ ["SelfEncrypt"],
+ "seal {} {} -> {}",
+ hex(aad),
+ hex(plaintext),
+ hex(&output)
+ );
+ Ok(output)
+ }
+
+ fn select_key(&self, kid: u8) -> Option<&SymKey> {
+ if kid == self.key_id {
+ Some(&self.key)
+ } else {
+ let (prev_key_id, _) = self.key_id.overflowing_sub(1);
+ if kid == prev_key_id {
+ self.old_key.as_ref()
+ } else {
+ None
+ }
+ }
+ }
+
+ /// Open the protected `ciphertext`.
+ ///
+ /// # Errors
+ /// Returns an error when the self-encrypted object is invalid;
+ /// when the keys have been rotated; or when NSS fails.
+ #[allow(clippy::similar_names)] // aad is similar to aead
+ pub fn open(&self, aad: &[u8], ciphertext: &[u8]) -> Res<Vec<u8>> {
+ if ciphertext[0] != Self::VERSION {
+ return Err(Error::SelfEncryptFailure);
+ }
+ let key = if let Some(k) = self.select_key(ciphertext[1]) {
+ k
+ } else {
+ return Err(Error::SelfEncryptFailure);
+ };
+ let offset = 2 + Self::SALT_LENGTH;
+
+ let mut extended_aad = Encoder::with_capacity(offset + aad.len());
+ extended_aad.encode(&ciphertext[0..offset]);
+ extended_aad.encode(aad);
+
+ let aead = self.make_aead(key, &ciphertext[2..offset])?;
+ // NSS insists on having extra space available for decryption.
+ let padded_len = ciphertext.len() - offset;
+ let mut output = vec![0; padded_len];
+ let decrypted =
+ aead.decrypt(0, extended_aad.as_ref(), &ciphertext[offset..], &mut output)?;
+ let final_len = decrypted.len();
+ output.truncate(final_len);
+ qtrace!(
+ ["SelfEncrypt"],
+ "open {} {} -> {}",
+ hex(aad),
+ hex(ciphertext),
+ hex(&output)
+ );
+ Ok(output)
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/ssl.rs b/third_party/rust/neqo-crypto/src/ssl.rs
new file mode 100644
index 0000000000..b3c0c12708
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/ssl.rs
@@ -0,0 +1,149 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(
+ dead_code,
+ non_upper_case_globals,
+ non_snake_case,
+ clippy::cognitive_complexity,
+ clippy::too_many_lines,
+ clippy::upper_case_acronyms,
+ unknown_lints,
+ clippy::borrow_as_ptr
+)]
+
+use crate::constants::Epoch;
+use crate::err::{secstatus_to_res, Res};
+
+use std::os::raw::{c_uint, c_void};
+
+include!(concat!(env!("OUT_DIR"), "/nss_ssl.rs"));
+mod SSLOption {
+ include!(concat!(env!("OUT_DIR"), "/nss_sslopt.rs"));
+}
+
+// I clearly don't understand how bindgen operates.
+#[allow(clippy::empty_enum)]
+pub enum PLArenaPool {}
+#[allow(clippy::empty_enum)]
+pub enum PRFileDesc {}
+
+// Remap some constants.
+pub const SECSuccess: SECStatus = _SECStatus_SECSuccess;
+pub const SECFailure: SECStatus = _SECStatus_SECFailure;
+
+#[derive(Debug, Copy, Clone)]
+pub enum Opt {
+ Locking,
+ Tickets,
+ OcspStapling,
+ Alpn,
+ ExtendedMasterSecret,
+ SignedCertificateTimestamps,
+ EarlyData,
+ RecordSizeLimit,
+ Tls13CompatMode,
+ HelloDowngradeCheck,
+ SuppressEndOfEarlyData,
+}
+
+impl Opt {
+ // Cast is safe here because SSLOptions are within the i32 range
+ #[allow(clippy::cast_possible_wrap)]
+ pub(crate) fn as_int(self) -> PRInt32 {
+ let i = match self {
+ Self::Locking => SSLOption::SSL_NO_LOCKS,
+ Self::Tickets => SSLOption::SSL_ENABLE_SESSION_TICKETS,
+ Self::OcspStapling => SSLOption::SSL_ENABLE_OCSP_STAPLING,
+ Self::Alpn => SSLOption::SSL_ENABLE_ALPN,
+ Self::ExtendedMasterSecret => SSLOption::SSL_ENABLE_EXTENDED_MASTER_SECRET,
+ Self::SignedCertificateTimestamps => SSLOption::SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
+ Self::EarlyData => SSLOption::SSL_ENABLE_0RTT_DATA,
+ Self::RecordSizeLimit => SSLOption::SSL_RECORD_SIZE_LIMIT,
+ Self::Tls13CompatMode => SSLOption::SSL_ENABLE_TLS13_COMPAT_MODE,
+ Self::HelloDowngradeCheck => SSLOption::SSL_ENABLE_HELLO_DOWNGRADE_CHECK,
+ Self::SuppressEndOfEarlyData => SSLOption::SSL_SUPPRESS_END_OF_EARLY_DATA,
+ };
+ i as PRInt32
+ }
+
+ // Some options are backwards, like SSL_NO_LOCKS, so use this to manage that.
+ fn map_enabled(self, enabled: bool) -> PRIntn {
+ let v = match self {
+ Self::Locking => !enabled,
+ _ => enabled,
+ };
+ PRIntn::from(v)
+ }
+
+ pub(crate) fn set(self, fd: *mut PRFileDesc, value: bool) -> Res<()> {
+ secstatus_to_res(unsafe { SSL_OptionSet(fd, self.as_int(), self.map_enabled(value)) })
+ }
+}
+
+experimental_api!(SSL_GetCurrentEpoch(
+ fd: *mut PRFileDesc,
+ read_epoch: *mut u16,
+ write_epoch: *mut u16,
+));
+experimental_api!(SSL_HelloRetryRequestCallback(
+ fd: *mut PRFileDesc,
+ cb: SSLHelloRetryRequestCallback,
+ arg: *mut c_void,
+));
+experimental_api!(SSL_RecordLayerWriteCallback(
+ fd: *mut PRFileDesc,
+ cb: SSLRecordWriteCallback,
+ arg: *mut c_void,
+));
+experimental_api!(SSL_RecordLayerData(
+ fd: *mut PRFileDesc,
+ epoch: Epoch,
+ ct: SSLContentType::Type,
+ data: *const u8,
+ len: c_uint,
+));
+experimental_api!(SSL_SendSessionTicket(
+ fd: *mut PRFileDesc,
+ extra: *const u8,
+ len: c_uint,
+));
+experimental_api!(SSL_SetMaxEarlyDataSize(fd: *mut PRFileDesc, size: u32));
+experimental_api!(SSL_SetResumptionToken(
+ fd: *mut PRFileDesc,
+ token: *const u8,
+ len: c_uint,
+));
+experimental_api!(SSL_SetResumptionTokenCallback(
+ fd: *mut PRFileDesc,
+ cb: SSLResumptionTokenCallback,
+ arg: *mut c_void,
+));
+
+experimental_api!(SSL_GetResumptionTokenInfo(
+ token: *const u8,
+ token_len: c_uint,
+ info: *mut SSLResumptionTokenInfo,
+ len: c_uint,
+));
+
+experimental_api!(SSL_DestroyResumptionTokenInfo(
+ info: *mut SSLResumptionTokenInfo,
+));
+
+#[cfg(test)]
+mod tests {
+ use super::{SSL_GetNumImplementedCiphers, SSL_NumImplementedCiphers};
+
+ #[test]
+ fn num_ciphers() {
+ assert!(unsafe { SSL_NumImplementedCiphers } > 0);
+ assert!(unsafe { SSL_GetNumImplementedCiphers() } > 0);
+ assert_eq!(unsafe { SSL_NumImplementedCiphers }, unsafe {
+ SSL_GetNumImplementedCiphers()
+ });
+ }
+}
diff --git a/third_party/rust/neqo-crypto/src/time.rs b/third_party/rust/neqo-crypto/src/time.rs
new file mode 100644
index 0000000000..d3364b7f72
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/time.rs
@@ -0,0 +1,252 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![allow(clippy::upper_case_acronyms)]
+
+use crate::agentio::as_c_void;
+use crate::err::{Error, Res};
+use crate::once::OnceResult;
+use crate::ssl::{PRFileDesc, SSLTimeFunc};
+
+use std::boxed::Box;
+use std::convert::{TryFrom, TryInto};
+use std::ops::Deref;
+use std::os::raw::c_void;
+use std::pin::Pin;
+use std::time::{Duration, Instant};
+
+include!(concat!(env!("OUT_DIR"), "/nspr_time.rs"));
+
+experimental_api!(SSL_SetTimeFunc(
+ fd: *mut PRFileDesc,
+ cb: SSLTimeFunc,
+ arg: *mut c_void,
+));
+
+/// This struct holds the zero time used for converting between `Instant` and `PRTime`.
+#[derive(Debug)]
+struct TimeZero {
+ instant: Instant,
+ prtime: PRTime,
+}
+
+impl TimeZero {
+ /// This function sets a baseline from an instance of `Instant`.
+ /// This allows for the possibility that code that uses these APIs will create
+ /// instances of `Instant` before any of this code is run. If `Instant`s older than
+ /// `BASE_TIME` are used with these conversion functions, they will fail.
+ /// To avoid that, we make sure that this sets the base time using the first value
+ /// it sees if it is in the past. If it is not, then use `Instant::now()` instead.
+ pub fn baseline(t: Instant) -> Self {
+ let now = Instant::now();
+ let prnow = unsafe { PR_Now() };
+
+ if now <= t {
+ // `t` is in the future, just use `now`.
+ Self {
+ instant: now,
+ prtime: prnow,
+ }
+ } else {
+ let elapsed = Interval::from(now.duration_since(now));
+ // An error from these unwrap functions would require
+ // ridiculously long application running time.
+ let prelapsed: PRTime = elapsed.try_into().unwrap();
+ Self {
+ instant: t,
+ prtime: prnow.checked_sub(prelapsed).unwrap(),
+ }
+ }
+ }
+}
+
+static mut BASE_TIME: OnceResult<TimeZero> = OnceResult::new();
+
+fn get_base() -> &'static TimeZero {
+ let f = || TimeZero {
+ instant: Instant::now(),
+ prtime: unsafe { PR_Now() },
+ };
+ unsafe { BASE_TIME.call_once(f) }
+}
+
+pub(crate) fn init() {
+ let _ = get_base();
+}
+
+/// Time wraps Instant and provides conversion functions into `PRTime`.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct Time {
+ t: Instant,
+}
+
+impl Deref for Time {
+ type Target = Instant;
+ fn deref(&self) -> &Self::Target {
+ &self.t
+ }
+}
+
+impl From<Instant> for Time {
+ /// Convert from an Instant into a Time.
+ fn from(t: Instant) -> Self {
+ // Call `TimeZero::baseline(t)` so that time zero can be set.
+ let f = || TimeZero::baseline(t);
+ let _ = unsafe { BASE_TIME.call_once(f) };
+ Self { t }
+ }
+}
+
+impl TryFrom<PRTime> for Time {
+ type Error = Error;
+ fn try_from(prtime: PRTime) -> Res<Self> {
+ let base = get_base();
+ if let Some(delta) = prtime.checked_sub(base.prtime) {
+ let d = Duration::from_micros(delta.try_into()?);
+ base.instant
+ .checked_add(d)
+ .map_or(Err(Error::TimeTravelError), |t| Ok(Self { t }))
+ } else {
+ Err(Error::TimeTravelError)
+ }
+ }
+}
+
+impl TryInto<PRTime> for Time {
+ type Error = Error;
+ fn try_into(self) -> Res<PRTime> {
+ let base = get_base();
+ let delta = self
+ .t
+ .checked_duration_since(base.instant)
+ .ok_or(Error::TimeTravelError)?;
+ if let Ok(d) = PRTime::try_from(delta.as_micros()) {
+ d.checked_add(base.prtime).ok_or(Error::TimeTravelError)
+ } else {
+ Err(Error::TimeTravelError)
+ }
+ }
+}
+
+impl From<Time> for Instant {
+ #[must_use]
+ fn from(t: Time) -> Self {
+ t.t
+ }
+}
+
+/// Interval wraps Duration and provides conversion functions into `PRTime`.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct Interval {
+ d: Duration,
+}
+
+impl Deref for Interval {
+ type Target = Duration;
+ fn deref(&self) -> &Self::Target {
+ &self.d
+ }
+}
+
+impl TryFrom<PRTime> for Interval {
+ type Error = Error;
+ fn try_from(prtime: PRTime) -> Res<Self> {
+ Ok(Self {
+ d: Duration::from_micros(u64::try_from(prtime)?),
+ })
+ }
+}
+
+impl From<Duration> for Interval {
+ fn from(d: Duration) -> Self {
+ Self { d }
+ }
+}
+
+impl TryInto<PRTime> for Interval {
+ type Error = Error;
+ fn try_into(self) -> Res<PRTime> {
+ Ok(PRTime::try_from(self.d.as_micros())?)
+ }
+}
+
+/// `TimeHolder` maintains a `PRTime` value in a form that is accessible to the TLS stack.
+#[derive(Debug)]
+pub struct TimeHolder {
+ t: Pin<Box<PRTime>>,
+}
+
+impl TimeHolder {
+ unsafe extern "C" fn time_func(arg: *mut c_void) -> PRTime {
+ let p = arg as *const PRTime;
+ *p.as_ref().unwrap()
+ }
+
+ pub fn bind(&mut self, fd: *mut PRFileDesc) -> Res<()> {
+ unsafe { SSL_SetTimeFunc(fd, Some(Self::time_func), as_c_void(&mut self.t)) }
+ }
+
+ pub fn set(&mut self, t: Instant) -> Res<()> {
+ *self.t = Time::from(t).try_into()?;
+ Ok(())
+ }
+}
+
+impl Default for TimeHolder {
+ fn default() -> Self {
+ TimeHolder { t: Box::pin(0) }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::{get_base, init, Interval, PRTime, Time};
+ use crate::err::Res;
+ use std::convert::{TryFrom, TryInto};
+ use std::time::{Duration, Instant};
+
+ #[test]
+ fn convert_stable() {
+ init();
+ let now = Time::from(Instant::now());
+ let pr: PRTime = now.try_into().expect("convert to PRTime with truncation");
+ let t2 = Time::try_from(pr).expect("convert to Instant");
+ let pr2: PRTime = t2.try_into().expect("convert to PRTime again");
+ assert_eq!(pr, pr2);
+ let t3 = Time::try_from(pr2).expect("convert to Instant again");
+ assert_eq!(t2, t3);
+ }
+
+ #[test]
+ fn past_time() {
+ init();
+ let base = get_base();
+ assert!(Time::try_from(base.prtime - 1).is_err());
+ }
+
+ #[test]
+ fn negative_time() {
+ init();
+ assert!(Time::try_from(-1).is_err());
+ }
+
+ #[test]
+ fn negative_interval() {
+ init();
+ assert!(Interval::try_from(-1).is_err());
+ }
+
+ #[test]
+ // We allow replace_consts here because
+ // std::u64::max_value() isn't available
+ // in all of our targets
+ fn overflow_interval() {
+ init();
+ let interval = Interval::from(Duration::from_micros(u64::max_value()));
+ let res: Res<PRTime> = interval.try_into();
+ assert!(res.is_err());
+ }
+}