diff options
Diffstat (limited to 'third_party/rust/neqo-transport/src/crypto.rs')
-rw-r--r-- | third_party/rust/neqo-transport/src/crypto.rs | 1498 |
1 files changed, 1498 insertions, 0 deletions
diff --git a/third_party/rust/neqo-transport/src/crypto.rs b/third_party/rust/neqo-transport/src/crypto.rs new file mode 100644 index 0000000000..84a1954b54 --- /dev/null +++ b/third_party/rust/neqo-transport/src/crypto.rs @@ -0,0 +1,1498 @@ +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cell::RefCell; +use std::cmp::{max, min}; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::mem; +use std::ops::{Index, IndexMut, Range}; +use std::rc::Rc; +use std::time::Instant; + +use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role}; + +use neqo_crypto::{ + hkdf, hp::HpKey, Aead, Agent, AntiReplay, Cipher, Epoch, Error as CryptoError, HandshakeState, + PrivateKey, PublicKey, Record, RecordList, ResumptionToken, SymKey, ZeroRttChecker, + TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE, + TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT, + TLS_VERSION_1_3, +}; + +use crate::cid::ConnectionIdRef; +use crate::packet::{PacketBuilder, PacketNumber}; +use crate::recovery::RecoveryToken; +use crate::recv_stream::RxStreamOrderer; +use crate::send_stream::TxBuffer; +use crate::stats::FrameStats; +use crate::tparams::{TpZeroRttChecker, TransportParameters, TransportParametersHandler}; +use crate::tracking::PacketNumberSpace; +use crate::version::Version; +use crate::{Error, Res}; + +const MAX_AUTH_TAG: usize = 32; +/// The number of invocations remaining on a write cipher before we try +/// to update keys. This has to be much smaller than the number returned +/// by `CryptoDxState::limit` or updates will happen too often. As we don't +/// need to ask permission to update, this can be quite small. +pub(crate) const UPDATE_WRITE_KEYS_AT: PacketNumber = 100; + +// This is a testing kludge that allows for overwriting the number of +// invocations of the next cipher to operate. With this, it is possible +// to test what happens when the number of invocations reaches 0, or +// when it hits `UPDATE_WRITE_KEYS_AT` and an automatic update should occur. +// This is a little crude, but it saves a lot of plumbing. +#[cfg(test)] +thread_local!(pub(crate) static OVERWRITE_INVOCATIONS: RefCell<Option<PacketNumber>> = RefCell::default()); + +#[derive(Debug)] +pub struct Crypto { + version: Version, + protocols: Vec<String>, + pub(crate) tls: Agent, + pub(crate) streams: CryptoStreams, + pub(crate) states: CryptoStates, +} + +type TpHandler = Rc<RefCell<TransportParametersHandler>>; + +impl Crypto { + pub fn new( + version: Version, + mut agent: Agent, + protocols: Vec<String>, + tphandler: TpHandler, + ) -> Res<Self> { + agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; + agent.set_ciphers(&[ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + ])?; + agent.set_alpn(&protocols)?; + agent.disable_end_of_early_data()?; + // Always enable 0-RTT on the client, but the server needs + // more configuration passed to server_enable_0rtt. + if let Agent::Client(c) = &mut agent { + c.enable_0rtt()?; + } + let extension = match version { + Version::Version2 | Version::Version1 => 0x39, + Version::Draft29 | Version::Draft30 | Version::Draft31 | Version::Draft32 => 0xffa5, + }; + agent.extension_handler(extension, tphandler)?; + Ok(Self { + version, + protocols, + tls: agent, + streams: Default::default(), + states: Default::default(), + }) + } + + /// Get the name of the server. (Only works for the client currently). + pub fn server_name(&self) -> Option<&str> { + if let Agent::Client(c) = &self.tls { + Some(c.server_name()) + } else { + None + } + } + + /// Get the set of enabled protocols. + pub fn protocols(&self) -> &[String] { + &self.protocols + } + + pub fn server_enable_0rtt( + &mut self, + tphandler: TpHandler, + anti_replay: &AntiReplay, + zero_rtt_checker: impl ZeroRttChecker + 'static, + ) -> Res<()> { + if let Agent::Server(s) = &mut self.tls { + Ok(s.enable_0rtt( + anti_replay, + 0xffff_ffff, + TpZeroRttChecker::wrap(tphandler, zero_rtt_checker), + )?) + } else { + panic!("not a server"); + } + } + + pub fn server_enable_ech( + &mut self, + config: u8, + public_name: &str, + sk: &PrivateKey, + pk: &PublicKey, + ) -> Res<()> { + if let Agent::Server(s) = &mut self.tls { + s.enable_ech(config, public_name, sk, pk)?; + Ok(()) + } else { + panic!("not a client"); + } + } + + pub fn client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> { + if let Agent::Client(c) = &mut self.tls { + c.enable_ech(ech_config_list)?; + Ok(()) + } else { + panic!("not a client"); + } + } + + /// Get the active ECH configuration, which is empty if ECH is disabled. + pub fn ech_config(&self) -> &[u8] { + self.tls.ech_config() + } + + pub fn handshake( + &mut self, + now: Instant, + space: PacketNumberSpace, + data: Option<&[u8]>, + ) -> Res<&HandshakeState> { + let input = data.map(|d| { + qtrace!("Handshake record received {:0x?} ", d); + let epoch = match space { + PacketNumberSpace::Initial => TLS_EPOCH_INITIAL, + PacketNumberSpace::Handshake => TLS_EPOCH_HANDSHAKE, + // Our epoch progresses forward, but the TLS epoch is fixed to 3. + PacketNumberSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA, + }; + Record { + ct: TLS_CT_HANDSHAKE, + epoch, + data: d.to_vec(), + } + }); + + match self.tls.handshake_raw(now, input) { + Ok(output) => { + self.buffer_records(output)?; + Ok(self.tls.state()) + } + Err(CryptoError::EchRetry(v)) => Err(Error::EchRetry(v)), + Err(e) => { + qinfo!("Handshake failed {:?}", e); + Err(match self.tls.alert() { + Some(a) => Error::CryptoAlert(*a), + _ => Error::CryptoError(e), + }) + } + } + } + + /// Enable 0-RTT and return `true` if it is enabled successfully. + pub fn enable_0rtt(&mut self, version: Version, role: Role) -> Res<bool> { + let info = self.tls.preinfo()?; + // `info.early_data()` returns false for a server, + // so use `early_data_cipher()` to tell if 0-RTT is enabled. + let cipher = info.early_data_cipher(); + if cipher.is_none() { + return Ok(false); + } + let (dir, secret) = match role { + Role::Client => ( + CryptoDxDirection::Write, + self.tls.write_secret(TLS_EPOCH_ZERO_RTT), + ), + Role::Server => ( + CryptoDxDirection::Read, + self.tls.read_secret(TLS_EPOCH_ZERO_RTT), + ), + }; + let secret = secret.ok_or(Error::InternalError(1))?; + self.states + .set_0rtt_keys(version, dir, &secret, cipher.unwrap()); + Ok(true) + } + + /// Lock in a compatible upgrade. + pub fn confirm_version(&mut self, confirmed: Version) { + self.states.confirm_version(self.version, confirmed); + self.version = confirmed; + } + + /// Returns true if new handshake keys were installed. + pub fn install_keys(&mut self, role: Role) -> Res<bool> { + if !self.tls.state().is_final() { + let installed_hs = self.install_handshake_keys()?; + if role == Role::Server { + self.maybe_install_application_write_key(self.version)?; + } + Ok(installed_hs) + } else { + Ok(false) + } + } + + fn install_handshake_keys(&mut self) -> Res<bool> { + qtrace!([self], "Attempt to install handshake keys"); + let write_secret = if let Some(secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) { + secret + } else { + // No keys is fine. + return Ok(false); + }; + let read_secret = self + .tls + .read_secret(TLS_EPOCH_HANDSHAKE) + .ok_or(Error::InternalError(2))?; + let cipher = match self.tls.info() { + None => self.tls.preinfo()?.cipher_suite(), + Some(info) => Some(info.cipher_suite()), + } + .ok_or(Error::InternalError(3))?; + self.states + .set_handshake_keys(self.version, &write_secret, &read_secret, cipher); + qdebug!([self], "Handshake keys installed"); + Ok(true) + } + + fn maybe_install_application_write_key(&mut self, version: Version) -> Res<()> { + qtrace!([self], "Attempt to install application write key"); + if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) { + self.states.set_application_write_key(version, secret)?; + qdebug!([self], "Application write key installed"); + } + Ok(()) + } + + pub fn install_application_keys(&mut self, version: Version, expire_0rtt: Instant) -> Res<()> { + self.maybe_install_application_write_key(version)?; + // The write key might have been installed earlier, but it should + // always be installed now. + debug_assert!(self.states.app_write.is_some()); + let read_secret = self + .tls + .read_secret(TLS_EPOCH_APPLICATION_DATA) + .ok_or(Error::InternalError(4))?; + self.states + .set_application_read_key(version, read_secret, expire_0rtt)?; + qdebug!([self], "application read keys installed"); + Ok(()) + } + + /// Buffer crypto records for sending. + pub fn buffer_records(&mut self, records: RecordList) -> Res<()> { + for r in records { + if r.ct != TLS_CT_HANDSHAKE { + return Err(Error::ProtocolViolation); + } + qtrace!([self], "Adding CRYPTO data {:?}", r); + self.streams.send(PacketNumberSpace::from(r.epoch), &r.data); + } + Ok(()) + } + + pub fn write_frame( + &mut self, + space: PacketNumberSpace, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + self.streams.write_frame(space, builder, tokens, stats) + } + + pub fn acked(&mut self, token: &CryptoRecoveryToken) { + qinfo!( + "Acked crypto frame space={} offset={} length={}", + token.space, + token.offset, + token.length + ); + self.streams.acked(token); + } + + pub fn lost(&mut self, token: &CryptoRecoveryToken) { + qinfo!( + "Lost crypto frame space={} offset={} length={}", + token.space, + token.offset, + token.length + ); + self.streams.lost(token); + } + + /// Mark any outstanding frames in the indicated space as "lost" so + /// that they can be sent again. + pub fn resend_unacked(&mut self, space: PacketNumberSpace) { + self.streams.resend_unacked(space); + } + + /// Discard state for a packet number space and return true + /// if something was discarded. + pub fn discard(&mut self, space: PacketNumberSpace) -> bool { + self.streams.discard(space); + self.states.discard(space) + } + + pub fn create_resumption_token( + &mut self, + new_token: Option<&[u8]>, + tps: &TransportParameters, + version: Version, + rtt: u64, + ) -> Option<ResumptionToken> { + if let Agent::Client(ref mut c) = self.tls { + if let Some(ref t) = c.resumption_token() { + qtrace!("TLS token {}", hex(t.as_ref())); + let mut enc = Encoder::default(); + enc.encode_uint(4, version.wire_version()); + enc.encode_varint(rtt); + enc.encode_vvec_with(|enc_inner| { + tps.encode(enc_inner); + }); + enc.encode_vvec(new_token.unwrap_or(&[])); + enc.encode(t.as_ref()); + qinfo!("resumption token {}", hex_snip_middle(enc.as_ref())); + Some(ResumptionToken::new(enc.into(), t.expiration_time())) + } else { + None + } + } else { + unreachable!("It is a server."); + } + } + + pub fn has_resumption_token(&self) -> bool { + if let Agent::Client(c) = &self.tls { + c.has_resumption_token() + } else { + unreachable!("It is a server."); + } + } +} + +impl ::std::fmt::Display for Crypto { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Crypto") + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CryptoDxDirection { + Read, + Write, +} + +#[derive(Debug)] +pub struct CryptoDxState { + /// The QUIC version. + version: Version, + /// Whether packets protected with this state will be read or written. + direction: CryptoDxDirection, + /// The epoch of this crypto state. This initially tracks TLS epochs + /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application. + /// But we don't need to keep that, and QUIC isn't limited in how + /// many times keys can be updated, so we don't use `u16` for this. + epoch: usize, + aead: Aead, + hpkey: HpKey, + /// This tracks the range of packet numbers that have been seen. This allows + /// for verifying that packet numbers before a key update are strictly lower + /// than packet numbers after a key update. + used_pn: Range<PacketNumber>, + /// This is the minimum packet number that is allowed. + min_pn: PacketNumber, + /// The total number of operations that are remaining before the keys + /// become exhausted and can't be used any more. + invocations: PacketNumber, +} + +impl CryptoDxState { + #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range. + pub fn new( + version: Version, + direction: CryptoDxDirection, + epoch: Epoch, + secret: &SymKey, + cipher: Cipher, + ) -> Self { + qinfo!( + "Making {:?} {} CryptoDxState, v={:?} cipher={}", + direction, + epoch, + version, + cipher, + ); + let hplabel = String::from(version.label_prefix()) + "hp"; + Self { + version, + direction, + epoch: usize::from(epoch), + aead: Aead::new(TLS_VERSION_1_3, cipher, secret, version.label_prefix()).unwrap(), + hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, &hplabel).unwrap(), + used_pn: 0..0, + min_pn: 0, + invocations: Self::limit(direction, cipher), + } + } + + pub fn new_initial( + version: Version, + direction: CryptoDxDirection, + label: &str, + dcid: &[u8], + ) -> Self { + qtrace!("new_initial {:?} {}", version, ConnectionIdRef::from(dcid)); + let salt = version.initial_salt(); + let cipher = TLS_AES_128_GCM_SHA256; + let initial_secret = hkdf::extract( + TLS_VERSION_1_3, + cipher, + Some(hkdf::import_key(TLS_VERSION_1_3, salt).as_ref().unwrap()), + hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(), + ) + .unwrap(); + + let secret = + hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap(); + + Self::new(version, direction, TLS_EPOCH_INITIAL, &secret, cipher) + } + + /// Determine the confidentiality and integrity limits for the cipher. + fn limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber { + match direction { + // This uses the smaller limits for 2^16 byte packets + // as we don't control incoming packet size. + CryptoDxDirection::Read => match cipher { + TLS_AES_128_GCM_SHA256 => 1 << 52, + TLS_AES_256_GCM_SHA384 => PacketNumber::MAX, + TLS_CHACHA20_POLY1305_SHA256 => 1 << 36, + _ => unreachable!(), + }, + // This uses the larger limits for 2^11 byte packets. + CryptoDxDirection::Write => match cipher { + TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => 1 << 28, + TLS_CHACHA20_POLY1305_SHA256 => PacketNumber::MAX, + _ => unreachable!(), + }, + } + } + + fn invoked(&mut self) -> Res<()> { + #[cfg(test)] + OVERWRITE_INVOCATIONS.with(|v| { + if let Some(i) = v.borrow_mut().take() { + neqo_common::qwarn!("Setting {:?} invocations to {}", self.direction, i); + self.invocations = i; + } + }); + self.invocations = self + .invocations + .checked_sub(1) + .ok_or(Error::KeysExhausted)?; + Ok(()) + } + + /// Determine whether we should initiate a key update. + pub fn should_update(&self) -> bool { + // There is no point in updating read keys as the limit is global. + debug_assert_eq!(self.direction, CryptoDxDirection::Write); + self.invocations <= UPDATE_WRITE_KEYS_AT + } + + pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self { + let pn = self.next_pn(); + // We count invocations of each write key just for that key, but all + // attempts to invocations to read count toward a single limit. + // This doesn't count use of Handshake keys. + let invocations = if self.direction == CryptoDxDirection::Read { + self.invocations + } else { + Self::limit(CryptoDxDirection::Write, cipher) + }; + Self { + version: self.version, + direction: self.direction, + epoch: self.epoch + 1, + aead: Aead::new( + TLS_VERSION_1_3, + cipher, + next_secret, + self.version.label_prefix(), + ) + .unwrap(), + hpkey: self.hpkey.clone(), + used_pn: pn..pn, + min_pn: pn, + invocations, + } + } + + #[must_use] + pub fn version(&self) -> Version { + self.version + } + + #[must_use] + pub fn key_phase(&self) -> bool { + // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ... + self.epoch & 1 != 1 + } + + /// This is a continuation of a previous, so adjust the range accordingly. + /// Fail if the two ranges overlap. Do nothing if the directions don't match. + pub fn continuation(&mut self, prev: &Self) -> Res<()> { + debug_assert_eq!(self.direction, prev.direction); + let next = prev.next_pn(); + self.min_pn = next; + if self.used_pn.is_empty() { + self.used_pn = next..next; + Ok(()) + } else if prev.used_pn.end > self.used_pn.start { + qdebug!( + [self], + "Found packet with too new packet number {} > {}, compared to {}", + self.used_pn.start, + prev.used_pn.end, + prev, + ); + Err(Error::PacketNumberOverlap) + } else { + self.used_pn.start = next; + Ok(()) + } + } + + /// Mark a packet number as used. If this is too low, reject it. + /// Note that this won't catch a value that is too high if packets protected with + /// old keys are received after a key update. That needs to be caught elsewhere. + pub fn used(&mut self, pn: PacketNumber) -> Res<()> { + if pn < self.min_pn { + qdebug!( + [self], + "Found packet with too old packet number: {} < {}", + pn, + self.min_pn + ); + return Err(Error::PacketNumberOverlap); + } + if self.used_pn.start == self.used_pn.end { + self.used_pn.start = pn; + } + self.used_pn.end = max(pn + 1, self.used_pn.end); + Ok(()) + } + + #[must_use] + pub fn needs_update(&self) -> bool { + // Only initiate a key update if we have processed exactly one packet + // and we are in an epoch greater than 3. + self.used_pn.start + 1 == self.used_pn.end + && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA) + } + + #[must_use] + pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool { + if let Some(la) = largest_acknowledged { + self.used_pn.contains(&la) + } else { + // If we haven't received any acknowledgments, it's OK to update + // the first application data epoch. + self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA) + } + } + + pub fn compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>> { + let mask = self.hpkey.mask(sample)?; + qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask)); + Ok(mask) + } + + #[must_use] + pub fn next_pn(&self) -> PacketNumber { + self.used_pn.end + } + + pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> { + debug_assert_eq!(self.direction, CryptoDxDirection::Write); + qtrace!( + [self], + "encrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + // The numbers in `Self::limit` assume a maximum packet size of 2^11. + if body.len() > 2048 { + debug_assert!(false); + return Err(Error::InternalError(12)); + } + self.invoked()?; + + let size = body.len() + MAX_AUTH_TAG; + let mut out = vec![0; size]; + let res = self.aead.encrypt(pn, hdr, body, &mut out)?; + + qtrace!([self], "encrypt ct={}", hex(res)); + debug_assert_eq!(pn, self.next_pn()); + self.used(pn)?; + Ok(res.to_vec()) + } + + #[must_use] + pub fn expansion(&self) -> usize { + self.aead.expansion() + } + + pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> { + debug_assert_eq!(self.direction, CryptoDxDirection::Read); + qtrace!( + [self], + "decrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + self.invoked()?; + let mut out = vec![0; body.len()]; + let res = self.aead.decrypt(pn, hdr, body, &mut out)?; + self.used(pn)?; + Ok(res.to_vec()) + } + + #[cfg(all(test, not(feature = "fuzzing")))] + pub(crate) fn test_default() -> Self { + // This matches the value in packet.rs + const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + Self::new_initial( + Version::default(), + CryptoDxDirection::Write, + "server in", + CLIENT_CID, + ) + } + + /// Get the amount of extra padding packets protected with this profile need. + /// This is the difference between the size of the header protection sample + /// and the AEAD expansion. + pub fn extra_padding(&self) -> usize { + self.hpkey + .sample_size() + .saturating_sub(self.aead.expansion()) + } +} + +impl std::fmt::Display for CryptoDxState { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "epoch {} {:?}", self.epoch, self.direction) + } +} + +#[derive(Debug)] +pub struct CryptoState { + tx: CryptoDxState, + rx: CryptoDxState, +} + +impl Index<CryptoDxDirection> for CryptoState { + type Output = CryptoDxState; + + fn index(&self, dir: CryptoDxDirection) -> &Self::Output { + match dir { + CryptoDxDirection::Read => &self.rx, + CryptoDxDirection::Write => &self.tx, + } + } +} + +impl IndexMut<CryptoDxDirection> for CryptoState { + fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output { + match dir { + CryptoDxDirection::Read => &mut self.rx, + CryptoDxDirection::Write => &mut self.tx, + } + } +} + +/// `CryptoDxAppData` wraps the state necessary for one direction of application data keys. +/// This includes the secret needed to generate the next set of keys. +#[derive(Debug)] +pub(crate) struct CryptoDxAppData { + dx: CryptoDxState, + cipher: Cipher, + // Not the secret used to create `self.dx`, but the one needed for the next iteration. + next_secret: SymKey, +} + +impl CryptoDxAppData { + pub fn new( + version: Version, + dir: CryptoDxDirection, + secret: SymKey, + cipher: Cipher, + ) -> Res<Self> { + Ok(Self { + dx: CryptoDxState::new(version, dir, TLS_EPOCH_APPLICATION_DATA, &secret, cipher), + cipher, + next_secret: Self::update_secret(cipher, &secret)?, + }) + } + + fn update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey> { + let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?; + Ok(next) + } + + pub fn next(&self) -> Res<Self> { + if self.dx.epoch == usize::max_value() { + // Guard against too many key updates. + return Err(Error::KeysExhausted); + } + let next_secret = Self::update_secret(self.cipher, &self.next_secret)?; + Ok(Self { + dx: self.dx.next(&self.next_secret, self.cipher), + cipher: self.cipher, + next_secret, + }) + } + + pub fn epoch(&self) -> usize { + self.dx.epoch + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum CryptoSpace { + Initial, + ZeroRtt, + Handshake, + ApplicationData, +} + +/// All of the keying material needed for a connection. +/// +/// Note that the methods on this struct take a version but those are only ever +/// used for Initial keys; a version has been selected at the time we need to +/// get other keys, so those have fixed versions. +#[derive(Debug, Default)] +pub struct CryptoStates { + initials: HashMap<Version, CryptoState>, + handshake: Option<CryptoState>, + zero_rtt: Option<CryptoDxState>, // One direction only! + cipher: Cipher, + app_write: Option<CryptoDxAppData>, + app_read: Option<CryptoDxAppData>, + app_read_next: Option<CryptoDxAppData>, + // If this is set, then we have noticed a genuine update. + // Once this time passes, we should switch in new keys. + read_update_time: Option<Instant>, +} + +impl CryptoStates { + /// Select a `CryptoDxState` and `CryptoSpace` for the given `PacketNumberSpace`. + /// This selects 0-RTT keys for `PacketNumberSpace::ApplicationData` if 1-RTT keys are + /// not yet available. + pub fn select_tx_mut( + &mut self, + version: Version, + space: PacketNumberSpace, + ) -> Option<(CryptoSpace, &mut CryptoDxState)> { + match space { + PacketNumberSpace::Initial => self + .tx_mut(version, CryptoSpace::Initial) + .map(|dx| (CryptoSpace::Initial, dx)), + PacketNumberSpace::Handshake => self + .tx_mut(version, CryptoSpace::Handshake) + .map(|dx| (CryptoSpace::Handshake, dx)), + PacketNumberSpace::ApplicationData => { + if let Some(app) = self.app_write.as_mut() { + Some((CryptoSpace::ApplicationData, &mut app.dx)) + } else { + self.zero_rtt.as_mut().map(|dx| (CryptoSpace::ZeroRtt, dx)) + } + } + } + } + + pub fn tx_mut<'a>( + &'a mut self, + version: Version, + cspace: CryptoSpace, + ) -> Option<&'a mut CryptoDxState> { + let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx); + match cspace { + CryptoSpace::Initial => tx(self.initials.get_mut(&version)), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_mut() + .filter(|z| z.direction == CryptoDxDirection::Write), + CryptoSpace::Handshake => tx(self.handshake.as_mut()), + CryptoSpace::ApplicationData => self.app_write.as_mut().map(|app| &mut app.dx), + } + } + + pub fn tx<'a>(&'a self, version: Version, cspace: CryptoSpace) -> Option<&'a CryptoDxState> { + let tx = |k: Option<&'a CryptoState>| k.map(|dx| &dx.tx); + match cspace { + CryptoSpace::Initial => tx(self.initials.get(&version)), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_ref() + .filter(|z| z.direction == CryptoDxDirection::Write), + CryptoSpace::Handshake => tx(self.handshake.as_ref()), + CryptoSpace::ApplicationData => self.app_write.as_ref().map(|app| &app.dx), + } + } + + pub fn select_tx( + &self, + version: Version, + space: PacketNumberSpace, + ) -> Option<(CryptoSpace, &CryptoDxState)> { + match space { + PacketNumberSpace::Initial => self + .tx(version, CryptoSpace::Initial) + .map(|dx| (CryptoSpace::Initial, dx)), + PacketNumberSpace::Handshake => self + .tx(version, CryptoSpace::Handshake) + .map(|dx| (CryptoSpace::Handshake, dx)), + PacketNumberSpace::ApplicationData => { + if let Some(app) = self.app_write.as_ref() { + Some((CryptoSpace::ApplicationData, &app.dx)) + } else { + self.zero_rtt.as_ref().map(|dx| (CryptoSpace::ZeroRtt, dx)) + } + } + } + } + + pub fn rx_hp(&mut self, version: Version, cspace: CryptoSpace) -> Option<&mut CryptoDxState> { + if let CryptoSpace::ApplicationData = cspace { + self.app_read.as_mut().map(|ar| &mut ar.dx) + } else { + self.rx(version, cspace, false) + } + } + + pub fn rx<'a>( + &'a mut self, + version: Version, + cspace: CryptoSpace, + key_phase: bool, + ) -> Option<&'a mut CryptoDxState> { + let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx); + match cspace { + CryptoSpace::Initial => rx(self.initials.get_mut(&version)), + CryptoSpace::ZeroRtt => self + .zero_rtt + .as_mut() + .filter(|z| z.direction == CryptoDxDirection::Read), + CryptoSpace::Handshake => rx(self.handshake.as_mut()), + CryptoSpace::ApplicationData => { + let f = |a: Option<&'a mut CryptoDxAppData>| { + a.filter(|ar| ar.dx.key_phase() == key_phase) + }; + // XOR to reduce the leakage about which key is chosen. + f(self.app_read.as_mut()) + .xor(f(self.app_read_next.as_mut())) + .map(|ar| &mut ar.dx) + } + } + } + + /// Whether keys for processing packets in the indicated space are pending. + /// This allows the caller to determine whether to save a packet for later + /// when keys are not available. + /// NOTE: 0-RTT keys are not considered here. The expectation is that a + /// server will have to save 0-RTT packets in a different place. Though it + /// is possible to attribute 0-RTT packets to an existing connection if there + /// is a multi-packet Initial, that is an unusual circumstance, so we + /// don't do caching for that in those places that call this function. + pub fn rx_pending(&self, space: CryptoSpace) -> bool { + match space { + CryptoSpace::Initial | CryptoSpace::ZeroRtt => false, + CryptoSpace::Handshake => self.handshake.is_none() && !self.initials.is_empty(), + CryptoSpace::ApplicationData => self.app_read.is_none(), + } + } + + /// Create the initial crypto state. + /// Note that the version here can change and that's OK. + pub fn init<'v, V>(&mut self, versions: V, role: Role, dcid: &[u8]) + where + V: IntoIterator<Item = &'v Version>, + { + const CLIENT_INITIAL_LABEL: &str = "client in"; + const SERVER_INITIAL_LABEL: &str = "server in"; + + let (write, read) = match role { + Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL), + Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL), + }; + + for v in versions { + qinfo!( + [self], + "Creating initial cipher state v={:?}, role={:?} dcid={}", + v, + role, + hex(dcid) + ); + + let mut initial = CryptoState { + tx: CryptoDxState::new_initial(*v, CryptoDxDirection::Write, write, dcid), + rx: CryptoDxState::new_initial(*v, CryptoDxDirection::Read, read, dcid), + }; + if let Some(prev) = self.initials.get(v) { + qinfo!( + [self], + "Continue packet numbers for initial after retry (write is {:?})", + prev.rx.used_pn, + ); + initial.tx.continuation(&prev.tx).unwrap(); + } + self.initials.insert(*v, initial); + } + } + + /// At a server, we can be more targeted in initializing. + /// Initialize on demand: either to decrypt Initial packets that we receive + /// or after a version has been selected. + /// This is maybe slightly inefficient in the first case, because we might + /// not need the send keys if the packet is subsequently discarded, but + /// the overall effort is small enough to write off. + pub fn init_server(&mut self, version: Version, dcid: &[u8]) { + if !self.initials.contains_key(&version) { + self.init(&[version], Role::Server, dcid); + } + } + + pub fn confirm_version(&mut self, orig: Version, confirmed: Version) { + if orig != confirmed { + // This part where the old data is removed and then re-added is to + // appease the borrow checker. + // Note that on the server, we might not have initials for |orig| if it + // was configured for |orig| and only |confirmed| Initial packets arrived. + if let Some(prev) = self.initials.remove(&orig) { + let next = self.initials.get_mut(&confirmed).unwrap(); + next.tx.continuation(&prev.tx).unwrap(); + self.initials.insert(orig, prev); + } + } + } + + pub fn set_0rtt_keys( + &mut self, + version: Version, + dir: CryptoDxDirection, + secret: &SymKey, + cipher: Cipher, + ) { + qtrace!([self], "install 0-RTT keys"); + self.zero_rtt = Some(CryptoDxState::new( + version, + dir, + TLS_EPOCH_ZERO_RTT, + secret, + cipher, + )); + } + + /// Discard keys and return true if that happened. + pub fn discard(&mut self, space: PacketNumberSpace) -> bool { + match space { + PacketNumberSpace::Initial => { + let empty = self.initials.is_empty(); + self.initials.clear(); + !empty + } + PacketNumberSpace::Handshake => self.handshake.take().is_some(), + PacketNumberSpace::ApplicationData => panic!("Can't drop application data keys"), + } + } + + pub fn discard_0rtt_keys(&mut self) { + qtrace!([self], "discard 0-RTT keys"); + assert!( + self.app_read.is_none(), + "Can't discard 0-RTT after setting application keys" + ); + self.zero_rtt = None; + } + + pub fn set_handshake_keys( + &mut self, + version: Version, + write_secret: &SymKey, + read_secret: &SymKey, + cipher: Cipher, + ) { + self.cipher = cipher; + self.handshake = Some(CryptoState { + tx: CryptoDxState::new( + version, + CryptoDxDirection::Write, + TLS_EPOCH_HANDSHAKE, + write_secret, + cipher, + ), + rx: CryptoDxState::new( + version, + CryptoDxDirection::Read, + TLS_EPOCH_HANDSHAKE, + read_secret, + cipher, + ), + }); + } + + pub fn set_application_write_key(&mut self, version: Version, secret: SymKey) -> Res<()> { + debug_assert!(self.app_write.is_none()); + debug_assert_ne!(self.cipher, 0); + let mut app = CryptoDxAppData::new(version, CryptoDxDirection::Write, secret, self.cipher)?; + if let Some(z) = &self.zero_rtt { + if z.direction == CryptoDxDirection::Write { + app.dx.continuation(z)?; + } + } + self.zero_rtt = None; + self.app_write = Some(app); + Ok(()) + } + + pub fn set_application_read_key( + &mut self, + version: Version, + secret: SymKey, + expire_0rtt: Instant, + ) -> Res<()> { + debug_assert!(self.app_write.is_some(), "should have write keys installed"); + debug_assert!(self.app_read.is_none()); + let mut app = CryptoDxAppData::new(version, CryptoDxDirection::Read, secret, self.cipher)?; + if let Some(z) = &self.zero_rtt { + if z.direction == CryptoDxDirection::Read { + app.dx.continuation(z)?; + } + self.read_update_time = Some(expire_0rtt); + } + self.app_read_next = Some(app.next()?); + self.app_read = Some(app); + Ok(()) + } + + /// Update the write keys. + pub fn initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()> { + // Only update if we are able to. We can only do this if we have + // received an acknowledgement for a packet in the current phase. + // Also, skip this if we are waiting for read keys on the existing + // key update to be rolled over. + let write = &self.app_write.as_ref().unwrap().dx; + if write.can_update(largest_acknowledged) && self.read_update_time.is_none() { + // This call additionally checks that we don't advance to the next + // epoch while a key update is in progress. + if self.maybe_update_write()? { + Ok(()) + } else { + qdebug!([self], "Write keys already updated"); + Err(Error::KeyUpdateBlocked) + } + } else { + qdebug!([self], "Waiting for ACK or blocked on read key timer"); + Err(Error::KeyUpdateBlocked) + } + } + + /// Try to update, and return true if it happened. + fn maybe_update_write(&mut self) -> Res<bool> { + // Update write keys. But only do so if the write keys are not already + // ahead of the read keys. If we initiated the key update, the write keys + // will already be ahead. + debug_assert!(self.read_update_time.is_none()); + let write = &self.app_write.as_ref().unwrap(); + let read = &self.app_read.as_ref().unwrap(); + if write.epoch() == read.epoch() { + qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1); + self.app_write = Some(write.next()?); + Ok(true) + } else { + Ok(false) + } + } + + /// Check whether write keys are close to running out of invocations. + /// If that is close, update them if possible. Failing to update at + /// this stage is cause for a fatal error. + pub fn auto_update(&mut self) -> Res<()> { + if let Some(app_write) = self.app_write.as_ref() { + if app_write.dx.should_update() { + qinfo!([self], "Initiating automatic key update"); + if !self.maybe_update_write()? { + return Err(Error::KeysExhausted); + } + } + } + Ok(()) + } + + fn has_0rtt_read(&self) -> bool { + self.zero_rtt + .as_ref() + .filter(|z| z.direction == CryptoDxDirection::Read) + .is_some() + } + + /// Prepare to update read keys. This doesn't happen immediately as + /// we want to ensure that we can continue to receive any delayed + /// packets that use the old keys. So we just set a timer. + pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> { + qtrace!([self], "Key update received"); + // If we received a key update, then we assume that the peer has + // acknowledged a packet we sent in this epoch. It's OK to do that + // because they aren't allowed to update without first having received + // something from us. If the ACK isn't in the packet that triggered this + // key update, it must be in some other packet they have sent. + let _ = self.maybe_update_write()?; + + // We shouldn't have 0-RTT keys at this point, but if we do, dump them. + debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read()); + if self.has_0rtt_read() { + self.zero_rtt = None; + } + self.read_update_time = Some(expiration); + Ok(()) + } + + #[must_use] + pub fn update_time(&self) -> Option<Instant> { + self.read_update_time + } + + /// Check if time has passed for updating key update parameters. + /// If it has, then swap keys over and allow more key updates to be initiated. + /// This is also used to discard 0-RTT read keys at the server in the same way. + pub fn check_key_update(&mut self, now: Instant) -> Res<()> { + if let Some(expiry) = self.read_update_time { + // If enough time has passed, then install new keys and clear the timer. + if now >= expiry { + if self.has_0rtt_read() { + qtrace!([self], "Discarding 0-RTT keys"); + self.zero_rtt = None; + } else { + qtrace!([self], "Rotating read keys"); + mem::swap(&mut self.app_read, &mut self.app_read_next); + self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?); + } + self.read_update_time = None; + } + } + Ok(()) + } + + /// Get the current/highest epoch. This returns (write, read) epochs. + #[cfg(test)] + pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) { + let to_epoch = |app: &Option<CryptoDxAppData>| app.as_ref().map(|a| a.dx.epoch); + (to_epoch(&self.app_write), to_epoch(&self.app_read)) + } + + /// While we are awaiting the completion of a key update, we might receive + /// valid packets that are protected with old keys. We need to ensure that + /// these don't carry packet numbers higher than those in packets protected + /// with the newer keys. To ensure that, this is called after every decryption. + pub fn check_pn_overlap(&mut self) -> Res<()> { + // We only need to do the check while we are waiting for read keys to be updated. + if self.read_update_time.is_some() { + qtrace!([self], "Checking for PN overlap"); + let next_dx = &mut self.app_read_next.as_mut().unwrap().dx; + next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?; + } + Ok(()) + } + + /// Make some state for removing protection in tests. + #[cfg(not(feature = "fuzzing"))] + #[cfg(test)] + pub(crate) fn test_default() -> Self { + let read = |epoch| { + let mut dx = CryptoDxState::test_default(); + dx.direction = CryptoDxDirection::Read; + dx.epoch = epoch; + dx + }; + let app_read = |epoch| CryptoDxAppData { + dx: read(epoch), + cipher: TLS_AES_128_GCM_SHA256, + next_secret: hkdf::import_key(TLS_VERSION_1_3, &[0xaa; 32]).unwrap(), + }; + let mut initials = HashMap::new(); + initials.insert( + Version::Version1, + CryptoState { + tx: CryptoDxState::test_default(), + rx: read(0), + }, + ); + Self { + initials, + handshake: None, + zero_rtt: None, + cipher: TLS_AES_128_GCM_SHA256, + // This isn't used, but the epoch is read to check for a key update. + app_write: Some(app_read(3)), + app_read: Some(app_read(3)), + app_read_next: Some(app_read(4)), + read_update_time: None, + } + } + + #[cfg(all(not(feature = "fuzzing"), test))] + pub(crate) fn test_chacha() -> Self { + const SECRET: &[u8] = &[ + 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad, + 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3, + 0x0f, 0x21, 0x63, 0x2b, + ]; + let secret = hkdf::import_key(TLS_VERSION_1_3, SECRET).unwrap(); + let app_read = |epoch| CryptoDxAppData { + dx: CryptoDxState { + version: Version::Version1, + direction: CryptoDxDirection::Read, + epoch, + aead: Aead::new( + TLS_VERSION_1_3, + TLS_CHACHA20_POLY1305_SHA256, + &secret, + "quic ", // This is a v1 test so hard-code the label. + ) + .unwrap(), + hpkey: HpKey::extract( + TLS_VERSION_1_3, + TLS_CHACHA20_POLY1305_SHA256, + &secret, + "quic hp", + ) + .unwrap(), + used_pn: 0..645_971_972, + min_pn: 0, + invocations: 10, + }, + cipher: TLS_CHACHA20_POLY1305_SHA256, + next_secret: secret.clone(), + }; + Self { + initials: HashMap::new(), + handshake: None, + zero_rtt: None, + cipher: TLS_CHACHA20_POLY1305_SHA256, + app_write: Some(app_read(3)), + app_read: Some(app_read(3)), + app_read_next: Some(app_read(4)), + read_update_time: None, + } + } +} + +impl std::fmt::Display for CryptoStates { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CryptoStates") + } +} + +#[derive(Debug, Default)] +pub struct CryptoStream { + tx: TxBuffer, + rx: RxStreamOrderer, +} + +#[derive(Debug)] +#[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408 +pub enum CryptoStreams { + Initial { + initial: CryptoStream, + handshake: CryptoStream, + application: CryptoStream, + }, + Handshake { + handshake: CryptoStream, + application: CryptoStream, + }, + ApplicationData { + application: CryptoStream, + }, +} + +impl CryptoStreams { + pub fn discard(&mut self, space: PacketNumberSpace) { + match space { + PacketNumberSpace::Initial => { + if let Self::Initial { + handshake, + application, + .. + } = self + { + *self = Self::Handshake { + handshake: mem::take(handshake), + application: mem::take(application), + }; + } + } + PacketNumberSpace::Handshake => { + if let Self::Handshake { application, .. } = self { + *self = Self::ApplicationData { + application: mem::take(application), + }; + } else if matches!(self, Self::Initial { .. }) { + panic!("Discarding handshake before initial discarded"); + } + } + PacketNumberSpace::ApplicationData => { + panic!("Discarding application data crypto streams") + } + } + } + + pub fn send(&mut self, space: PacketNumberSpace, data: &[u8]) { + self.get_mut(space).unwrap().tx.send(data); + } + + pub fn inbound_frame(&mut self, space: PacketNumberSpace, offset: u64, data: &[u8]) { + self.get_mut(space).unwrap().rx.inbound_frame(offset, data); + } + + pub fn data_ready(&self, space: PacketNumberSpace) -> bool { + self.get(space).map_or(false, |cs| cs.rx.data_ready()) + } + + pub fn read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec<u8>) -> usize { + self.get_mut(space).unwrap().rx.read_to_end(buf) + } + + pub fn acked(&mut self, token: &CryptoRecoveryToken) { + self.get_mut(token.space) + .unwrap() + .tx + .mark_as_acked(token.offset, token.length); + } + + pub fn lost(&mut self, token: &CryptoRecoveryToken) { + // See BZ 1624800, ignore lost packets in spaces we've dropped keys + if let Some(cs) = self.get_mut(token.space) { + cs.tx.mark_as_lost(token.offset, token.length); + } + } + + /// Resend any Initial or Handshake CRYPTO frames that might be outstanding. + /// This can help speed up handshake times. + pub fn resend_unacked(&mut self, space: PacketNumberSpace) { + if space != PacketNumberSpace::ApplicationData { + if let Some(cs) = self.get_mut(space) { + cs.tx.unmark_sent(); + } + } + } + + fn get(&self, space: PacketNumberSpace) -> Option<&CryptoStream> { + let (initial, hs, app) = match self { + Self::Initial { + initial, + handshake, + application, + } => (Some(initial), Some(handshake), Some(application)), + Self::Handshake { + handshake, + application, + } => (None, Some(handshake), Some(application)), + Self::ApplicationData { application } => (None, None, Some(application)), + }; + match space { + PacketNumberSpace::Initial => initial, + PacketNumberSpace::Handshake => hs, + PacketNumberSpace::ApplicationData => app, + } + } + + fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut CryptoStream> { + let (initial, hs, app) = match self { + Self::Initial { + initial, + handshake, + application, + } => (Some(initial), Some(handshake), Some(application)), + Self::Handshake { + handshake, + application, + } => (None, Some(handshake), Some(application)), + Self::ApplicationData { application } => (None, None, Some(application)), + }; + match space { + PacketNumberSpace::Initial => initial, + PacketNumberSpace::Handshake => hs, + PacketNumberSpace::ApplicationData => app, + } + } + + pub fn write_frame( + &mut self, + space: PacketNumberSpace, + builder: &mut PacketBuilder, + tokens: &mut Vec<RecoveryToken>, + stats: &mut FrameStats, + ) -> Res<()> { + let cs = self.get_mut(space).unwrap(); + if let Some((offset, data)) = cs.tx.next_bytes() { + let mut header_len = 1 + Encoder::varint_len(offset) + 1; + + // Don't bother if there isn't room for the header and some data. + if builder.remaining() < header_len + 1 { + return Ok(()); + } + // Calculate length of data based on the minimum of: + // - available data + // - remaining space, less the header, which counts only one byte + // for the length at first to avoid underestimating length + let length = min(data.len(), builder.remaining() - header_len); + header_len += Encoder::varint_len(u64::try_from(length).unwrap()) - 1; + let length = min(data.len(), builder.remaining() - header_len); + + builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO); + builder.encode_varint(offset); + builder.encode_vvec(&data[..length]); + if builder.len() > builder.limit() { + return Err(Error::InternalError(15)); + } + + cs.tx.mark_as_sent(offset, length); + + qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length); + tokens.push(RecoveryToken::Crypto(CryptoRecoveryToken { + space, + offset, + length, + })); + stats.crypto += 1; + } + Ok(()) + } +} + +impl Default for CryptoStreams { + fn default() -> Self { + Self::Initial { + initial: CryptoStream::default(), + handshake: CryptoStream::default(), + application: CryptoStream::default(), + } + } +} + +#[derive(Debug, Clone)] +pub struct CryptoRecoveryToken { + space: PacketNumberSpace, + offset: u64, + length: usize, +} |